thak123's picture
Update app.py
79c3b28 verified
raw
history blame contribute delete
No virus
5.24 kB
from typing import Optional
import os
os.environ["WANDB_DISABLED"] = "true"
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
from torchvision.io import ImageReadMode, read_image
from transformers import CLIPModel, AutoModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
from datasets import load_dataset, load_metric
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForSequenceClassification,
AutoTokenizer,
logging,
)
class Transform(torch.nn.Module):
def __init__(self, image_size, mean, std):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(mean, std),
)
def forward(self, x) -> torch.Tensor:
"""`x` should be an instance of `PIL.Image.Image`"""
with torch.no_grad():
x = self.transforms(x)
return x
class VisionTextDualEncoderModel(nn.Module):
def __init__(self, num_classes):
super(VisionTextDualEncoderModel, self).__init__()
# Load the XLM-RoBERTa model
self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
# Define your vision model (e.g., using torchvision)
self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_output_dim = self.vision_encoder.config.vision_config.hidden_size
# Combine the modalities
self.fc = nn.Linear(
self.text_encoder.config.hidden_size + vision_output_dim, num_classes
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
):
# Encode text inputs
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
).pooler_output
# Encode vision inputs
vision_outputs = self.vision_encoder.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Concatenate text and vision features
combined_features = torch.cat(
(text_outputs, vision_outputs.pooler_output), dim=1
)
# Forward through a linear layer for classification
logits = self.fc(combined_features)
return {"logits": logits}
id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")
model = VisionTextDualEncoderModel(num_classes=3)
config = model.vision_encoder.config
# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors")
load_model(model, sf_filename)
image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
def predict_sentiment(text, image):
print(text, image)
image = read_image(image, mode=ImageReadMode.RGB)
text_inputs = tokenizer(
text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
image_transformations = Transform(
config.vision_config.image_size,
image_processor.image_mean,
image_processor.image_std,
)
image_transformations = torch.jit.script(image_transformations)
pixel_values = image_transformations(image)
text_inputs["pixel_values"] = pixel_values.unsqueeze(0)
prediction = None
with torch.no_grad():
outputs = model(**text_inputs)
print(outputs)
prediction = np.argmax(outputs["logits"], axis=-1)
print(id2label[prediction[0].item()])
return id2label[prediction[0].item()]
interface = gr.Interface(
fn=lambda text, image: predict_sentiment(text, image),
inputs=[gr.Textbox(),gr.Image(type="filepath")],
outputs=['text'],
title='Multilingual Multimodal Sentiment Analysis',
examples= [["I am enjoying","A_Sep20_14_1189155141.jpg"]],
description='Get the positive/neutral/negative sentiment for the given input.'
)
interface.launch(inline = False)