Spaces:
Sleeping
Sleeping
import os | |
from typing import List | |
from langchain_community.document_loaders import UnstructuredPDFLoader | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores.chroma import Chroma | |
from langchain_core.runnables.base import RunnableSequence | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from dotenv import load_dotenv | |
load_dotenv() | |
HF_API_KEY = os.environ["HF_API_KEY"] | |
class MistralOutputParser(StrOutputParser): | |
"""OutputParser that parser llm result from Mistral API""" | |
def parse(self, text: str) -> str: | |
""" | |
Returns the input text with no changes. | |
Args: | |
text (str): text to parse | |
Returns: | |
str: parsed text | |
""" | |
return text.split("[/INST]")[-1].strip() | |
def load_pdf( | |
document_path: str, | |
mode: str = "single", | |
strategy: str = "fast", | |
chunk_size: int = 500, | |
chunk_overlap: int = 0, | |
) -> List[str]: | |
""" | |
Load a pdf document and split it into chunks of text. | |
Args: | |
document_path (Path): path to the pdf document | |
mode (str, optional): mode of the loader. Defaults to "single". | |
strategy (str, optional): strategy of the loader. Defaults to "fast". | |
chunk_size (int, optional): size of the chunks. Defaults to 500. | |
chunk_overlap (int, optional): overlap of the chunks. Defaults to 0. | |
Returns: | |
List[str]: list of chunks of text | |
""" | |
# Load the document | |
loader = UnstructuredPDFLoader( | |
document_path, | |
mode=mode, | |
strategy=strategy, | |
) | |
docs = loader.load() | |
# Split the document into chunks of text | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
all_splits = text_splitter.split_documents(docs) | |
return all_splits | |
def store_vector(all_splits: List[str]) -> VectorStoreRetriever: | |
""" | |
Store vector of each chunk of text. | |
Args: | |
all_splits (List[str]): list of chunks of text | |
Returns: | |
VectorStoreRetriever: retriever that can be used to retrieve the vector of a chunk of text | |
""" | |
# Use the HuggingFace distilbert-base-uncased model to embed the text | |
embeddings_model_url = ( | |
# "https://api-inference.huggingface.co/models/distilbert-base-uncased" | |
"https://api-inference.huggingface.co/models/Salesforce/SFR-Embedding-Mistral" | |
) | |
embeddings = HuggingFaceInferenceAPIEmbeddings( | |
endpoint_url=embeddings_model_url, | |
api_key=HF_API_KEY, | |
) | |
# Store the embeddings of each chunk of text into ChromaDB | |
vector_store = Chroma.from_documents(all_splits, embeddings) | |
retriever = vector_store.as_retriever() | |
return retriever | |
def generate_mistral_rag_prompt() -> ChatPromptTemplate: | |
""" | |
Generate a prompt for Mistral API wiht RAG. | |
Returns: | |
ChatPromptTemplate: prompt for Mistral API | |
""" | |
template = "<s>[INST] {context} {prompt} [/INST]" | |
prompt_template = ChatPromptTemplate.from_template(template) | |
return prompt_template | |
def generate_mistral_simple_prompt() -> ChatPromptTemplate: | |
""" | |
Generate a simple prompt for Mistral without RAG. | |
Returns: | |
ChatPromptTemplate: prompt for Mistral API | |
""" | |
template = "[INST] {prompt} [/INST]" | |
prompt_template = ChatPromptTemplate.from_template(template) | |
return prompt_template | |
def generate_rag_chain(retriever: VectorStoreRetriever = None) -> RunnableSequence: | |
""" | |
Generate a RAG chain with Mistral API and ChromaDB. | |
Args: | |
Retriever (VectorStoreRetriever): retriever that can be used to retrieve the vector of a chunk of text | |
Returns: | |
RunnableSequence: RAG chain | |
""" | |
# Use the Mistral Free prototype API | |
mistral_url = ( | |
# "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | |
"https://api-inference.huggingface.co/models/mistralai/Mixtral-8x22B-Instruct-v0.1" | |
) | |
model_endpoint = HuggingFaceEndpoint( | |
endpoint_url=mistral_url, | |
huggingfacehub_api_token=HF_API_KEY, | |
task="text2text-generation", | |
max_new_tokens=1024 | |
) | |
# Use a custom output parser | |
output_parser = MistralOutputParser() | |
# If no retriever is provided, use a simple prompt | |
if retriever is None: | |
entry = {"prompt": RunnablePassthrough()} | |
return entry | generate_mistral_simple_prompt() | model_endpoint | output_parser | |
# If a retriever is provided, use a RAG prompt | |
retrieval = {"context": retriever, "prompt": RunnablePassthrough()} | |
return retrieval | generate_mistral_rag_prompt() | model_endpoint | output_parser | |
def load_multiple_pdf(document_paths: List[str]) -> List[str]: | |
""" | |
Load multiple pdf documents and split them into chunks of text. | |
Args: | |
document_paths (List[str]): list of paths to the pdf documents | |
Returns: | |
List[str]: list of chunks of text | |
""" | |
docs = [] | |
for document_path in document_paths: | |
loader = UnstructuredPDFLoader( | |
document_path, | |
mode="single", | |
strategy="fast", | |
) | |
docs.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=25) | |
all_splits = text_splitter.split_documents(docs) | |
return all_splits | |