|
from datasets import load_dataset |
|
from torchvision import transforms |
|
import torch |
|
from timm import create_model |
|
from omegaconf import OmegaConf |
|
import faiss |
|
import pickle |
|
import gradio as gr |
|
import os |
|
import joblib |
|
import torch.nn as nn |
|
from typing import Dict, Iterable, Callable |
|
from torch import Tensor |
|
import torchvision |
|
from PIL import Image |
|
|
|
|
|
def get_model(args,arch,load_from,arch_path): |
|
if load_from == 'timm': |
|
model = create_model(arch,pretrained = True).to(args.PARAMETERS.device) |
|
print("Load model timm") |
|
elif load_from == 'torchvision': |
|
if arch == 'resnet50': |
|
model = torchvision.models.resnet50(pretrained=False) |
|
if len(arch_path)>0: |
|
print("Loading pretrained Model") |
|
model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True) |
|
model.eval() |
|
return model |
|
|
|
|
|
def get_transform(args): |
|
return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]), |
|
transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]), |
|
transforms.ToTensor()]) |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
def __init__(self, model: nn.Module, layers: Iterable[str]): |
|
super().__init__() |
|
self.model = model |
|
self.layers = layers |
|
self._features = {layer: torch.empty(0) for layer in layers} |
|
|
|
for layer_id in layers: |
|
layer = dict([*self.model.named_modules()])[layer_id] |
|
layer.register_forward_hook(self.save_outputs_hook(layer_id)) |
|
|
|
def save_outputs_hook(self, layer_id: str) -> Callable: |
|
def fn(_, __, output): |
|
self._features[layer_id] = output |
|
return fn |
|
|
|
def forward(self, x: Tensor) -> Dict[str, Tensor]: |
|
_ = self.model(x) |
|
return self._features |
|
|
|
|
|
def _load_dataset(args): |
|
if args.PARAMETERS.metric == 'L2': |
|
faiss_metric = faiss.METRIC_L2 |
|
dataset = load_dataset(args.PARAMETERS.dataset,split = 'train') |
|
dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric) |
|
dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric) |
|
return dataset |
|
|
|
|
|
args = OmegaConf.load("configs/resnet.yaml") |
|
wiki_dataset = _load_dataset(args) |
|
TRANSFORMS = get_transform(args) |
|
robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path) |
|
non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path) |
|
fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer]) |
|
fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer]) |
|
|
|
|
|
|
|
def retrieval_fn(image,radio): |
|
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = TRANSFORMS(image).unsqueeze(0) |
|
image = image.to(args.PARAMETERS.device) |
|
|
|
if radio == 'robust': |
|
emb = fe_robust_model(image)[args.ROBUST.layer] |
|
emb = emb.view(1,-1).detach().cpu().numpy() |
|
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col, |
|
query = emb, |
|
k = 3) |
|
elif radio == 'standard': |
|
emb = fe_nonrobust_model(image)[args.NONROBUST.layer] |
|
emb = emb.view(1,-1).detach().cpu().numpy() |
|
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col, |
|
query = emb, |
|
k=3) |
|
return scores,retrieved_examples |
|
|
|
def gradio_fn(image,radio): |
|
scores,retrieved_examples = retrieval_fn(image,radio) |
|
m = [] |
|
for description,image,score in zip(retrieved_examples['description'], |
|
retrieved_examples['image'], |
|
scores): |
|
m.append(description) |
|
m.append(image) |
|
return m |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown("# Robust vs Standard Image Retrieval") |
|
with gr.Tabs(): |
|
with gr.TabItem("Upload your Image"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
image_input = gr.Image(label="Input Image") |
|
with gr.Row(): |
|
radio_button = gr.Radio(["robust","standard"], |
|
value = "robust", |
|
label = "OD Model") |
|
with gr.Row(): |
|
calculate_button = gr.Button("Compute") |
|
with gr.Column(): |
|
textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") |
|
output_image1 = gr.Image(label="1st Best match") |
|
textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") |
|
output_image2 = gr.Image(label="2nd Best match") |
|
textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") |
|
output_image3 = gr.Image(label="3rd Best match") |
|
|
|
calculate_button.click(fn = gradio_fn, |
|
inputs = [image_input,radio_button], |
|
outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3]) |
|
demo.launch(share = False,debug = True) |
|
|