haihuynh commited on
Commit
3d49071
1 Parent(s): 596f184

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -32
main.py CHANGED
@@ -13,40 +13,40 @@ from PIL import Image
13
  import matplotlib.cm as cm
14
  import matplotlib as mpl
15
 
16
- # img_names = []
17
- # with open('list_eval_partition.txt', 'r') as f:
18
- # for line in tqdm.tqdm(f):
19
- # img_name, dtype = line.rstrip().split(' ')
20
- # img_names.append(img_name)
21
 
22
 
23
  # For a model pretrained on VGGFace2
24
  print('Loading model weights ........')
25
 
26
- # class SiameseModel(nn.Module):
27
- # def __init__(self):
28
- # super().__init__()
29
- # self.backbone = InceptionResnetV1(pretrained='vggface2')
30
- # def forward(self, x):
31
- # x = self.backbone(x)
32
- # x = torch.nn.functional.normalize(x, dim=1)
33
- # return x
34
 
35
- # model = SiameseModel()
36
- # model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
37
- # model.eval()
38
 
39
 
40
  # Make FAISS index
41
- # print('Make index .............')
42
- # index = faiss.IndexFlatL2(512)
43
 
44
- # hf = h5py.File('face_vecs_full.h5', 'r')
45
- # for key in tqdm.tqdm(hf.keys()):
46
- # vec = np.array(hf.get(key))
47
- # index.add(vec)
48
 
49
- # hf.close()
50
 
51
  print("Finished indexing")
52
 
@@ -98,16 +98,16 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
98
 
99
  gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
100
 
101
- # img_dir = './img_align_celeba'
102
- # examples = random.choices(img_names, k=6)
103
- # examples = [os.path.join(img_dir, ex) for ex in examples]
104
- # examples = [Image.open(img) for img in examples]
105
 
106
- # with gr.Row():
107
- # gr.Examples(
108
- # examples = examples,
109
- # inputs = image
110
- # )
111
 
112
 
113
  btn.click(image_search,
 
13
  import matplotlib.cm as cm
14
  import matplotlib as mpl
15
 
16
+ img_names = []
17
+ with open('list_eval_partition.txt', 'r') as f:
18
+ for line in f:
19
+ img_name, dtype = line.rstrip().split(' ')
20
+ img_names.append(img_name)
21
 
22
 
23
  # For a model pretrained on VGGFace2
24
  print('Loading model weights ........')
25
 
26
+ class SiameseModel(nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.backbone = InceptionResnetV1(pretrained='vggface2')
30
+ def forward(self, x):
31
+ x = self.backbone(x)
32
+ x = torch.nn.functional.normalize(x, dim=1)
33
+ return x
34
 
35
+ model = SiameseModel()
36
+ model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
37
+ model.eval()
38
 
39
 
40
  # Make FAISS index
41
+ print('Make index .............')
42
+ index = faiss.IndexFlatL2(512)
43
 
44
+ hf = h5py.File('face_vecs_full.h5', 'r')
45
+ for key in tqdm.tqdm(hf.keys()):
46
+ vec = np.array(hf.get(key))
47
+ index.add(vec)
48
 
49
+ hf.close()
50
 
51
  print("Finished indexing")
52
 
 
98
 
99
  gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
100
 
101
+ img_dir = './img_align_celeba'
102
+ examples = random.choices(img_names, k=6)
103
+ examples = [os.path.join(img_dir, ex) for ex in examples]
104
+ examples = [Image.open(img) for img in examples]
105
 
106
+ with gr.Row():
107
+ gr.Examples(
108
+ examples = examples,
109
+ inputs = image
110
+ )
111
 
112
 
113
  btn.click(image_search,