Spaces:
Runtime error
Runtime error
LOUIS SANNA
commited on
Commit
·
cde6d5c
1
Parent(s):
dfcff8d
feat(code): step down rule
Browse files- climateqa/chains.py +57 -57
climateqa/chains.py
CHANGED
@@ -7,52 +7,10 @@ from langchain.chains import QAWithSourcesChain
|
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
10 |
-
from climateqa.prompts import answer_prompt, reformulation_prompt
|
11 |
from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
|
12 |
|
13 |
|
14 |
-
def load_reformulation_chain(llm):
|
15 |
-
prompt = PromptTemplate(
|
16 |
-
template=reformulation_prompt,
|
17 |
-
input_variables=["query"],
|
18 |
-
)
|
19 |
-
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
|
20 |
-
|
21 |
-
# Parse the output
|
22 |
-
def parse_output(output):
|
23 |
-
query = output["query"]
|
24 |
-
print("output", output)
|
25 |
-
json_output = json.loads(output["json"])
|
26 |
-
question = json_output.get("question", query)
|
27 |
-
language = json_output.get("language", "English")
|
28 |
-
return {
|
29 |
-
"question": question,
|
30 |
-
"language": language,
|
31 |
-
}
|
32 |
-
|
33 |
-
transform_chain = TransformChain(
|
34 |
-
input_variables=["json"],
|
35 |
-
output_variables=["question", "language"],
|
36 |
-
transform=parse_output,
|
37 |
-
)
|
38 |
-
|
39 |
-
reformulation_chain = SequentialChain(
|
40 |
-
chains=[reformulation_chain, transform_chain],
|
41 |
-
input_variables=["query"],
|
42 |
-
output_variables=["question", "language"],
|
43 |
-
)
|
44 |
-
return reformulation_chain
|
45 |
-
|
46 |
-
|
47 |
-
def load_combine_documents_chain(llm):
|
48 |
-
prompt = PromptTemplate(
|
49 |
-
template=answer_prompt,
|
50 |
-
input_variables=["summaries", "question", "audience", "language"],
|
51 |
-
)
|
52 |
-
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
|
53 |
-
return qa_chain
|
54 |
-
|
55 |
-
|
56 |
def load_qa_chain_with_docs(llm):
|
57 |
"""Load a QA chain with documents.
|
58 |
Useful when you already have retrieved docs
|
@@ -78,6 +36,15 @@ def load_qa_chain_with_docs(llm):
|
|
78 |
return chain
|
79 |
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def load_qa_chain_with_text(llm):
|
82 |
prompt = PromptTemplate(
|
83 |
template=answer_prompt,
|
@@ -87,6 +54,53 @@ def load_qa_chain_with_text(llm):
|
|
87 |
return qa_chain
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def load_qa_chain_with_retriever(retriever, llm):
|
91 |
qa_chain = load_combine_documents_chain(llm)
|
92 |
|
@@ -101,17 +115,3 @@ def load_qa_chain_with_retriever(retriever, llm):
|
|
101 |
fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
|
102 |
)
|
103 |
return answer_chain
|
104 |
-
|
105 |
-
|
106 |
-
def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
|
107 |
-
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
108 |
-
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
|
109 |
-
|
110 |
-
climateqa_chain = SequentialChain(
|
111 |
-
chains=[reformulation_chain, answer_chain],
|
112 |
-
input_variables=["query", "audience"],
|
113 |
-
output_variables=["answer", "question", "language", "source_documents"],
|
114 |
-
return_all=True,
|
115 |
-
verbose=True,
|
116 |
-
)
|
117 |
-
return climateqa_chain
|
|
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
10 |
+
from climateqa.prompts import answer_prompt, reformulation_prompt
|
11 |
from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def load_qa_chain_with_docs(llm):
|
15 |
"""Load a QA chain with documents.
|
16 |
Useful when you already have retrieved docs
|
|
|
36 |
return chain
|
37 |
|
38 |
|
39 |
+
def load_combine_documents_chain(llm):
|
40 |
+
prompt = PromptTemplate(
|
41 |
+
template=answer_prompt,
|
42 |
+
input_variables=["summaries", "question", "audience", "language"],
|
43 |
+
)
|
44 |
+
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
|
45 |
+
return qa_chain
|
46 |
+
|
47 |
+
|
48 |
def load_qa_chain_with_text(llm):
|
49 |
prompt = PromptTemplate(
|
50 |
template=answer_prompt,
|
|
|
54 |
return qa_chain
|
55 |
|
56 |
|
57 |
+
def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
|
58 |
+
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
59 |
+
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
|
60 |
+
|
61 |
+
climateqa_chain = SequentialChain(
|
62 |
+
chains=[reformulation_chain, answer_chain],
|
63 |
+
input_variables=["query", "audience"],
|
64 |
+
output_variables=["answer", "question", "language", "source_documents"],
|
65 |
+
return_all=True,
|
66 |
+
verbose=True,
|
67 |
+
)
|
68 |
+
return climateqa_chain
|
69 |
+
|
70 |
+
|
71 |
+
def load_reformulation_chain(llm):
|
72 |
+
prompt = PromptTemplate(
|
73 |
+
template=reformulation_prompt,
|
74 |
+
input_variables=["query"],
|
75 |
+
)
|
76 |
+
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
|
77 |
+
|
78 |
+
# Parse the output
|
79 |
+
def parse_output(output):
|
80 |
+
query = output["query"]
|
81 |
+
print("output", output)
|
82 |
+
json_output = json.loads(output["json"])
|
83 |
+
question = json_output.get("question", query)
|
84 |
+
language = json_output.get("language", "English")
|
85 |
+
return {
|
86 |
+
"question": question,
|
87 |
+
"language": language,
|
88 |
+
}
|
89 |
+
|
90 |
+
transform_chain = TransformChain(
|
91 |
+
input_variables=["json"],
|
92 |
+
output_variables=["question", "language"],
|
93 |
+
transform=parse_output,
|
94 |
+
)
|
95 |
+
|
96 |
+
reformulation_chain = SequentialChain(
|
97 |
+
chains=[reformulation_chain, transform_chain],
|
98 |
+
input_variables=["query"],
|
99 |
+
output_variables=["question", "language"],
|
100 |
+
)
|
101 |
+
return reformulation_chain
|
102 |
+
|
103 |
+
|
104 |
def load_qa_chain_with_retriever(retriever, llm):
|
105 |
qa_chain = load_combine_documents_chain(llm)
|
106 |
|
|
|
115 |
fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
|
116 |
)
|
117 |
return answer_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|