import lancedb import os import gradio as gr from sentence_transformers import SentenceTransformer, CrossEncoder db = lancedb.connect(".lancedb") tables = {} def table(tname): if not tname in tables: tables[tname] = db.open_table(tname) return tables[tname] VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5") retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") reranker_bge = CrossEncoder("BAAI/bge-reranker-large") reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None): if model_kind == "bge": query_vec = retriever_bge.encode(query) else: query_vec = retriever_minilm.encode(query) if pre_ranker_size is None: pre_ranker_size = k try: documents = table( f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}", ).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] if reranker_kind is None: return documents # Pair the query with each document for re-ranking query_document_pairs = [(query, text) for text in documents] # Score documents using the reranker if reranker_kind == "bge": scores = reranker_bge.predict(query_document_pairs) else: scores = reranker_minilm.predict(query_document_pairs) # Aggregate and sort the documents based on the scores scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) # Return the top K documents based on re-ranking return [doc for doc, _ in scored_documents[:k]] except Exception as e: raise gr.Error(str(e))