File size: 1,808 Bytes
409c8d1
54b0744
409c8d1
 
 
54b0744
ff843a8
409c8d1
 
 
b89ace4
ff843a8
409c8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cafc19
409c8d1
e860724
409c8d1
 
e860724
409c8d1
 
ffa6daf
409c8d1
80065f6
 
 
 
54b0744
 
 
409c8d1
f3aa75c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from transformers import CLIPProcessor, CLIPModel

def find_similar(image):

    image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    ## Define model
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    model = model.to(device)
    
    ## Load data
    photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
    photo_features = np.load("./features.npy")
    photo_ids = pd.read_csv("./photo_ids.csv")
    photo_ids = list(photo_ids['photo_id'])
    
    ## Inference
    with torch.no_grad():
        photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
        search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
        search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
    search_photo_feature = search_photo_feature.cpu().numpy()
    
    ## Find similarity
    similarities = list((search_photo_feature @ photo_features.T).squeeze(0))
    
    ## Return best image :)
    idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1]
    photo_id = photo_ids[idx]
    try:
       photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
    except:
       photo_data = photos.iloc[0]
    response = requests.get(photo_data["photo_image_url"] + "?w=640")
    img = PILIMAGE.open(BytesIO(response.content))
    return img

iface = gr.Interface(fn=find_similar, inputs=gr.inputs.Image(), outputs=gr.outputs.Image(type="pil")).launch()