project / app /pipeline.py
kabylake's picture
commit
7bd11ed
import string
from typing import List, Optional, Tuple
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from loguru import logger
from app.chroma import ChromaDenseVectorDB
from app.config.models.configs import (
ResponseModel,
Config, SemanticSearchConfig,
)
from app.ranking import BCEReranker, rerank
from app.splade import SpladeSparseVectorDB
class LLMBundle:
def __init__(
self,
chain: Chain,
dense_db: ChromaDenseVectorDB,
reranker: BCEReranker,
sparse_db: SpladeSparseVectorDB,
chunk_sizes: List[int],
hyde_chain: Optional[LLMChain] = None
) -> None:
self.chain = chain
self.dense_db = dense_db
self.reranker = reranker
self.sparse_db = sparse_db
self.chunk_sizes = chunk_sizes
self.hyde_chain = hyde_chain
def get_relevant_documents(
self,
original_query: str,
query: str,
config: SemanticSearchConfig,
label: str,
) -> Tuple[List[str], float]:
most_relevant_docs = []
docs = []
current_reranker_score, reranker_score = -1e5, -1e5
for chunk_size in self.chunk_sizes:
all_relevant_docs = []
all_relevant_doc_ids = set()
logger.debug("Evaluating query: {}", query)
if config.query_prefix:
logger.info(f"Adding query prefix for retrieval: {config.query_prefix}")
query = config.query_prefix + query
sparse_search_docs_ids, sparse_scores = self.sparse_db.query(
search=query, n=config.max_k, label=label, chunk_size=chunk_size
)
logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.")
filter = (
{"chunk_size": chunk_size}
if len(self.chunk_sizes) > 1
else dict()
)
if label:
filter.update({"label": label})
if (
not filter
):
filter = None
logger.info(f"Dense embeddings filter: {filter}")
res = self.dense_db.similarity_search_with_relevance_scores(
query, filter=filter
)
dense_search_doc_ids = [r[0].metadata["document_id"] for r in res]
all_doc_ids = (
set(sparse_search_docs_ids).union(set(dense_search_doc_ids))
).difference(all_relevant_doc_ids)
if all_doc_ids:
relevant_docs = self.dense_db.get_documents_by_id(
document_ids=list(all_doc_ids)
)
all_relevant_docs += relevant_docs
# Re-rank embeddings
reranker_score, relevant_docs = rerank(
rerank_model=self.reranker,
query=original_query,
docs=all_relevant_docs,
)
if reranker_score > current_reranker_score:
docs = relevant_docs
current_reranker_score = reranker_score
len_ = 0
for doc in docs:
doc_length = len(doc.page_content)
if len_ + doc_length < config.max_char_size:
most_relevant_docs.append(doc)
len_ += doc_length
return most_relevant_docs, current_reranker_score
def get_and_parse_response(
self,
query: str,
config: Config,
label: str = "",
) -> ResponseModel:
original_query = query
# Add HyDE queries
hyde_response = self.hyde_chain.run(query)
query += hyde_response
logger.info(f"query: {query}")
semantic_search_config = config.semantic_search
most_relevant_docs, score = self.get_relevant_documents(
original_query, query, semantic_search_config, label
)
res = self.chain(
{"input_documents": most_relevant_docs, "question": original_query},
)
out = ResponseModel(
response=res["output_text"],
question=query,
average_score=score,
hyde_response="",
)
for doc in res["input_documents"]:
out.semantic_search.append(doc.page_content)
return out
class PartialFormatter(string.Formatter):
def __init__(self, missing="~~", bad_fmt="!!"):
self.missing, self.bad_fmt = missing, bad_fmt
def get_field(self, field_name, args, kwargs):
try:
val = super(PartialFormatter, self).get_field(field_name, args, kwargs)
except (KeyError, AttributeError):
val = None, field_name
return val
def format_field(self, value, spec):
if value is None:
return self.missing
try:
return super(PartialFormatter, self).format_field(value, spec)
except ValueError:
if self.bad_fmt is not None:
return self.bad_fmt
else:
raise