Spaces:
Runtime error
Runtime error
LOUIS SANNA
commited on
Commit
·
780c913
1
Parent(s):
3a575de
feat(domains)
Browse files- anyqa/config.py +10 -0
- anyqa/retriever.py +7 -8
- app.py +11 -12
anyqa/config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
def get_domains():
|
6 |
+
domains = []
|
7 |
+
for root, dirs, files in os.walk("data"):
|
8 |
+
for dir in dirs:
|
9 |
+
domains.append(dir)
|
10 |
+
return domains
|
anyqa/retriever.py
CHANGED
@@ -13,25 +13,24 @@ SUMMARY_TYPES = []
|
|
13 |
|
14 |
class QARetriever(BaseRetriever):
|
15 |
vectorstore: VectorStore
|
16 |
-
|
17 |
threshold: float = 22
|
18 |
k_summary: int = 0
|
19 |
k_total: int = 10
|
20 |
namespace: str = "vectors"
|
21 |
|
22 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
23 |
-
|
24 |
-
assert isinstance(self.sources, list)
|
25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
26 |
|
27 |
-
query = "He who can bear the misfortune of a nation is called the ruler of the world."
|
28 |
# Prepare base search kwargs
|
29 |
filters = {}
|
30 |
-
if len(self.
|
31 |
-
filters["
|
32 |
|
33 |
if self.k_summary > 0:
|
34 |
# Search for k_summary documents in the summaries dataset
|
|
|
35 |
if len(SUMMARY_TYPES):
|
36 |
filters_summaries = {
|
37 |
**filters_summaries,
|
@@ -48,7 +47,8 @@ class QARetriever(BaseRetriever):
|
|
48 |
docs_summaries = []
|
49 |
|
50 |
# Search for k_total - k_summary documents in the full reports dataset
|
51 |
-
filters_full = {}
|
|
|
52 |
if len(SUMMARY_TYPES):
|
53 |
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
|
54 |
|
@@ -59,7 +59,6 @@ class QARetriever(BaseRetriever):
|
|
59 |
filter=self.format_filter(filters_full),
|
60 |
k=k_full,
|
61 |
)
|
62 |
-
print("docs_full", docs_full)
|
63 |
|
64 |
# Concatenate documents
|
65 |
docs = docs_summaries + docs_full
|
|
|
13 |
|
14 |
class QARetriever(BaseRetriever):
|
15 |
vectorstore: VectorStore
|
16 |
+
domains: list = []
|
17 |
threshold: float = 22
|
18 |
k_summary: int = 0
|
19 |
k_total: int = 10
|
20 |
namespace: str = "vectors"
|
21 |
|
22 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
23 |
+
assert isinstance(self.domains, list)
|
|
|
24 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
25 |
|
|
|
26 |
# Prepare base search kwargs
|
27 |
filters = {}
|
28 |
+
if len(self.domains):
|
29 |
+
filters["domain"] = {"$in": self.domains}
|
30 |
|
31 |
if self.k_summary > 0:
|
32 |
# Search for k_summary documents in the summaries dataset
|
33 |
+
filters_summaries = {**filters}
|
34 |
if len(SUMMARY_TYPES):
|
35 |
filters_summaries = {
|
36 |
**filters_summaries,
|
|
|
47 |
docs_summaries = []
|
48 |
|
49 |
# Search for k_total - k_summary documents in the full reports dataset
|
50 |
+
filters_full = {**filters}
|
51 |
+
print("filters", filters)
|
52 |
if len(SUMMARY_TYPES):
|
53 |
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
|
54 |
|
|
|
59 |
filter=self.format_filter(filters_full),
|
60 |
k=k_full,
|
61 |
)
|
|
|
62 |
|
63 |
# Concatenate documents
|
64 |
docs = docs_summaries + docs_full
|
app.py
CHANGED
@@ -7,6 +7,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
7 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
8 |
|
9 |
# ClimateQ&A imports
|
|
|
10 |
from anyqa.embeddings import EMBEDDING_MODEL_NAME
|
11 |
from anyqa.llm import get_llm
|
12 |
from anyqa.qa_logging import log
|
@@ -136,16 +137,14 @@ def answer_user_example(query, query_example, history):
|
|
136 |
return query_example, history + [[query_example, ". . ."]]
|
137 |
|
138 |
|
139 |
-
def fetch_sources(query,
|
140 |
-
# Prepare default values
|
141 |
-
if len(sources) == 0:
|
142 |
-
sources = ["IPCC"]
|
143 |
|
144 |
llm_reformulation = get_llm(
|
145 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
146 |
)
|
|
|
147 |
retriever = QARetriever(
|
148 |
-
vectorstore=vectorstore,
|
149 |
)
|
150 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
151 |
|
@@ -379,11 +378,11 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
379 |
gr.Markdown(
|
380 |
"Reminder: You can talk in any language, this tool is multi-lingual!"
|
381 |
)
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
label="Select
|
386 |
-
value=[
|
387 |
interactive=True,
|
388 |
)
|
389 |
|
@@ -419,7 +418,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
419 |
.success(change_tab, None, tabs)
|
420 |
.success(
|
421 |
fetch_sources,
|
422 |
-
[textbox,
|
423 |
[
|
424 |
textbox,
|
425 |
sources_textbox,
|
@@ -454,7 +453,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
454 |
.success(change_tab, None, tabs)
|
455 |
.success(
|
456 |
fetch_sources,
|
457 |
-
[textbox,
|
458 |
[
|
459 |
textbox,
|
460 |
sources_textbox,
|
|
|
7 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
8 |
|
9 |
# ClimateQ&A imports
|
10 |
+
from anyqa.config import get_domains
|
11 |
from anyqa.embeddings import EMBEDDING_MODEL_NAME
|
12 |
from anyqa.llm import get_llm
|
13 |
from anyqa.qa_logging import log
|
|
|
137 |
return query_example, history + [[query_example, ". . ."]]
|
138 |
|
139 |
|
140 |
+
def fetch_sources(query, domains):
|
|
|
|
|
|
|
141 |
|
142 |
llm_reformulation = get_llm(
|
143 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
144 |
)
|
145 |
+
print("domains", domains)
|
146 |
retriever = QARetriever(
|
147 |
+
vectorstore=vectorstore, domains=domains, k_summary=0, k_total=10
|
148 |
)
|
149 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
150 |
|
|
|
378 |
gr.Markdown(
|
379 |
"Reminder: You can talk in any language, this tool is multi-lingual!"
|
380 |
)
|
381 |
+
domains = get_domains()
|
382 |
+
dropdown_domains = gr.CheckboxGroup(
|
383 |
+
domains,
|
384 |
+
label="Select source types",
|
385 |
+
value=[],
|
386 |
interactive=True,
|
387 |
)
|
388 |
|
|
|
418 |
.success(change_tab, None, tabs)
|
419 |
.success(
|
420 |
fetch_sources,
|
421 |
+
[textbox, dropdown_domains],
|
422 |
[
|
423 |
textbox,
|
424 |
sources_textbox,
|
|
|
453 |
.success(change_tab, None, tabs)
|
454 |
.success(
|
455 |
fetch_sources,
|
456 |
+
[textbox, dropdown_domains],
|
457 |
[
|
458 |
textbox,
|
459 |
sources_textbox,
|