from typing import Optional import os os.environ["WANDB_DISABLED"] = "true" import numpy as np from PIL import Image import gradio as gr import torch import torch.nn as nn from transformers import CLIPModel, AutoModel from huggingface_hub import hf_hub_download from safetensors.torch import load_model from datasets import load_dataset, load_metric from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, logging, ) class Transform(torch.nn.Module): def __init__(self, image_size, mean, std): super().__init__() self.transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize(mean, std), ) def forward(self, x) -> torch.Tensor: """`x` should be an instance of `PIL.Image.Image`""" with torch.no_grad(): x = self.transforms(x) return x class VisionTextDualEncoderModel(nn.Module): def __init__(self, num_classes): super(VisionTextDualEncoderModel, self).__init__() # Load the XLM-RoBERTa model self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") # Define your vision model (e.g., using torchvision) self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") vision_output_dim = self.vision_encoder.config.vision_config.hidden_size # Combine the modalities self.fc = nn.Linear( self.text_encoder.config.hidden_size + vision_output_dim, num_classes ) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, token_type_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, ): # Encode text inputs text_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, ).pooler_output # Encode vision inputs vision_outputs = self.vision_encoder.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Concatenate text and vision features combined_features = torch.cat( (text_outputs, vision_outputs.pooler_output), dim=1 ) # Forward through a linear layer for classification logits = self.fc(combined_features) return {"logits": logits} id2label = {0: "negative", 1: "neutral", 2: "positive"} label2id = {"negative": 0, "neutral": 1, "positive": 2} tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") model = VisionTextDualEncoderModel(num_classes=3) config = model.vision_encoder.config # https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors") load_model(model, sf_filename) # model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin")) # model = AutoModelForSequenceClassification.from_pretrained( # "FFZG-cleopatra/M2SA", # num_labels=3, id2label=id2label, # label2id=label2id # ) image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") def predict_sentiment(text, image): print(text, image) text_inputs = tokenizer( text, max_length=512, padding="max_length", truncation=True, ) image_transformations = Transform( config.vision_config.image_size, image_processor.image_mean, image_processor.image_std, ) image_transformations = torch.jit.script(image_transformations) image = image_transformations(image) model_input = { "input_ids" : text_inputs.input_ids, "pixel_values":image, "attention_mask" : text_inputs.attention_mask, } prediction = None with torch.no_grad(): prediction = model(*model_input) print(prediction) return prediction interface = gr.Interface( fn=lambda text, image: predict_sentiment(text, image), inputs=[gr.Textbox(),gr.Image()], outputs=['text'], title='Multilingual-Multimodal-Sentiment-Analysis', examples= [], description='Get the positive/neutral/negative sentiment for the given input.' ) interface.launch(inline = False)