File size: 5,070 Bytes
d7174bf
70b16bd
d7174bf
 
 
 
 
70b16bd
0833e94
738ac11
 
 
d7174bf
738ac11
70b16bd
 
 
 
0833e94
70b16bd
 
 
 
5bc82d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738ac11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0833e94
 
 
738ac11
 
 
846b6b6
738ac11
 
 
 
320f206
 
738ac11
 
e999bb0
 
 
 
 
70b16bd
 
05e22ed
 
 
0833e94
 
738ac11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e55b12
738ac11
 
0833e94
 
e999bb0
738ac11
0833e94
70b16bd
 
 
6e8de0e
4d4f267
70b16bd
 
80ad019
70b16bd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Optional
import os
os.environ["WANDB_DISABLED"] = "true"

import numpy as np

from PIL import Image
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPModel, AutoModel

from huggingface_hub import hf_hub_download
from safetensors.torch import load_model

from datasets import load_dataset, load_metric
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    logging,
)

class Transform(torch.nn.Module):
    def __init__(self, image_size, mean, std):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(mean, std),
        )

    def forward(self, x) -> torch.Tensor:
        """`x` should be an instance of `PIL.Image.Image`"""
        with torch.no_grad():
            x = self.transforms(x)
        return x



class VisionTextDualEncoderModel(nn.Module):
    def __init__(self, num_classes):
        super(VisionTextDualEncoderModel, self).__init__()

        # Load the XLM-RoBERTa model
        self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")

        # Define your vision model (e.g., using torchvision)
        self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        vision_output_dim = self.vision_encoder.config.vision_config.hidden_size


        # Combine the modalities
        self.fc = nn.Linear(
            self.text_encoder.config.hidden_size + vision_output_dim, num_classes
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
    ):
        # Encode text inputs
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        ).pooler_output

        # Encode vision inputs
        vision_outputs = self.vision_encoder.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Concatenate text and vision features
        combined_features = torch.cat(
            (text_outputs, vision_outputs.pooler_output), dim=1
        )

        # Forward through a linear layer for classification
        logits = self.fc(combined_features)

        return {"logits": logits}

id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}

tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")

model = VisionTextDualEncoderModel(num_classes=3)
config = model.vision_encoder.config

# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors")

load_model(model, sf_filename) 
# model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin"))


# model = AutoModelForSequenceClassification.from_pretrained(
#         "FFZG-cleopatra/M2SA",
#         num_labels=3, id2label=id2label,
#         label2id=label2id
#     )



image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

def predict_sentiment(text, image):
    print(text, image)
    text_inputs = tokenizer(
            text,
            max_length=512,
            padding="max_length",
            truncation=True,
        )
    
    image_transformations = Transform(
        config.vision_config.image_size,
        image_processor.image_mean,
        image_processor.image_std,
    )
    image_transformations = torch.jit.script(image_transformations)
    image = image_transformations(image)
    model_input = {
        "input_ids" : text_inputs.input_ids,
        "pixel_values":image,
        "attention_mask" : text_inputs.attention_mask,
    }
    prediction = None
    with torch.no_grad():
        prediction = model(*model_input)
        print(prediction)
    return prediction


interface = gr.Interface(
    fn=lambda text, image: predict_sentiment(text, image),
    inputs=[gr.Textbox(),gr.Image()],
    outputs=['text'],
    title='Multilingual-Multimodal-Sentiment-Analysis',
    examples= [],
    description='Get the positive/neutral/negative sentiment for the given input.'
)

interface.launch(inline = False)