import clip from PIL import Image import pandas as pd import torch from dataloader.extract_features_dataloader import transform_resize, question_preprocess from model.vqa_model import NetVQA from dataclasses import dataclass from torch.cuda.amp import autocast import gradio as gr @dataclass class InferenceConfig: ''' Describes configuration of the training process ''' model: str = "RN50x64" checkpoint_root_clip: str = "./checkpoints/clip" checkpoint_root_head: str = "./checkpoints/head" use_question_preprocess: bool = True # True: delete ? at end aux_mapping = {0: "unanswerable", 1: "unsuitable", 2: "yes", 3: "no", 4: "number", 5: "color", 6: "other"} folds = 10 # Data n_classes: int = 5726 # class mapping class_mapping: str = "./data/annotations/class_mapping.csv" device = "cuda" if torch.cuda.is_available() else "cpu" config = InferenceConfig() # load class mapping cm = pd.read_csv(config.class_mapping) classid_to_answer = {} for i in range(len(cm)): row = cm.iloc[i] classid_to_answer[row["class_id"]] = row["answer"] clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip, device=config.device) model = NetVQA(config).to(config.device) config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model) model_state_dict = torch.load(config.checkpoint_head) model.load_state_dict(model_state_dict, strict=True) model.eval() # Select Preprocessing image_transforms = transform_resize(clip_model.visual.input_resolution) if config.use_question_preprocess: question_transforms = question_preprocess else: question_transforms = None clip_model.eval() def predict(img, text): img = Image.fromarray(img) img = image_transforms(img) img = img.unsqueeze(dim=0) if question_transforms is not None: question = question_transforms(text) else: question = text question_tokens = clip.tokenize(question, truncate=True) with torch.no_grad(): img = img.to(config.device) img_feature = clip_model.encode_image(img) question_tokens = question_tokens.to(config.device) question_feature = clip_model.encode_text(question_tokens) with autocast(): output, output_aux = model(img_feature, question_feature) prediction_vqa = dict() output = output.cpu().squeeze(0) for k, v in classid_to_answer.items(): prediction_vqa[v] = float(output[k]) prediction_aux = dict() output_aux = output_aux.cpu().squeeze(0) for k, v in config.aux_mapping.items(): prediction_aux[v] = float(output_aux[k]) return prediction_vqa, prediction_aux description = """ Less Is More: Linear Layers on CLIP Features as Powerful VizWiz Model Our approach focuses on visual question answering for visual impaired people. We fine-tuned our approach on the CVPR Grand Challenge VizWiz 2022 data set. You may click on one of the examples or upload your own image and question. The Gradio app shows the current answer for your question and an answer category. Link to our paper. """ gr.Interface(fn=predict, description=description, inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')], outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)], examples=[['examples/Augustiner.jpg', 'What is this?'],['examples/VizWiz_test_00006968.jpg', 'Can you tell me the color of the dog?'], ['examples/VizWiz_test_00005604.jpg', 'What drink is this?'], ['examples/VizWiz_test_00006246.jpg', 'Can you please tell me what kind of tea this is?'], ['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']] ).launch()