import gradio as gr from facenet_pytorch import InceptionResnetV1 import torch.nn as nn import torchvision.transforms as tf import numpy as np import torch import faiss import h5py import tqdm import os import random from PIL import Image import matplotlib.cm as cm import matplotlib as mpl img_names = [] with open('list_eval_partition.txt', 'r') as f: for line in tqdm(f): img_name, dtype = line.rstrip().split(' ') img_names.append(img_name) # For a model pretrained on VGGFace2 print('Loading model weights ........') # class SiameseModel(nn.Module): # def __init__(self): # super().__init__() # self.backbone = InceptionResnetV1(pretrained='vggface2') # def forward(self, x): # x = self.backbone(x) # x = torch.nn.functional.normalize(x, dim=1) # return x # model = SiameseModel() # model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) # model.eval() # Make FAISS index print('Make index .............') index = faiss.IndexFlatL2(512) hf = h5py.File('face_vecs_full.h5', 'r') for key in tqdm(hf.keys()): vec = np.array(hf.get(key)) index.add(vec) hf.close() print("Finished indexing") # Function to search image def image_search(image, k=5): transform = tf.Compose([ tf.Resize((160, 160)), tf.ToTensor() ]) query_img = transform(image) query_img = torch.unsqueeze(query_img, 0) model.eval() query_vec = model(query_img).detach().numpy() D, I = index.search(query_vec, k=k) retrieval_imgs = [] FOLDER = 'img_align_celeba' for idx in I[0]: img_file_name = img_names[idx] path = os.path.join(FOLDER, img_file_name) image = Image.open(path) retrieval_imgs.append((image, '')) return retrieval_imgs with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown(''' # Face Image Retrieval with Content-based Retrieval Image (CBIR) & Saliency Map -------- ''') with gr.Row(): with gr.Column(): image = gr.Image(type='pil', scale=1) slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image') with gr.Row(): btn = gr.Button('Search') clear_btn = gr.ClearButton() gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2) img_dir = './img_align_celeba' examples = random.choices(img_names, k=6) examples = [os.path.join(img_dir, ex) for ex in examples] examples = [Image.open(img) for img in examples] with gr.Row(): gr.Examples( examples = examples, inputs = image ) btn.click(image_search, inputs= [image, slider], outputs= [gallery]) def clear_image(): return None clear_btn.click( fn = clear_image, inputs = [], outputs = [image] ) if __name__ == "__main__": demo.launch()