Spaces:
Runtime error
Runtime error
import lancedb | |
import os | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
db = lancedb.connect(".lancedb") | |
tables = {} | |
def table(tname): | |
if not tname in tables: | |
tables[tname] = db.open_table(tname) | |
return tables[tname] | |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") | |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") | |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) | |
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5") | |
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
reranker_bge = CrossEncoder("BAAI/bge-reranker-large") | |
reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None): | |
if model_kind == "bge": | |
query_vec = retriever_bge.encode(query) | |
else: | |
query_vec = retriever_minilm.encode(query) | |
if pre_ranker_size is None: | |
pre_ranker_size = k | |
try: | |
documents = table( | |
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}", | |
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list() | |
documents = [doc[TEXT_COLUMN] for doc in documents] | |
if reranker_kind is None: | |
return documents | |
# Pair the query with each document for re-ranking | |
query_document_pairs = [(query, text) for text in documents] | |
# Score documents using the reranker | |
if reranker_kind == "bge": | |
scores = reranker_bge.predict(query_document_pairs) | |
else: | |
scores = reranker_minilm.predict(query_document_pairs) | |
# Aggregate and sort the documents based on the scores | |
scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) | |
# Return the top K documents based on re-ranking | |
return [doc for doc, _ in scored_documents[:k]] | |
except Exception as e: | |
raise gr.Error(str(e)) | |