Spaces:
Runtime error
Runtime error
File size: 2,037 Bytes
741514a 7588eb3 741514a 5fa8f2f 741514a 5fa8f2f 7588eb3 741514a 7588eb3 5fa8f2f 741514a 7588eb3 741514a 5fa8f2f 7588eb3 741514a 7588eb3 741514a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import lancedb
import os
import gradio as gr
from sentence_transformers import SentenceTransformer, CrossEncoder
db = lancedb.connect(".lancedb")
tables = {}
def table(tname):
if not tname in tables:
tables[tname] = db.open_table(tname)
return tables[tname]
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
reranker_bge = CrossEncoder("BAAI/bge-reranker-large")
reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None):
if model_kind == "bge":
query_vec = retriever_bge.encode(query)
else:
query_vec = retriever_minilm.encode(query)
if pre_ranker_size is None:
pre_ranker_size = k
try:
documents = table(
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list()
documents = [doc[TEXT_COLUMN] for doc in documents]
if reranker_kind is None:
return documents
# Pair the query with each document for re-ranking
query_document_pairs = [(query, text) for text in documents]
# Score documents using the reranker
if reranker_kind == "bge":
scores = reranker_bge.predict(query_document_pairs)
else:
scores = reranker_minilm.predict(query_document_pairs)
# Aggregate and sort the documents based on the scores
scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
# Return the top K documents based on re-ranking
return [doc for doc, _ in scored_documents[:k]]
except Exception as e:
raise gr.Error(str(e))
|