siglip2-clothes / app.py
adorkin's picture
Upload 5 files
9082a92 verified
import numpy as np
from safetensors import safe_open
from datasets import load_dataset
import torch
from transformers import AutoModel, AutoTokenizer
import gradio as gr
def load_embeddings(file_path, key="vectors"):
with safe_open(file_path, framework="numpy") as f:
embeddings = f.get_tensor(key)
return embeddings
image_embeddings = load_embeddings("clothes_desc.safetensors")
image_embeddings = image_embeddings / np.linalg.norm(
image_embeddings, axis=1, keepdims=True
)
ds = load_dataset("wbensvage/clothes_desc")["train"]
model_name = "google/siglip2-large-patch16-512"
model = AutoModel.from_pretrained(model_name, device_map="cpu").eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
def encode_text(texts, model, tokenizer):
inputs = tokenizer(texts, return_tensors="pt").to(model.device)
with torch.no_grad():
embs = model.get_text_features(**inputs)
embs = embs.detach().cpu().numpy()
embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
return embs
def find_images(query, top_k):
query_embedding = encode_text([query], model, tokenizer)
similarity = np.dot(query_embedding, image_embeddings.T)
top_k_indices = np.argsort(similarity[0])[::-1][:top_k]
images = [ds[int(i)]["image"] for i in top_k_indices]
return images
iface = gr.Interface(
fn=find_images,
inputs=[
gr.Textbox(lines=2, placeholder="Enter search text here (Shift + Enter to submit)", label="Query"),
gr.Slider(10, 50, step=10, value=20, label="Number of images"),
],
outputs=gr.Gallery(label="Search Results", columns=5, height="auto"),
title="SigLIP2 Image Search",
description="The demo uses [siglip2-large-patch16-512](https://huggingface.co/google/siglip2-large-patch16-512). Compare with [Multilingual CLIP](https://huggingface.co/spaces/adorkin/m-clip-clothes).",
examples=[
["a red dress", 20],
["a blue shirt", 20],
["la blouse rouge", 20],
["la jupe bleue", 20],
["punane kleit", 20],
["sinine särk", 20],
],
)
iface.launch()