File size: 3,227 Bytes
6158da4 b83cc65 6158da4 b83cc65 6158da4 b83cc65 6158da4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.llms import CTransformers
from langchain.memory import ConversationBufferMemory
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
import os
from modules.constants import *
from modules.chat_model_loader import ChatModelLoader
from modules.vector_db import VectorDB
class LLMTutor:
def __init__(self, config, logger=None):
self.config = config
self.vector_db = VectorDB(config, logger=logger)
if self.config["embedding_options"]["embedd_files"]:
self.vector_db.create_database()
self.vector_db.save_database()
def set_custom_prompt(self):
"""
Prompt template for QA retrieval for each vectorstore
"""
if self.config["llm_params"]["use_history"]:
custom_prompt_template = prompt_template_with_history
else:
custom_prompt_template = prompt_template
prompt = PromptTemplate(
template=custom_prompt_template,
input_variables=["context", "chat_history", "question"],
)
# prompt = QA_PROMPT
return prompt
# Retrieval QA Chain
def retrieval_qa_chain(self, llm, prompt, db):
if self.config["llm_params"]["use_history"]:
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(
search_kwargs={
"k": self.config["embedding_options"]["search_top_k"]
}
),
return_source_documents=True,
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt},
)
else:
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(
search_kwargs={
"k": self.config["embedding_options"]["search_top_k"]
}
),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt},
)
return qa_chain
# Loading the model
def load_llm(self):
chat_model_loader = ChatModelLoader(self.config)
llm = chat_model_loader.load_chat_model()
return llm
# QA Model Function
def qa_bot(self):
db = self.vector_db.load_database()
self.llm = self.load_llm()
qa_prompt = self.set_custom_prompt()
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
return qa
# output function
def final_result(query):
qa_result = qa_bot()
response = qa_result({"query": query})
return response
|