Spaces:
Runtime error
Runtime error
from typing import Optional | |
import os | |
os.environ["WANDB_DISABLED"] = "true" | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize | |
from torchvision.transforms.functional import InterpolationMode | |
from torchvision import transforms | |
from torchvision.io import ImageReadMode, read_image | |
from transformers import CLIPModel, AutoModel | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_model | |
from datasets import load_dataset, load_metric | |
from transformers import ( | |
AutoConfig, | |
AutoImageProcessor, | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
logging, | |
) | |
class Transform(torch.nn.Module): | |
def __init__(self, image_size, mean, std): | |
super().__init__() | |
self.transforms = torch.nn.Sequential( | |
Resize([image_size], interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
ConvertImageDtype(torch.float), | |
Normalize(mean, std), | |
) | |
def forward(self, x) -> torch.Tensor: | |
"""`x` should be an instance of `PIL.Image.Image`""" | |
with torch.no_grad(): | |
x = self.transforms(x) | |
return x | |
class VisionTextDualEncoderModel(nn.Module): | |
def __init__(self, num_classes): | |
super(VisionTextDualEncoderModel, self).__init__() | |
# Load the XLM-RoBERTa model | |
self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") | |
# Define your vision model (e.g., using torchvision) | |
self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
vision_output_dim = self.vision_encoder.config.vision_config.hidden_size | |
# Combine the modalities | |
self.fc = nn.Linear( | |
self.text_encoder.config.hidden_size + vision_output_dim, num_classes | |
) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
return_loss: Optional[bool] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
): | |
# Encode text inputs | |
text_outputs = self.text_encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
).pooler_output | |
# Encode vision inputs | |
vision_outputs = self.vision_encoder.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# Concatenate text and vision features | |
combined_features = torch.cat( | |
(text_outputs, vision_outputs.pooler_output), dim=1 | |
) | |
# Forward through a linear layer for classification | |
logits = self.fc(combined_features) | |
return {"logits": logits} | |
id2label = {0: "negative", 1: "neutral", 2: "positive"} | |
label2id = {"negative": 0, "neutral": 1, "positive": 2} | |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") | |
model = VisionTextDualEncoderModel(num_classes=3) | |
config = model.vision_encoder.config | |
# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors | |
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors") | |
load_model(model, sf_filename) | |
# model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin")) | |
# model = AutoModelForSequenceClassification.from_pretrained( | |
# "FFZG-cleopatra/M2SA", | |
# num_labels=3, id2label=id2label, | |
# label2id=label2id | |
# ) | |
image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
def predict_sentiment(text, image): | |
print(text, image) | |
print(dir(image)) | |
image = read_image(image, mode=ImageReadMode.RGB) | |
# image = transforms.ToTensor()(image).unsqueeze(0) | |
text_inputs = tokenizer( | |
text, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
) | |
image_transformations = Transform( | |
config.vision_config.image_size, | |
image_processor.image_mean, | |
image_processor.image_std, | |
) | |
image_transformations = torch.jit.script(image_transformations) | |
image = image_transformations(image) | |
model_input = { | |
"input_ids" : text_inputs.input_ids, | |
"pixel_values":image, | |
"attention_mask" : text_inputs.attention_mask, | |
} | |
print(text_inputs) | |
print(image) | |
print(model_input) | |
prediction = None | |
with torch.no_grad(): | |
prediction = model(input_ids=text_inputs.input_ids,attention_mask=text_inputs.attention_mask, pixel_values=image) | |
print(prediction) | |
return prediction | |
interface = gr.Interface( | |
fn=lambda text, image: predict_sentiment(text, image), | |
inputs=[gr.Textbox(),gr.Image(type="filepath")], | |
outputs=['text'], | |
title='Multilingual-Multimodal-Sentiment-Analysis', | |
examples= [["I am enjoying","A_Sep20_14_1189155141.jpg"]], | |
description='Get the positive/neutral/negative sentiment for the given input.' | |
) | |
interface.launch(inline = False) | |