|
from langchain.schema.vectorstore import VectorStoreRetriever |
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun |
|
from langchain.schema.document import Document |
|
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun |
|
from typing import List |
|
|
|
|
|
class VectorStoreRetrieverScore(VectorStoreRetriever): |
|
|
|
|
|
def _get_relevant_documents( |
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
docs_and_similarities = ( |
|
self.vectorstore.similarity_search_with_relevance_scores( |
|
query, **self.search_kwargs |
|
) |
|
) |
|
|
|
for doc, similarity in docs_and_similarities: |
|
doc.metadata["score"] = similarity |
|
|
|
docs = [doc for doc, _ in docs_and_similarities] |
|
return docs |
|
|
|
async def _aget_relevant_documents( |
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
docs_and_similarities = ( |
|
self.vectorstore.similarity_search_with_relevance_scores( |
|
query, **self.search_kwargs |
|
) |
|
) |
|
|
|
for doc, similarity in docs_and_similarities: |
|
doc.metadata["score"] = similarity |
|
|
|
docs = [doc for doc, _ in docs_and_similarities] |
|
return docs |
|
|