Samarth991's picture
Update QnA.py
7a16b91 verified
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.summarize.chain import load_summarize_chain
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.chains.question_answering import load_qa_chain
#from Api_Key import google_plam
from langchain_groq import ChatGroq
import os
from dotenv import load_dotenv
load_dotenv()
def prompt_template_to_analyze_resume():
template = """
You are provided with the Resume of the Candidate in the context below .
As an Talent Aquistion bot , your task is to provide insights about the candidate in precise manner.
\n\n:{context}
"""
prompt = ChatPromptTemplate.from_messages(
[
('system',template),
('human','input'),
]
)
return prompt
def prompt_template_for_relaibility():
template ="""
You are provided with the Resume of the Candidate in the context below
If asked about reliability , check How frequently the candidate has switched from one company to another.
Grade him on the given basis:
If less than 2 Year - very less Reliable
if more than 2 years but less than 5 years - Reliable
if more than 5 Years - Highly Reliable
and generate verdict .
\n\n:{context}
"""
prompt = ChatPromptTemplate.from_messages(
[
('system',template),
('human','input'),
]
)
return prompt
def summarize(documents,llm):
summarize_chain = load_summarize_chain(llm=llm, chain_type='refine', verbose = True)
results = summarize_chain.invoke({'input_documents':documents})
return results['output_text']
def get_hugging_face_model(model_id='mistralai/Mistral-7B-Instruct-v0.2',temperature=0.01,max_tokens=4096,api_key=None):
llm = HuggingFaceHub(
huggingfacehub_api_token =api_key,
repo_id=model_id,
model_kwargs={"temperature":temperature, "max_new_tokens":max_tokens}
)
return llm
def get_groq_model(api_key):
os.environ["GROQ_API_KEY"] = api_key
llm = ChatGroq(model="llama3-8b-8192") # (model="gemma2-9b-it")
return llm
def Q_A(vectorstore,question,API_KEY,compressor=False):
if API_KEY.startswith('gsk'):
chat_llm = get_groq_model(api_key=API_KEY)
elif API_KEY.startswith('hf'):
chat_llm = get_hugging_face_model(api_key=API_KEY)
# Create a retriever
retriever = vectorstore.as_retriever(search_type = 'similarity',search_kwargs = {'k':2},)
if compressor:
#Create a contextual compressor
compressor = LLMChainExtractor.from_llm(chat_llm)
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor,base_retriever=retriever)
retriever = compression_retriever
if 'reliable' in question.lower() or 'relaibility' in question.lower():
prompt = prompt_template_for_relaibility()
else:
prompt = prompt_template_to_analyze_resume()
# question_answer_chain = load_qa_chain(chat_llm, chain_type="stuff", prompt=prompt)
question_answer_chain = create_stuff_documents_chain(chat_llm, prompt)
chain = create_retrieval_chain(retriever, question_answer_chain)
result = chain.invoke({'input':question})
return result['answer']