Samarth991 commited on
Commit
a918942
·
verified ·
1 Parent(s): 1648d0f

Update QnA.py

Browse files
Files changed (1) hide show
  1. QnA.py +9 -5
QnA.py CHANGED
@@ -64,7 +64,7 @@ def get_hugging_face_model(model_id='mistralai/Mistral-7B-Instruct-v0.2',tempera
64
  )
65
  return llm
66
 
67
- def Q_A(vectorstore,question,API_KEY):
68
 
69
  if API_KEY.startswith('gsk'):
70
  os.environ["GROQ_API_KEY"] = API_KEY
@@ -74,9 +74,13 @@ def Q_A(vectorstore,question,API_KEY):
74
 
75
  # Create a retriever
76
  retriever = vectorstore.as_retriever(search_type = 'similarity',search_kwargs = {'k':2},)
77
- #Create a contextual compressor
78
- compressor = LLMChainExtractor.from_llm(chat_llm)
79
- compression_retriever = ContextualCompressionRetriever(base_compressor=compressor,base_retriever=retriever)
 
 
 
 
80
 
81
  if 'reliable' in question.lower() or 'relaibility' in question.lower():
82
  question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_for_relaibility())
@@ -84,7 +88,7 @@ def Q_A(vectorstore,question,API_KEY):
84
  else:
85
  question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_to_analyze_resume())
86
 
87
- chain = create_retrieval_chain(compression_retriever, question_answer_chain)
88
  result = chain.invoke({'input':question})
89
  return result['answer']
90
 
 
64
  )
65
  return llm
66
 
67
+ def Q_A(vectorstore,question,API_KEY,compressor=False):
68
 
69
  if API_KEY.startswith('gsk'):
70
  os.environ["GROQ_API_KEY"] = API_KEY
 
74
 
75
  # Create a retriever
76
  retriever = vectorstore.as_retriever(search_type = 'similarity',search_kwargs = {'k':2},)
77
+
78
+ if compressor:
79
+ #Create a contextual compressor
80
+ compressor = LLMChainExtractor.from_llm(chat_llm)
81
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor,base_retriever=retriever)
82
+ retriever = compression_retriever
83
+
84
 
85
  if 'reliable' in question.lower() or 'relaibility' in question.lower():
86
  question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_for_relaibility())
 
88
  else:
89
  question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_to_analyze_resume())
90
 
91
+ chain = create_retrieval_chain(retriever, question_answer_chain)
92
  result = chain.invoke({'input':question})
93
  return result['answer']
94