haihuynh commited on
Commit
b4f7e81
1 Parent(s): c66bfe7

Upload 2 files

Browse files
Files changed (2) hide show
  1. cbir_system.py +129 -0
  2. list_eval_partition.txt +0 -0
cbir_system.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from facenet_pytorch import InceptionResnetV1
3
+ import torch.nn as nn
4
+ import torchvision.transforms as tf
5
+ import numpy as np
6
+ import torch
7
+ import faiss
8
+ import h5py
9
+ import os
10
+ import random
11
+ from PIL import Image
12
+ import matplotlib.cm as cm
13
+ import matplotlib as mpl
14
+
15
+ img_names = []
16
+ with open('list_eval_partition.txt', 'r') as f:
17
+ for line in f:
18
+ img_name, dtype = line.rstrip().split(' ')
19
+ img_names.append(img_name)
20
+
21
+
22
+ # For a model pretrained on VGGFace2
23
+ print('Loading model weights ........')
24
+
25
+ class SiameseModel(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.backbone = InceptionResnetV1(pretrained='vggface2')
29
+ def forward(self, x):
30
+ x = self.backbone(x)
31
+ x = torch.nn.functional.normalize(x, dim=1)
32
+ return x
33
+
34
+ model = SiameseModel()
35
+ model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
36
+ model.eval()
37
+
38
+
39
+ # Make FAISS index
40
+ print('Make index .............')
41
+ index = faiss.IndexFlatL2(512)
42
+
43
+ hf = h5py.File('face_vecs_full.h5', 'r')
44
+ for key in hf.keys():
45
+ vec = np.array(hf.get(key))
46
+ index.add(vec)
47
+
48
+ hf.close()
49
+
50
+ # Function to search image
51
+ def image_search(image, k=5):
52
+
53
+ transform = tf.Compose([
54
+ tf.Resize((160, 160)),
55
+ tf.ToTensor()
56
+ ])
57
+
58
+ query_img = transform(image)
59
+ query_img = torch.unsqueeze(query_img, 0)
60
+
61
+ model.eval()
62
+ query_vec = model(query_img).detach().numpy()
63
+
64
+ D, I = index.search(query_vec, k=k)
65
+
66
+ retrieval_imgs = []
67
+
68
+ FOLDER = 'img_align_celeba'
69
+ for idx in I[0]:
70
+ img_file_name = img_names[idx]
71
+ path = os.path.join(FOLDER, img_file_name)
72
+
73
+ image = Image.open(path)
74
+ retrieval_imgs.append((image, ''))
75
+
76
+ return retrieval_imgs
77
+
78
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
79
+ gr.Markdown('''
80
+
81
+
82
+ # Face Image Retrieval with Content-based Retrieval Image (CBIR) & Saliency Map
83
+ --------
84
+
85
+
86
+ ''')
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ image = gr.Image(type='pil', scale=1)
91
+ slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image')
92
+ with gr.Row():
93
+ btn = gr.Button('Search')
94
+ clear_btn = gr.ClearButton()
95
+
96
+ gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
97
+
98
+ img_dir = './img_align_celeba'
99
+ examples = random.choices(img_names, k=6)
100
+ examples = [os.path.join(img_dir, ex) for ex in examples]
101
+ examples = [Image.open(img) for img in examples]
102
+
103
+ with gr.Row():
104
+ gr.Examples(
105
+ examples = examples,
106
+ inputs = image
107
+ )
108
+
109
+
110
+ btn.click(image_search,
111
+ inputs= [image, slider],
112
+ outputs= [gallery])
113
+
114
+ def clear_image():
115
+ return None
116
+
117
+ clear_btn.click(
118
+ fn = clear_image,
119
+ inputs = [],
120
+ outputs = [image]
121
+ )
122
+
123
+
124
+ def parse_args():
125
+
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()
129
+
list_eval_partition.txt ADDED
The diff for this file is too large to render. See raw diff