sasha HF staff commited on
Commit
52477e5
·
1 Parent(s): 3aa66d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -40
app.py CHANGED
@@ -1,60 +1,32 @@
1
- import pickle
2
  import gradio as gr
3
  from datasets import load_dataset
4
- from transformers import AutoModel, AutoFeatureExtractor
5
- import wikipedia
6
 
7
-
8
- # Only runs once when the script is first run.
9
- with open("index_768_cosine.pickle", "rb") as handle:
10
- index = pickle.load(handle)
11
-
12
- # Load model for computing embeddings.
13
- feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
14
- model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
15
 
16
  # Candidate images.
17
- dataset = load_dataset("sasha/butterflies_10k_names_multiple")
18
  ds = dataset["train"]
 
19
 
20
 
21
- def query(image, top_k=4):
22
- inputs = feature_extractor(image, return_tensors="pt")
23
- model_output = model(**inputs)
24
- embedding = model_output.pooler_output.detach()
25
- results = index.query(embedding, k=top_k)
26
- inx = results[0][0].tolist()
27
- logits = results[1][0].tolist()
28
- images = ds.select(inx)["image"]
29
- captions = ds.select(inx)["name"]
30
- images_with_captions = [(i, c) for i, c in zip(images,captions)]
31
- labels_with_probs = dict(zip(captions,logits))
32
- labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
33
- try:
34
- description = wikipedia.summary(captions[0], sentences = 1)
35
- description = "### " + description
36
- url = wikipedia.page(captions[0]).url
37
- url = " You can learn more about your butterfly [here](" + str(url) + ")!"
38
- description = description + url
39
- except:
40
- description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
41
- url = "https://en.wikipedia.org/wiki/Butterfly"
42
- url = " You can learn more about butterflies [here](" + str(url) + ")!"
43
- description = description + url
44
- return images_with_captions, labels_with_probs, description
45
 
46
 
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# Find my Butterfly 🦋")
49
- gr.Markdown("## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!")
50
  with gr.Row():
51
  with gr.Column(min_width= 900):
52
  inputs = gr.Image(shape=(800, 1600))
53
- btn = gr.Button("Find my butterfly!")
54
  description = gr.Markdown()
55
 
56
  with gr.Column():
57
- outputs=gr.Gallery().style(grid=[2], height="auto")
58
  labels = gr.Label()
59
 
60
  gr.Markdown("### Image Examples")
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
+ from sentence_transformers import SentenceTransformer
 
4
 
5
+ model = SentenceTransformer('clip-ViT-B-32')
 
 
 
 
 
 
 
6
 
7
  # Candidate images.
8
+ dataset = load_dataset("sasha/pedro-embeddings")
9
  ds = dataset["train"]
10
+ ds.add_faiss_index(column='embeddings')
11
 
12
 
13
+ def query(image, number_to_retrieve=1):
14
+ prompt = model.encode(image)
15
+ scores, retrieved_examples = ds.get_nearest_examples('embeddings', prompt, k=number_to_retrieve)
16
+ return retrieved_examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  with gr.Blocks() as demo:
20
+ gr.Markdown("# Find my Pedro Pascal")
21
+ gr.Markdown("## Use this Space to find the Pedro Pascal most similar to your input image!")
22
  with gr.Row():
23
  with gr.Column(min_width= 900):
24
  inputs = gr.Image(shape=(800, 1600))
25
+ btn = gr.Button("Find my Pedro!")
26
  description = gr.Markdown()
27
 
28
  with gr.Column():
29
+ outputs=gr.Gallery().style(grid=[1], height="auto")
30
  labels = gr.Label()
31
 
32
  gr.Markdown("### Image Examples")