thak123's picture
Update app.py
05e22ed verified
raw
history blame
No virus
5.07 kB
from typing import Optional
import os
os.environ["WANDB_DISABLED"] = "true"
import numpy as np
from PIL import Image
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPModel, AutoModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
from datasets import load_dataset, load_metric
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
logging,
)
class Transform(torch.nn.Module):
def __init__(self, image_size, mean, std):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(mean, std),
)
def forward(self, x) -> torch.Tensor:
"""`x` should be an instance of `PIL.Image.Image`"""
with torch.no_grad():
x = self.transforms(x)
return x
class VisionTextDualEncoderModel(nn.Module):
def __init__(self, num_classes):
super(VisionTextDualEncoderModel, self).__init__()
# Load the XLM-RoBERTa model
self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
# Define your vision model (e.g., using torchvision)
self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_output_dim = self.vision_encoder.config.vision_config.hidden_size
# Combine the modalities
self.fc = nn.Linear(
self.text_encoder.config.hidden_size + vision_output_dim, num_classes
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
):
# Encode text inputs
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
).pooler_output
# Encode vision inputs
vision_outputs = self.vision_encoder.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Concatenate text and vision features
combined_features = torch.cat(
(text_outputs, vision_outputs.pooler_output), dim=1
)
# Forward through a linear layer for classification
logits = self.fc(combined_features)
return {"logits": logits}
id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
model = VisionTextDualEncoderModel(num_classes=3)
config = model.vision_encoder.config
# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors")
load_model(model, sf_filename)
# model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin"))
# model = AutoModelForSequenceClassification.from_pretrained(
# "FFZG-cleopatra/M2SA",
# num_labels=3, id2label=id2label,
# label2id=label2id
# )
image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
def predict_sentiment(text, image):
print(text, image)
text_inputs = tokenizer(
text,
max_length=512,
padding="max_length",
truncation=True,
)
image_transformations = Transform(
config.vision_config.image_size,
image_processor.image_mean,
image_processor.image_std,
)
image_transformations = torch.jit.script(image_transformations)
image = image_transformations(image)
model_input = {
"input_ids" : text_inputs.input_ids,
"pixel_values":image,
"attention_mask" : text_inputs.attention_mask,
}
prediction = None
with torch.no_grad():
prediction = model(*model_input)
print(prediction)
return prediction
interface = gr.Interface(
fn=lambda text, image: predict_sentiment(text, image),
inputs=[gr.Textbox(),gr.Image()],
outputs=['text'],
title='Multilingual-Multimodal-Sentiment-Analysis',
examples= [],
description='Get the positive/neutral/negative sentiment for the given input.'
)
interface.launch(inline = False)