JLD commited on
Commit
cb57dca
·
1 Parent(s): 17a1ebc

Update search

Browse files
Files changed (1) hide show
  1. app.py +43 -25
app.py CHANGED
@@ -2,34 +2,52 @@ import gradio as gr
2
  import random
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer, util
 
 
 
 
5
 
6
- model = SentenceTransformer('clip-ViT-B-32')
7
-
8
- def fake_gan():
9
- images = [
10
- (random.choice(
11
- [
12
- "https://upload.wikimedia.org/wikipedia/commons/6/69/NASA-HS201427a-HubbleUltraDeepField2014-20140603.jpg",
13
- "https://upload.wikimedia.org/wikipedia/commons/7/73/Cycliste_%C3%A0_place_d%27Italie-Paris.jpg",
14
- "https://upload.wikimedia.org/wikipedia/commons/3/31/Great_white_shark_south_africa.jpg",
15
- ]
16
- ), f"label {i}" if i != 0 else "label" * 50)
17
- for i in range(3)
18
- ]
19
- return images
20
-
21
- def search_images_from_text(text):
22
- emb = model.encode(text)
23
- return fake_gan()
24
-
25
- def search_images_from_image(image):
26
- image_emb = model.encode(image)
27
- return fake_gan()
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def main():
30
- dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device="cuda:0")
31
- text_to_image_iface = gr.Interface(fn=search_images_from_text, inputs="text", outputs="gallery")
32
- image_to_image_iface = gr.Interface(fn=search_images_from_image, inputs="image", outputs="gallery")
 
33
  demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
34
  demo.launch()
35
 
 
2
  import random
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer, util
5
+ import logging
6
+ from PIL import Image
7
+ # Create a custom logger
8
+ logger = logging.getLogger(__name__)
9
 
10
+ # Set the level of this logger. INFO means that it will log all INFO, WARNING, ERROR, and CRITICAL messages.
11
+ logger.setLevel(logging.INFO)
12
+
13
+ # Create handlers
14
+ c_handler = logging.StreamHandler()
15
+ c_handler.setLevel(logging.INFO)
16
+
17
+ # Create formatters and add it to handlers
18
+ c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
19
+ c_handler.setFormatter(c_format)
20
+
21
+ # Add handlers to the logger
22
+ logger.addHandler(c_handler)
23
+
24
+ class SearchEngine:
25
+ def __init__(self):
26
+ self.model = SentenceTransformer('clip-ViT-B-32')
27
+ self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device="cuda:0")
28
+ image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
29
+ self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
30
+
31
+ def get_candidates(self, query_embedding, top_k=5):
32
+ logger.info("Getting candidates")
33
+ candidates = util.semantic_search(query_embeddings=query_embedding.unsqueeze(0), corpus_embeddings=self.embedding_dataset["image_embedding"].squeeze(1), top_k=top_k)[0]
34
+ return [self.image_dataset.get(self.embedding_dataset[candidate["corpus_id"]]["image_id"], "https://upload.wikimedia.org/wikipedia/commons/6/69/NASA-HS201427a-HubbleUltraDeepField2014-20140603.jpg") for candidate in candidates]
35
+
36
+ def search_images_from_text(self, text):
37
+ logger.info("Searching images from text")
38
+ emb = self.model.encode(text, convert_to_tensor=True, device="cuda:0")
39
+ return self.get_candidates(query_embedding=emb)
40
+
41
+ def search_images_from_image(self, image):
42
+ logger.info("Searching images from image")
43
+ emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device="cuda:0")
44
+ return self.get_candidates(query_embedding=emb)
45
 
46
  def main():
47
+ logger.info("Loading dataset")
48
+ search_engine = SearchEngine()
49
+ text_to_image_iface = gr.Interface(fn=search_engine.search_images_from_text, inputs="text", outputs="gallery")
50
+ image_to_image_iface = gr.Interface(fn=search_engine.search_images_from_image, inputs="image", outputs="gallery")
51
  demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
52
  demo.launch()
53