Spaces:
Build error
Build error
File size: 3,993 Bytes
a4fb052 0c044e0 a4fb052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from urllib.request import urlopen
import argparse
import clip
from PIL import Image
import pandas as pd
import time
import torch
from dataloader.extract_features_dataloader import transform_resize, question_preprocess
from model.vqa_model import NetVQA
from dataclasses import dataclass
from torch.cuda.amp import autocast
import gradio as gr
@dataclass
class InferenceConfig:
'''
Describes configuration of the training process
'''
model: str = "RN50x64"
checkpoint_root_clip: str = "./checkpoints/clip"
checkpoint_root_head: str = "./checkpoints/head"
use_question_preprocess: bool = True # True: delete ? at end
aux_mapping = {0: "unanswerable",
1: "unsuitable",
2: "yes",
3: "no",
4: "number",
5: "color",
6: "other"}
folds = 10
tta = False
# Data
n_classes: int = 5726
# class mapping
class_mapping: str = "./data/annotations/class_mapping.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = InferenceConfig()
# load class mapping
cm = pd.read_csv(config.class_mapping)
classid_to_answer = {}
for i in range(len(cm)):
row = cm.iloc[i]
classid_to_answer[row["class_id"]] = row["answer"]
clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip)
model = NetVQA(config).to(config.device)
config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model)
model_state_dict = torch.load(config.checkpoint_head)
model.load_state_dict(model_state_dict, strict=True)
#%%
# Select Preprocessing
image_transforms = transform_resize(clip_model.visual.input_resolution)
if config.use_question_preprocess:
question_transforms = question_preprocess
else:
question_transforms = None
clip_model.eval()
model.eval()
def predict(img, text):
img = Image.fromarray(img)
if config.tta:
image_augmentations = []
for transform in image_transforms:
image_augmentations.append(transform(img))
img = torch.stack(image_augmentations, dim=0)
else:
img = image_transforms(img)
img = img.unsqueeze(dim=0)
question = question_transforms(text)
question_tokens = clip.tokenize(question, truncate=True)
with torch.no_grad():
img = img.to(config.device)
img_feature = clip_model.encode_image(img)
if config.tta:
weights = torch.tensor(config.features_selection).reshape((len(config.features_selection),1))
img_feature = img_feature * weights.to(config.device)
img_feature = img_feature.sum(0)
img_feature = img_feature.unsqueeze(0)
question_tokens = question_tokens.to(config.device)
question_feature = clip_model.encode_text(question_tokens)
with autocast():
output, output_aux = model(img_feature, question_feature)
prediction_vqa = dict()
output = output.cpu().squeeze(0)
for k, v in classid_to_answer.items():
prediction_vqa[v] = float(output[k])
prediction_aux = dict()
output_aux = output_aux.cpu().squeeze(0)
for k, v in config.aux_mapping.items():
prediction_aux[v] = float(output_aux[k])
return prediction_vqa, prediction_aux
gr.Interface(fn=predict,
inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
examples=[['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
).launch()
|