File size: 3,294 Bytes
fbc7e49
52a9cd3
 
 
e055325
02b7760
52a9cd3
81d4c87
 
 
 
 
e055325
52a9cd3
 
 
 
 
81d4c87
52a9cd3
 
81d4c87
 
52a9cd3
 
 
 
 
 
 
 
 
 
81d4c87
52a9cd3
81d4c87
52a9cd3
e055325
5e3cc67
81d4c87
 
 
 
 
 
52a9cd3
81d4c87
5e3cc67
81d4c87
c3f2eff
8b97f05
81d4c87
52a9cd3
81d4c87
 
 
 
 
 
 
 
 
 
 
 
 
f5fba4f
81d4c87
 
 
f5fba4f
a4370d3
52a9cd3
81d4c87
52a9cd3
 
 
 
 
fbc7e49
81d4c87
e8c22b8
 
c77bb9e
fbc7e49
81d4c87
afc3612
 
 
 
3c9bd97
afc3612
81d4c87
02b7760
afc3612
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import chromadb
from sentence_transformers import CrossEncoder, SentenceTransformer
import json
from qdrant_client import QdrantClient

print("Setup client")
# chroma_client = chromadb.Client()
# collection = chroma_client.create_collection(
# name="food_collection",
# metadata={"hnsw:space": "cosine"} # l2 is the default
# )
client = QdrantClient(":memory:")

print("load data")
with open("test_json.json", "r") as f:
    payload = json.load(f)


def embedding_function(items_to_embed: list[str]):
    print("embedding")
    sentence_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
    embedded_items = sentence_model.encode(items_to_embed)
    print(len(embedded_items))
    print(type(embedded_items[0]))
    print(type(embedded_items[0][0]))
    embedded_list = [item.tolist() for item in embedded_items]
    print(len(embedded_list))
    print(type(embedded_list[0]))
    print(type(embedded_list[0][0]))
    return embedded_list


print("upserting")
print("printing item:")
embedding = embedding_function([item["doc"] for item in payload])
print(type(embedding))
client.add(
    collection_name="food",
    documents=[item["doc"] for item in payload],
    # embeddings=embedding,
    metadata=[{"payload": item} for item in payload],
    ids=[idx for idx, _ in enumerate(payload)],
)


def search_chroma(query: str):
    results = client.query(
        # query_embeddings=embedding_function([query]),
        collection_name="food",
        query_text=query,
        limit=5,
    )
    # print(results[0])
    # print(results[0].QueryResponse.metadata)
    # instructions = ['\n'.join(item.metadata['payload']['instructions']) for item in results]
    # text_only= [f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}" for item in results]
    top_k = [item.document for item in results]
    reranked = reranking_results(query, top_k)

    ordered_results = []
    for item in reranked:
        for result in results:
            if item["text"] == result.document:
                ordered_results.append(result)

    text_only = []
    for item in ordered_results:
        instructions = "- " + "<br>- ".join(item.metadata["payload"]["instructions"])
        markdown_text = f"# Dish: {item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}\n\n### Score: {item.score}\n"
        text_only.append(markdown_text)
    return "\n".join(text_only)


def reranking_results(query: str, top_k_results: list[str]):
    # Load the model, here we use our base sized model
    rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
    reranked_results = rerank_model.rank(query, top_k_results, return_documents=True)
    return reranked_results


def run_query(query_string: str):
    meal_string = search_chroma(query_string)
    return meal_string


with gr.Blocks() as meal_search:
    gr.Markdown("Start typing below and then click **Run** to see the output.")
    with gr.Row():
        inp = gr.Textbox(placeholder="What sort of meal are you after?")
        out = gr.Markdown()
    btn = gr.Button("Run")
    btn.click(fn=run_query, inputs=inp, outputs=out)

meal_search.launch()