File size: 5,207 Bytes
edb3fc9
 
0c2f64c
edb3fc9
 
0c2f64c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edb3fc9
 
0c2f64c
edb3fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c2f64c
edb3fc9
 
 
 
 
 
 
 
 
 
 
0c2f64c
e42c93d
7a02a84
edb3fc9
 
 
 
e42c93d
 
edb3fc9
 
e42c93d
 
edb3fc9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from torchvision import transforms
import json
from torch import nn
from typing import Literal

# Define Multimodal Classifier
class MultimodalClassifier(nn.Module):
    def __init__(
        self,
        text_encoder_id_or_path: str,
        image_encoder_id_or_path: str,
        projection_dim: int,
        fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat",
        proj_dropout: float = 0.1,
        fusion_dropout: float = 0.1,
        num_classes: int = 1,
    ) -> None:
        super().__init__()

        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.num_classes = num_classes

        # Text Encoder
        self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path)
        self.text_projection = nn.Sequential(
            nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim),
            nn.Dropout(proj_dropout),
        )

        # Image Encoder
        self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True)
        self.image_encoder.classifier = nn.Identity()  # Remove classification head
        self.image_projection = nn.Sequential(
            nn.Linear(512, self.projection_dim),
            nn.Dropout(proj_dropout),
        )

        # Fusion Layer
        fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim
        self.fusion_layer = nn.Sequential(
            nn.Dropout(fusion_dropout),
            nn.Linear(fusion_input_dim, self.projection_dim),
            nn.GELU(),
            nn.Dropout(fusion_dropout),
        )

        # Classification Layer
        self.classifier = nn.Linear(self.projection_dim, self.num_classes)

    def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        # Text Encoder Projection
        full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
        full_text_features = full_text_features[:, 0, :]  # CLS token
        full_text_features = self.text_projection(full_text_features)

        # Image Encoder Projection
        resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state
        resnet_image_features = resnet_image_features.mean(dim=[-2, -1])  # Global average pooling
        resnet_image_features = self.image_projection(resnet_image_features)

        # Fusion
        if self.fusion_method == "concat":
            fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1)
        else:
            fused_features = full_text_features * resnet_image_features

        # Classification
        fused_features = self.fusion_layer(fused_features)
        classification_output = self.classifier(fused_features)
        return classification_output

# Load the model
def load_model():
    with open("config.json", "r") as f:
        config = json.load(f)

    model = MultimodalClassifier(
        text_encoder_id_or_path=config["text_encoder_id_or_path"],
        image_encoder_id_or_path="microsoft/resnet-34",
        projection_dim=config["projection_dim"],
        fusion_method=config["fusion_method"],
        proj_dropout=config["proj_dropout"],
        fusion_dropout=config["fusion_dropout"],
        num_classes=config["num_classes"]
    )

    checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint, strict=False)

    return model

# Load model and tokenizer
model = load_model()
model.eval()
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Image transform pipeline
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Prediction function
def predict(image: Image.Image, text: str) -> str:
    # Process text input
    text_inputs = text_tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=512
    )
    
    # Process image input
    image_input = image_transform(image).unsqueeze(0)  # Add batch dimension

    # Model inference
    with torch.no_grad():
        classification_output = model(
            pixel_values=image_input,
            input_ids=text_inputs["input_ids"],
            attention_mask=text_inputs["attention_mask"]
        )
        predicted_class = torch.sigmoid(classification_output).round().item()
    
    return "Fake News" if predicted_class == 1 else "Real News"

# Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Related Image"),
        gr.Textbox(lines=2, placeholder="Enter news text for classification...", label="Input Text")
    ],
    outputs=gr.Label(label="Prediction"),
    title="Fake News Detector",
    description="Upload an image and provide text to classify the news as 'Fake' or 'Real'."
)

interface.launch()