Spaces:
Sleeping
Sleeping
import gradio as gr | |
from facenet_pytorch import InceptionResnetV1 | |
import torch.nn as nn | |
import torchvision.transforms as tf | |
import numpy as np | |
import torch | |
import faiss | |
import h5py | |
import tqdm | |
import os | |
import random | |
from PIL import Image | |
import matplotlib.cm as cm | |
import matplotlib as mpl | |
img_names = [] | |
with open('list_eval_partition.txt', 'r') as f: | |
for line in f: | |
img_name, dtype = line.rstrip().split(' ') | |
img_names.append(img_name) | |
# For a model pretrained on VGGFace2 | |
print('Loading model weights ........') | |
class SiameseModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.backbone = InceptionResnetV1(pretrained='vggface2') | |
def forward(self, x): | |
x = self.backbone(x) | |
x = torch.nn.functional.normalize(x, dim=1) | |
return x | |
model = SiameseModel() | |
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
# Make FAISS index | |
print('Make index .............') | |
index = faiss.IndexFlatL2(512) | |
hf = h5py.File('face_vecs_full.h5', 'r') | |
for key in tqdm.tqdm(hf.keys()): | |
vec = np.array(hf.get(key)) | |
index.add(vec) | |
hf.close() | |
print("Finished indexing") | |
# Function to search image | |
def image_search(image, k=5): | |
transform = tf.Compose([ | |
tf.Resize((160, 160)), | |
tf.ToTensor() | |
]) | |
query_img = transform(image) | |
query_img = torch.unsqueeze(query_img, 0) | |
model.eval() | |
query_vec = model(query_img).detach().numpy() | |
D, I = index.search(query_vec, k=k) | |
retrieval_imgs = [] | |
FOLDER = 'img_align_celeba' | |
for idx in I[0]: | |
img_file_name = img_names[idx] | |
path = os.path.join(FOLDER, img_file_name) | |
image = Image.open(path) | |
retrieval_imgs.append((image, '')) | |
return retrieval_imgs | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown(''' | |
# Face Image Retrieval with Content-based Retrieval Image (CBIR) & Saliency Map | |
-------- | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type='pil', scale=1) | |
slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image') | |
with gr.Row(): | |
btn = gr.Button('Search') | |
clear_btn = gr.ClearButton() | |
gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2) | |
img_dir = './img_align_celeba' | |
examples = random.choices(img_names, k=6) | |
examples = [os.path.join(img_dir, ex) for ex in examples] | |
examples = [Image.open(img) for img in examples] | |
with gr.Row(): | |
gr.Examples( | |
examples = examples, | |
inputs = image | |
) | |
btn.click(image_search, | |
inputs= [image, slider], | |
outputs= [gallery]) | |
def clear_image(): | |
return None | |
clear_btn.click( | |
fn = clear_image, | |
inputs = [], | |
outputs = [image] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name = "0.0.0.0", server_port = 7860) |