rag-retrieve / app.py
davidberenstein1957's picture
Update app.py
53ff92d verified
raw
history blame
1.87 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import duckdb
from huggingface_hub import get_token
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
import duckdb
# Initialize a StaticEmbedding module
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
model = SentenceTransformer(modules=[static_embedding])
dataset_name = "smol-blueprint/fineweb-bbc-news-text-embeddings"
embedding_column = "embedding"
duckdb.sql(
query=f"""
INSTALL vss;
LOAD vss;
CREATE TABLE embeddings AS
SELECT *, {embedding_column}::float[{model.get_sentence_embedding_dimension()}] as embedding_float
FROM 'hf://datasets/{dataset_name}/**/*.parquet';
CREATE INDEX my_hnsw_index ON embeddings USING HNSW (embedding_float) WITH (metric = 'cosine');
"""
)
def similarity_search(query: str, k: int = 5):
embedding = model.encode(query).tolist()
return duckdb.sql(
query=f"""
SELECT url, chunk, array_cosine_distance(embedding_float, {embedding}::FLOAT[{model.get_sentence_embedding_dimension()}]) as distance
FROM embeddings
ORDER BY distance
LIMIT {k};
"""
).to_df()
with gr.Blocks() as demo:
gr.Markdown("""# Vector Search Hub Datasets
Part of [smol blueprint](https://github.com/huggingface/smol-blueprint) - a smol blueprint for AI development, focusing on applied examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs. """)
query = gr.Textbox(label="Query")
k = gr.Slider(1, 50, value=5, label="Number of results")
btn = gr.Button("Search")
results = gr.Dataframe(headers=["url", "chunk", "distance"])
btn.click(fn=similarity_search, inputs=[query, k], outputs=[results])
demo.launch()