Spaces:
Configuration error
Configuration error
import string | |
from typing import List, Optional, Tuple | |
from langchain.chains import LLMChain | |
from langchain.chains.base import Chain | |
from loguru import logger | |
from app.chroma import ChromaDenseVectorDB | |
from app.config.models.configs import ( | |
ResponseModel, | |
Config, SemanticSearchConfig, | |
) | |
from app.ranking import BCEReranker, rerank | |
from app.splade import SpladeSparseVectorDB | |
class LLMBundle: | |
def __init__( | |
self, | |
chain: Chain, | |
dense_db: ChromaDenseVectorDB, | |
reranker: BCEReranker, | |
sparse_db: SpladeSparseVectorDB, | |
chunk_sizes: List[int], | |
hyde_chain: Optional[LLMChain] = None | |
) -> None: | |
self.chain = chain | |
self.dense_db = dense_db | |
self.reranker = reranker | |
self.sparse_db = sparse_db | |
self.chunk_sizes = chunk_sizes | |
self.hyde_chain = hyde_chain | |
def get_relevant_documents( | |
self, | |
original_query: str, | |
query: str, | |
config: SemanticSearchConfig, | |
label: str, | |
) -> Tuple[List[str], float]: | |
most_relevant_docs = [] | |
docs = [] | |
current_reranker_score, reranker_score = -1e5, -1e5 | |
for chunk_size in self.chunk_sizes: | |
all_relevant_docs = [] | |
all_relevant_doc_ids = set() | |
logger.debug("Evaluating query: {}", query) | |
if config.query_prefix: | |
logger.info(f"Adding query prefix for retrieval: {config.query_prefix}") | |
query = config.query_prefix + query | |
sparse_search_docs_ids, sparse_scores = self.sparse_db.query( | |
search=query, n=config.max_k, label=label, chunk_size=chunk_size | |
) | |
logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.") | |
filter = ( | |
{"chunk_size": chunk_size} | |
if len(self.chunk_sizes) > 1 | |
else dict() | |
) | |
if label: | |
filter.update({"label": label}) | |
if ( | |
not filter | |
): | |
filter = None | |
logger.info(f"Dense embeddings filter: {filter}") | |
res = self.dense_db.similarity_search_with_relevance_scores( | |
query, filter=filter | |
) | |
dense_search_doc_ids = [r[0].metadata["document_id"] for r in res] | |
all_doc_ids = ( | |
set(sparse_search_docs_ids).union(set(dense_search_doc_ids)) | |
).difference(all_relevant_doc_ids) | |
if all_doc_ids: | |
relevant_docs = self.dense_db.get_documents_by_id( | |
document_ids=list(all_doc_ids) | |
) | |
all_relevant_docs += relevant_docs | |
# Re-rank embeddings | |
reranker_score, relevant_docs = rerank( | |
rerank_model=self.reranker, | |
query=original_query, | |
docs=all_relevant_docs, | |
) | |
if reranker_score > current_reranker_score: | |
docs = relevant_docs | |
current_reranker_score = reranker_score | |
len_ = 0 | |
for doc in docs: | |
doc_length = len(doc.page_content) | |
if len_ + doc_length < config.max_char_size: | |
most_relevant_docs.append(doc) | |
len_ += doc_length | |
return most_relevant_docs, current_reranker_score | |
def get_and_parse_response( | |
self, | |
query: str, | |
config: Config, | |
label: str = "", | |
) -> ResponseModel: | |
original_query = query | |
# Add HyDE queries | |
hyde_response = self.hyde_chain.run(query) | |
query += hyde_response | |
logger.info(f"query: {query}") | |
semantic_search_config = config.semantic_search | |
most_relevant_docs, score = self.get_relevant_documents( | |
original_query, query, semantic_search_config, label | |
) | |
res = self.chain( | |
{"input_documents": most_relevant_docs, "question": original_query}, | |
) | |
out = ResponseModel( | |
response=res["output_text"], | |
question=query, | |
average_score=score, | |
hyde_response="", | |
) | |
for doc in res["input_documents"]: | |
out.semantic_search.append(doc.page_content) | |
return out | |
class PartialFormatter(string.Formatter): | |
def __init__(self, missing="~~", bad_fmt="!!"): | |
self.missing, self.bad_fmt = missing, bad_fmt | |
def get_field(self, field_name, args, kwargs): | |
try: | |
val = super(PartialFormatter, self).get_field(field_name, args, kwargs) | |
except (KeyError, AttributeError): | |
val = None, field_name | |
return val | |
def format_field(self, value, spec): | |
if value is None: | |
return self.missing | |
try: | |
return super(PartialFormatter, self).format_field(value, spec) | |
except ValueError: | |
if self.bad_fmt is not None: | |
return self.bad_fmt | |
else: | |
raise | |