lfoppiano's picture
add query analyzer with min and avg similarity
0188e45
from pathlib import Path
from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
from langchain.schema import Document
from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
class AdvancedVectorStoreRetriever(VectorStoreRetriever):
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
"similarity_with_embeddings"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity_with_embeddings":
docs_scores_and_embeddings = (
self.vectorstore.advanced_similarity_search(
query, **self.search_kwargs
)
)
for doc, score, embeddings in docs_scores_and_embeddings:
if '__embeddings' not in doc.metadata.keys():
doc.metadata['__embeddings'] = embeddings
if '__similarity' not in doc.metadata.keys():
doc.metadata['__similarity'] = score
docs = [doc for doc, _, _ in docs_scores_and_embeddings]
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
for doc, similarity in docs_and_similarities:
if '__similarity' not in doc.metadata.keys():
doc.metadata['__similarity'] = similarity
docs = [doc for doc, _ in docs_and_similarities]
else:
docs = super()._get_relevant_documents(query, run_manager=run_manager)
return docs
class AdvancedVectorStore(VectorStore):
def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Query the chroma collection."""
try:
import chromadb # noqa: F401
except ImportError:
raise ValueError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
return self._collection.query(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
**kwargs,
)
def advanced_similarity_search(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> [List[Document], float, List[float]]:
docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
return docs_scores_and_embeddings
def similarity_search_with_scores_and_embeddings(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float, List[float]]]:
if self._embedding_function is None:
results = self.__query_collection(
query_texts=[query],
n_results=k,
where=filter,
where_document=where_document,
include=['metadatas', 'documents', 'embeddings', 'distances']
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self.__query_collection(
query_embeddings=[query_embedding],
n_results=k,
where=filter,
where_document=where_document,
include=['metadatas', 'documents', 'embeddings', 'distances']
)
return _results_to_docs_scores_and_embeddings(results)
def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
results["embeddings"][0],
)
]