File size: 3,134 Bytes
b4f7e81
 
 
 
 
 
 
 
a451dfa
b4f7e81
 
 
 
 
 
3d49071
 
 
 
 
b4f7e81
 
 
 
 
3d49071
 
 
 
 
 
 
 
b4f7e81
3d49071
 
 
b4f7e81
 
 
3d49071
 
b4f7e81
3d49071
 
 
 
b4f7e81
3d49071
b4f7e81
0272604
 
b4f7e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e13603
b4f7e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d49071
d56b828
3d49071
 
b4f7e81
3d49071
 
 
 
 
b4f7e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f130359
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from facenet_pytorch import InceptionResnetV1
import torch.nn as nn
import torchvision.transforms as tf
import numpy as np
import torch
import faiss
import h5py
import tqdm
import os
import random
from PIL import Image
import matplotlib.cm as cm
import matplotlib as mpl

img_names = []
with open('list_eval_partition.txt', 'r') as f:
    for line in f:
        img_name, dtype = line.rstrip().split(' ')
        img_names.append(img_name)


# For a model pretrained on VGGFace2
print('Loading model weights ........')

class SiameseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = InceptionResnetV1(pretrained='vggface2')
    def forward(self, x):
        x = self.backbone(x)
        x = torch.nn.functional.normalize(x, dim=1)
        return x
    
model = SiameseModel()
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model.eval()


# Make FAISS index
print('Make index .............')
index = faiss.IndexFlatL2(512)

hf = h5py.File('face_vecs_full.h5', 'r')
for key in tqdm.tqdm(hf.keys()):
    vec = np.array(hf.get(key))
    index.add(vec)

hf.close()

print("Finished indexing")

# Function to search image
def image_search(image, k=5):
    
    transform = tf.Compose([
        tf.Resize((160, 160)),
        tf.ToTensor()
    ])

    query_img = transform(image)
    query_img = torch.unsqueeze(query_img, 0)

    model.eval()
    query_vec = model(query_img).detach().numpy()

    D, I = index.search(query_vec, k=k)

    retrieval_imgs = []
    
    FOLDER = 'img_align_celeba'
    for idx in I[0]:
        img_file_name = img_names[idx]
        path = os.path.join(FOLDER, img_file_name)

        image = Image.open(path)
        retrieval_imgs.append((image, ''))
        
    return retrieval_imgs

with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown('''
    
    
    # Face Image Retrieval with Content-based Image Retrieval (CBIR)
    --------
    
    
    ''')
    
    with gr.Row():
        with gr.Column():
            image = gr.Image(type='pil', scale=1)
            slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image')
            with gr.Row():
                btn = gr.Button('Search')
                clear_btn = gr.ClearButton()
        
        gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
        
    img_dir = './img_align_celeba'    
    examples = random.choices(img_names, k=5)
    examples = [os.path.join(img_dir, ex) for ex in examples]
    examples = [Image.open(img) for img in examples]
                                  
    with gr.Row():
        gr.Examples(
            examples = examples,
            inputs = image
        )
            
            
    btn.click(image_search, 
              inputs= [image, slider], 
              outputs= [gallery])
    
    def clear_image():
        return None
    
    clear_btn.click(
        fn = clear_image,
        inputs = [],
        outputs = [image]
    )
    
if __name__ == "__main__":
    demo.launch(server_name = "0.0.0.0", server_port = 7860)