Spaces:
Runtime error
Runtime error
Update search
Browse files
app.py
CHANGED
@@ -2,34 +2,52 @@ import gradio as gr
|
|
2 |
import random
|
3 |
from datasets import load_dataset
|
4 |
from sentence_transformers import SentenceTransformer, util
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
def
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def main():
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
|
34 |
demo.launch()
|
35 |
|
|
|
2 |
import random
|
3 |
from datasets import load_dataset
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import logging
|
6 |
+
from PIL import Image
|
7 |
+
# Create a custom logger
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
|
10 |
+
# Set the level of this logger. INFO means that it will log all INFO, WARNING, ERROR, and CRITICAL messages.
|
11 |
+
logger.setLevel(logging.INFO)
|
12 |
+
|
13 |
+
# Create handlers
|
14 |
+
c_handler = logging.StreamHandler()
|
15 |
+
c_handler.setLevel(logging.INFO)
|
16 |
+
|
17 |
+
# Create formatters and add it to handlers
|
18 |
+
c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
|
19 |
+
c_handler.setFormatter(c_format)
|
20 |
+
|
21 |
+
# Add handlers to the logger
|
22 |
+
logger.addHandler(c_handler)
|
23 |
+
|
24 |
+
class SearchEngine:
|
25 |
+
def __init__(self):
|
26 |
+
self.model = SentenceTransformer('clip-ViT-B-32')
|
27 |
+
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device="cuda:0")
|
28 |
+
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
29 |
+
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
30 |
+
|
31 |
+
def get_candidates(self, query_embedding, top_k=5):
|
32 |
+
logger.info("Getting candidates")
|
33 |
+
candidates = util.semantic_search(query_embeddings=query_embedding.unsqueeze(0), corpus_embeddings=self.embedding_dataset["image_embedding"].squeeze(1), top_k=top_k)[0]
|
34 |
+
return [self.image_dataset.get(self.embedding_dataset[candidate["corpus_id"]]["image_id"], "https://upload.wikimedia.org/wikipedia/commons/6/69/NASA-HS201427a-HubbleUltraDeepField2014-20140603.jpg") for candidate in candidates]
|
35 |
+
|
36 |
+
def search_images_from_text(self, text):
|
37 |
+
logger.info("Searching images from text")
|
38 |
+
emb = self.model.encode(text, convert_to_tensor=True, device="cuda:0")
|
39 |
+
return self.get_candidates(query_embedding=emb)
|
40 |
+
|
41 |
+
def search_images_from_image(self, image):
|
42 |
+
logger.info("Searching images from image")
|
43 |
+
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device="cuda:0")
|
44 |
+
return self.get_candidates(query_embedding=emb)
|
45 |
|
46 |
def main():
|
47 |
+
logger.info("Loading dataset")
|
48 |
+
search_engine = SearchEngine()
|
49 |
+
text_to_image_iface = gr.Interface(fn=search_engine.search_images_from_text, inputs="text", outputs="gallery")
|
50 |
+
image_to_image_iface = gr.Interface(fn=search_engine.search_images_from_image, inputs="image", outputs="gallery")
|
51 |
demo = gr.TabbedInterface([text_to_image_iface, image_to_image_iface], ["Text query", "Image query"])
|
52 |
demo.launch()
|
53 |
|