from typing import Optional import os os.environ["WANDB_DISABLED"] = "true" import numpy as np import gradio as gr import torch import torch.nn as nn import torchvision from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode from torchvision import transforms from torchvision.io import ImageReadMode, read_image from transformers import CLIPModel, AutoModel from huggingface_hub import hf_hub_download from safetensors.torch import load_model from datasets import load_dataset, load_metric from transformers import ( AutoConfig, AutoImageProcessor, AutoModelForSequenceClassification, AutoTokenizer, logging, ) class Transform(torch.nn.Module): def __init__(self, image_size, mean, std): super().__init__() self.transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize(mean, std), ) def forward(self, x) -> torch.Tensor: """`x` should be an instance of `PIL.Image.Image`""" with torch.no_grad(): x = self.transforms(x) return x class VisionTextDualEncoderModel(nn.Module): def __init__(self, num_classes): super(VisionTextDualEncoderModel, self).__init__() # Load the XLM-RoBERTa model self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") # Define your vision model (e.g., using torchvision) self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") vision_output_dim = self.vision_encoder.config.vision_config.hidden_size # Combine the modalities self.fc = nn.Linear( self.text_encoder.config.hidden_size + vision_output_dim, num_classes ) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, token_type_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, ): # Encode text inputs text_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, ).pooler_output # Encode vision inputs vision_outputs = self.vision_encoder.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Concatenate text and vision features combined_features = torch.cat( (text_outputs, vision_outputs.pooler_output), dim=1 ) # Forward through a linear layer for classification logits = self.fc(combined_features) return {"logits": logits} id2label = {0: "negative", 1: "neutral", 2: "positive"} label2id = {"negative": 0, "neutral": 1, "positive": 2} tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") model = VisionTextDualEncoderModel(num_classes=3) config = model.vision_encoder.config # https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors") load_model(model, sf_filename) image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") def predict_sentiment(text, image): print(text, image) image = read_image(image, mode=ImageReadMode.RGB) text_inputs = tokenizer( text, max_length=512, padding="max_length", truncation=True, return_tensors="pt" ) image_transformations = Transform( config.vision_config.image_size, image_processor.image_mean, image_processor.image_std, ) image_transformations = torch.jit.script(image_transformations) pixel_values = image_transformations(image) text_inputs["pixel_values"] = pixel_values.unsqueeze(0) prediction = None with torch.no_grad(): outputs = model(**text_inputs) print(outputs) prediction = np.argmax(outputs["logits"], axis=-1) print(id2label[prediction[0].item()]) return id2label[prediction[0].item()] interface = gr.Interface( fn=lambda text, image: predict_sentiment(text, image), inputs=[gr.Textbox(),gr.Image(type="filepath")], outputs=['text'], title='Multilingual Multimodal Sentiment Analysis', examples= [["I am enjoying","A_Sep20_14_1189155141.jpg"]], description='Get the positive/neutral/negative sentiment for the given input.' ) interface.launch(inline = False)