ai_school_hw5 / backend /semantic_search.py
complynx's picture
Add all the variables right there
5fa8f2f
raw
history blame
1.17 kB
import lancedb
import os
import gradio as gr
from sentence_transformers import SentenceTransformer
db = lancedb.connect(".lancedb")
tables = {}
def table(tname):
if not tname in tables:
tables[tname] = db.open_table(tname)
return tables[tname]
TABLE = db.open_table(os.getenv("TABLE_NAME"))
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type):
if model_kind == "bge":
query_vec = retriever_bge.encode(query)
else:
query_vec = retriever_minilm.encode(query)
try:
documents = table(
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
documents = [doc[TEXT_COLUMN] for doc in documents]
return documents
except Exception as e:
raise gr.Error(str(e))