File size: 5,883 Bytes
0516286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814d94e
02dd597
0516286
00038d3
0516286
 
 
 
 
00038d3
0516286
00038d3
 
0516286
00038d3
0516286
 
 
 
 
 
 
 
 
 
00038d3
0516286
00038d3
 
0516286
00038d3
0516286
00038d3
 
0516286
 
 
00038d3
0516286
 
 
 
 
00038d3
0516286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00038d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import gradio as gr
import warnings
import streamlit as st
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertForSequenceClassification, AutoModelForSeq2SeqLM
from tqdm import tqdm
from torchvision import models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from keras.preprocessing import image
from torchmetrics.classification import MultilabelF1Score
from sklearn.metrics import average_precision_score, ndcg_score


# In[2]:


warnings.filterwarnings("ignore")


# In[3]:


genres = ["Crime", "Thriller", "Fantasy", "Horror", "Sci-Fi", "Comedy", "Documentary", "Adventure", "Film-Noir", "Animation", "Romance", "Drama", "Western", "Musical", "Action", "Mystery", "War", "Children\'s"]
mapping = {}
for i in range(len(genres)):
    mapping[i] = genres[i]
mapping


# In[4]:


tokenizer_gen = AutoTokenizer.from_pretrained("MBZUAI/LaMini-Flan-T5-248M")
model_gen = AutoModelForSeq2SeqLM.from_pretrained("MBZUAI/LaMini-Flan-T5-248M")

tokenizer1 = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model1 = DistilBertForSequenceClassification .from_pretrained("distilbert-base-uncased", problem_type="multi_label_classification", num_labels=18)
model1.config.id2label = mapping

tokenizer2 = AutoTokenizer.from_pretrained("dduy193/plot-classification")
model2 = AutoModelForSequenceClassification.from_pretrained("dduy193/plot-classification")
model2.config.id2label = mapping

model3 = models.resnet101(pretrained=False)
model3.fc = torch.nn.Linear(2048, len(genres))


# In[5]:


class Multimodal(torch.nn.Module):
    def __init__(self, model1, model2, model3):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.model3 = model3
        self.fc1 = torch.nn.Linear(18, 18)
        self.fc2 = torch.nn.Linear(18, 18)
        self.fc3 = torch.nn.Linear(18, 18)

    def forward(self, 
                title_input_ids, title_attention_mask,
                plot_input_ids, plot_attention_mask,
                image_input):
        title_output = self.model1(title_input_ids, title_attention_mask)
        plot_output = self.model2(plot_input_ids, plot_attention_mask)
        image_output = self.model3(image_input)

        title_output = self.fc1(title_output.logits)
        plot_output = self.fc2(plot_output.logits)
        image_output = self.fc3(image_output)
        
        output = torch.add(title_output, plot_output)
        output = torch.add(output, image_output)
        return output

# **_PLEASE INSTALL THE MODEL CHECKPOINT FROM THE LINK IN README.txt_**

# In[7]:

model = Multimodal(model1, model2, model3)
model.load_state_dict(torch.load('multimodel.pt', map_location=torch.device('cpu')))
model.eval()
device = torch.device('cpu')


# In[8]:


def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, device) -> str:
    quote = 'What is the story of the movie {}?'
    model_gen.to(device)
    model_gen.eval()

    input_ids = tokenizer(quote.format(title), return_tensors='pt').input_ids.to(device)
    output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09)
    return tokenizer.decode(output[0], skip_special_tokens=True)


# In[9]:


def inference(title, image, 
              tokenizer1=tokenizer1, tokenizer2=tokenizer2, tokenizer_gen=tokenizer_gen,
              model_gen=model_gen, model=model, 
              genres=genres, device=device):
    title_input = tokenizer1(title, return_tensors='pt', padding=True, truncation=True)
    title_input_ids = title_input['input_ids'].to(device)
    title_attention_mask = title_input['attention_mask'].to(device)

    plot = generate_plot(title, model_gen, tokenizer_gen, device)
    plot_input = tokenizer2(plot, return_tensors='pt', padding=True, truncation=True)
    plot_input_ids = plot_input['input_ids'].to(device)
    plot_attention_mask = plot_input['attention_mask'].to(device)

    # If image is not uploaded
    if image is None:
        image_input = torch.zeros((1, 3, 224, 224)).to(device)

    else:
        image_input = image.resize((224, 224))
        image_input = v2.ToTensor()(image_input)
        image_input = image_input.unsqueeze(0)
        image_input = image_input.to(device)

    output = model(title_input_ids, title_attention_mask, plot_input_ids, plot_attention_mask, image_input)
    output = torch.sigmoid(output)
    output = output.cpu().detach().numpy()
    output = np.where(output > 0.5, 1, 0)
    output = output.squeeze()
    output = np.where(output == 1)[0]
    output = [genres[i] for i in output]
    return output


# In[10]:


app = gr.Interface(fn=inference, inputs=["text", "pil"], outputs="text", title="Movie Genre Classification", 
                   description="This model classifies the genre of a movie based on its title and poster.", 
                   examples=[["The Matrix", "https://upload.wikimedia.org/wikipedia/en/c/c1/The_Matrix_Poster.jpg"],
                             ["The Dark Knight", "https://upload.wikimedia.org/wikipedia/en/1/1c/The_Dark_Knight_%282008_film%29.jpg"],
                             ["The Godfather", "https://upload.wikimedia.org/wikipedia/en/1/1c/Godfather_ver1.jpg"],
                             ["The Shawshank Redemption", "https://upload.wikimedia.org/wikipedia/en/8/81/ShawshankRedemptionMoviePoster.jpg"],
                             ["The Lord of the Rings: The Return of the King", "https://upload.wikimedia.org/wikipedia/en/2/23/The_Lord_of_the_Rings%2C_TROTK_%282003%29.jpg"],
                             ["The Godfather: Part II", "https://upload.wikimedia.org/wikipedia/en/0/03/Godfather_part_ii.jpg"]])


# In[11]:


app.launch(share=True)