VizWiz-CLIP-VQA / app.py
Skyy93's picture
Add model files and readme
0c044e0
raw
history blame
3.99 kB
from urllib.request import urlopen
import argparse
import clip
from PIL import Image
import pandas as pd
import time
import torch
from dataloader.extract_features_dataloader import transform_resize, question_preprocess
from model.vqa_model import NetVQA
from dataclasses import dataclass
from torch.cuda.amp import autocast
import gradio as gr
@dataclass
class InferenceConfig:
'''
Describes configuration of the training process
'''
model: str = "RN50x64"
checkpoint_root_clip: str = "./checkpoints/clip"
checkpoint_root_head: str = "./checkpoints/head"
use_question_preprocess: bool = True # True: delete ? at end
aux_mapping = {0: "unanswerable",
1: "unsuitable",
2: "yes",
3: "no",
4: "number",
5: "color",
6: "other"}
folds = 10
tta = False
# Data
n_classes: int = 5726
# class mapping
class_mapping: str = "./data/annotations/class_mapping.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = InferenceConfig()
# load class mapping
cm = pd.read_csv(config.class_mapping)
classid_to_answer = {}
for i in range(len(cm)):
row = cm.iloc[i]
classid_to_answer[row["class_id"]] = row["answer"]
clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip)
model = NetVQA(config).to(config.device)
config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model)
model_state_dict = torch.load(config.checkpoint_head)
model.load_state_dict(model_state_dict, strict=True)
#%%
# Select Preprocessing
image_transforms = transform_resize(clip_model.visual.input_resolution)
if config.use_question_preprocess:
question_transforms = question_preprocess
else:
question_transforms = None
clip_model.eval()
model.eval()
def predict(img, text):
img = Image.fromarray(img)
if config.tta:
image_augmentations = []
for transform in image_transforms:
image_augmentations.append(transform(img))
img = torch.stack(image_augmentations, dim=0)
else:
img = image_transforms(img)
img = img.unsqueeze(dim=0)
question = question_transforms(text)
question_tokens = clip.tokenize(question, truncate=True)
with torch.no_grad():
img = img.to(config.device)
img_feature = clip_model.encode_image(img)
if config.tta:
weights = torch.tensor(config.features_selection).reshape((len(config.features_selection),1))
img_feature = img_feature * weights.to(config.device)
img_feature = img_feature.sum(0)
img_feature = img_feature.unsqueeze(0)
question_tokens = question_tokens.to(config.device)
question_feature = clip_model.encode_text(question_tokens)
with autocast():
output, output_aux = model(img_feature, question_feature)
prediction_vqa = dict()
output = output.cpu().squeeze(0)
for k, v in classid_to_answer.items():
prediction_vqa[v] = float(output[k])
prediction_aux = dict()
output_aux = output_aux.cpu().squeeze(0)
for k, v in config.aux_mapping.items():
prediction_aux[v] = float(output_aux[k])
return prediction_vqa, prediction_aux
gr.Interface(fn=predict,
inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
examples=[['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
).launch()