Spaces:
Runtime error
Runtime error
File size: 3,193 Bytes
f0fc5f8 cc2ce8c f0fc5f8 6e28a81 cc2ce8c 6e28a81 780c913 6e28a81 cc2ce8c 6e28a81 f0fc5f8 780c913 f0fc5f8 cc2ce8c 780c913 cc2ce8c 780c913 cc2ce8c c6c35dc cc2ce8c f0fc5f8 780c913 cc2ce8c f0fc5f8 6e28a81 cc2ce8c c6c35dc cc2ce8c 6e28a81 f0fc5f8 6e28a81 f0fc5f8 6e28a81 f0fc5f8 cc2ce8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# https://github.com/langchain-ai/langchain/issues/8623
from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores import Chroma
from typing import List
## The idea that some documents are summaries so easier to exploit
SUMMARY_TYPES = []
class QARetriever(BaseRetriever):
vectorstore: VectorStore
domains: list = []
threshold: float = 22
k_summary: int = 0
k_total: int = 10
namespace: str = "vectors"
def get_relevant_documents(self, query: str) -> List[Document]:
assert isinstance(self.domains, list)
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.domains):
filters["domain"] = {"$in": self.domains}
if self.k_summary > 0:
# Search for k_summary documents in the summaries dataset
filters_summaries = {**filters}
if len(SUMMARY_TYPES):
filters_summaries = {
**filters_summaries,
"report_type": {"$in": SUMMARY_TYPES},
}
docs_summaries = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_summaries),
k=self.k_summary,
)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
else:
docs_summaries = []
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {**filters}
print("filters", filters)
if len(SUMMARY_TYPES):
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_full),
k=k_full,
)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i, (doc, score) in enumerate(docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"])
doc.page_content = (
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
)
results.append(doc)
return results
def format_filter(self, filters):
# https://docs.trychroma.com/usage-guide#using-logical-operators
if isinstance(self.vectorstore, Chroma):
if len(filters) <= 1:
return filters
and_filters = []
for field, condition in filters.items():
and_filters.append({field: condition})
return {"$and": and_filters}
return filters
|