|
from langchain import PromptTemplate |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.chat_models import ChatOpenAI |
|
from langchain_community.embeddings import OpenAIEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.llms import CTransformers |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT |
|
import os |
|
from modules.constants import * |
|
from modules.helpers import get_prompt |
|
from modules.chat_model_loader import ChatModelLoader |
|
from modules.vector_db import VectorDB, VectorDBScore |
|
|
|
|
|
class LLMTutor: |
|
def __init__(self, config, logger=None): |
|
self.config = config |
|
self.vector_db = VectorDB(config, logger=logger) |
|
if self.config["embedding_options"]["embedd_files"]: |
|
self.vector_db.create_database() |
|
self.vector_db.save_database() |
|
|
|
def set_custom_prompt(self): |
|
""" |
|
Prompt template for QA retrieval for each vectorstore |
|
""" |
|
prompt = get_prompt(self.config) |
|
|
|
|
|
return prompt |
|
|
|
|
|
def retrieval_qa_chain(self, llm, prompt, db): |
|
if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]: |
|
retriever = VectorDBScore( |
|
vectorstore=db, |
|
search_type="similarity_score_threshold", |
|
search_kwargs={ |
|
"score_threshold": self.config["embedding_options"][ |
|
"score_threshold" |
|
], |
|
"k": self.config["embedding_options"]["search_top_k"], |
|
}, |
|
) |
|
elif self.config["embedding_options"]["db_option"] == "RAGatouille": |
|
retriever = db.as_langchain_retriever( |
|
k=self.config["embedding_options"]["search_top_k"] |
|
) |
|
if self.config["llm_params"]["use_history"]: |
|
memory = ConversationBufferWindowMemory( |
|
k=self.config["llm_params"]["memory_window"], |
|
memory_key="chat_history", |
|
return_messages=True, |
|
output_key="answer", |
|
) |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=True, |
|
memory=memory, |
|
combine_docs_chain_kwargs={"prompt": prompt}, |
|
) |
|
else: |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt}, |
|
) |
|
return qa_chain |
|
|
|
|
|
def load_llm(self): |
|
chat_model_loader = ChatModelLoader(self.config) |
|
llm = chat_model_loader.load_chat_model() |
|
return llm |
|
|
|
|
|
def qa_bot(self): |
|
db = self.vector_db.load_database() |
|
self.llm = self.load_llm() |
|
qa_prompt = self.set_custom_prompt() |
|
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db) |
|
|
|
return qa |
|
|
|
|
|
def final_result(query): |
|
qa_result = qa_bot() |
|
response = qa_result({"query": query}) |
|
return response |
|
|