File size: 2,108 Bytes
9082a92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from safetensors import safe_open
from datasets import load_dataset
import torch
from transformers import AutoModel, AutoTokenizer
import gradio as gr


def load_embeddings(file_path, key="vectors"):
    with safe_open(file_path, framework="numpy") as f:
        embeddings = f.get_tensor(key)
    return embeddings


image_embeddings = load_embeddings("clothes_desc.safetensors")


image_embeddings = image_embeddings / np.linalg.norm(
    image_embeddings, axis=1, keepdims=True
)


ds = load_dataset("wbensvage/clothes_desc")["train"]

model_name = "google/siglip2-large-patch16-512"
model = AutoModel.from_pretrained(model_name, device_map="cpu").eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)


def encode_text(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt").to(model.device)
    with torch.no_grad():
        embs = model.get_text_features(**inputs)
    embs = embs.detach().cpu().numpy()
    embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
    return embs


def find_images(query, top_k):
    query_embedding = encode_text([query], model, tokenizer)
    similarity = np.dot(query_embedding, image_embeddings.T)
    top_k_indices = np.argsort(similarity[0])[::-1][:top_k]
    images = [ds[int(i)]["image"] for i in top_k_indices]
    return images


iface = gr.Interface(
    fn=find_images,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter search text here (Shift + Enter to submit)", label="Query"),
        gr.Slider(10, 50, step=10, value=20, label="Number of images"),
    ],
    outputs=gr.Gallery(label="Search Results", columns=5, height="auto"),
    title="SigLIP2 Image Search",
    description="The demo uses [siglip2-large-patch16-512](https://huggingface.co/google/siglip2-large-patch16-512). Compare with [Multilingual CLIP](https://huggingface.co/spaces/adorkin/m-clip-clothes).",
    examples=[
        ["a red dress", 20],
        ["a blue shirt", 20],
        ["la blouse rouge", 20],
        ["la jupe bleue", 20],
        ["punane kleit", 20],
        ["sinine särk", 20],
    ],
)

iface.launch()