File size: 2,037 Bytes
741514a
 
 
7588eb3
741514a
 
 
 
5fa8f2f
 
 
 
 
 
 
741514a
 
 
 
5fa8f2f
 
 
7588eb3
 
741514a
7588eb3
 
5fa8f2f
 
 
 
741514a
7588eb3
 
 
741514a
5fa8f2f
 
7588eb3
741514a
 
7588eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741514a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import lancedb
import os
import gradio as gr
from sentence_transformers import SentenceTransformer, CrossEncoder


db = lancedb.connect(".lancedb")

tables = {}

def table(tname):
    if not tname in tables:
        tables[tname] = db.open_table(tname)
    return tables[tname]

VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))

retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

reranker_bge = CrossEncoder("BAAI/bge-reranker-large")
reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")


def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None):
    if model_kind == "bge":
        query_vec = retriever_bge.encode(query)
    else:
        query_vec = retriever_minilm.encode(query)

    if pre_ranker_size is None:
        pre_ranker_size = k

    try:
        documents = table(
            f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
            ).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list()
        documents = [doc[TEXT_COLUMN] for doc in documents]

        if reranker_kind is None:
            return documents
        # Pair the query with each document for re-ranking
        query_document_pairs = [(query, text) for text in documents]

        # Score documents using the reranker
        if reranker_kind == "bge":
            scores = reranker_bge.predict(query_document_pairs)
        else:
            scores = reranker_minilm.predict(query_document_pairs)

        # Aggregate and sort the documents based on the scores
        scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)

        # Return the top K documents based on re-ranking
        return [doc for doc, _ in scored_documents[:k]]

    except Exception as e:
        raise gr.Error(str(e))