thak123's picture
Update app.py
738ac11 verified
raw
history blame
No virus
4.45 kB
import numpy as np
import os
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPModel, AutoModel
from typing import Optional
from safetensors.torch import load_model
os.environ["WANDB_DISABLED"] = "true"
from datasets import load_dataset, load_metric
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
logging,
)
class VisionTextDualEncoderModel(nn.Module):
def __init__(self, num_classes):
super(VisionTextDualEncoderModel, self).__init__()
# Load the XLM-RoBERTa model
self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
# Define your vision model (e.g., using torchvision)
self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_output_dim = self.vision_encoder.config.vision_config.hidden_size
# Combine the modalities
self.fc = nn.Linear(
self.text_encoder.config.hidden_size + vision_output_dim, num_classes
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
):
# Encode text inputs
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
).pooler_output
# Encode vision inputs
vision_outputs = self.vision_encoder.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Concatenate text and vision features
combined_features = torch.cat(
(text_outputs, vision_outputs.pooler_output), dim=1
)
# Forward through a linear layer for classification
logits = self.fc(combined_features)
return {"logits": logits}
id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
model = VisionTextDualEncoderModel(num_classes=3)
config = model.vision_text_model.config
# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors")
load_model(model,"model.safetensors") # model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin"))
model = AutoModelForSequenceClassification.from_pretrained(
"FFZG-cleopatra/M2SA",
num_labels=3, id2label=id2label,
label2id=label2id
)
def predict_sentiment(text, image):
print(text, image)
text_inputs = tokenizer(
text,
max_length=512,
padding="max_length",
truncation=True,
)
image_transformations = Transform(
config.vision_config.image_size,
image_processor.image_mean,
image_processor.image_std,
)
image_transformations = torch.jit.script(image_transformations)
image = image_transformations(image)
model_input = {
"input_ids" : text_inputs.input_ids,
"pixel_values":image
"attention_mask" : text_inputs.attention_mask,
}
prediction = None
with torch.no_grad():
prediction = model(model_input)
print(prediction)
return prediction
interface = gr.Interface(
fn=lambda text, image: predict_sentiment(text, image),
inputs=[gr.inputs.Textbox(),gr.inputs.Image(shape=(224, 224))],
outputs=['text'],
title='Multilingual-Multimodal-Sentiment-Analysis',
examples= ["I love tea","I hate coffee"],
description='Get the positive/neutral/negative sentiment for the given input.'
)
interface.launch(inline = False)