sohaibcs1 commited on
Commit
06fdca2
·
1 Parent(s): caa476c

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -272
app.py DELETED
@@ -1,272 +0,0 @@
1
- import os
2
- os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
3
- os.system("gdown https://drive.google.com/uc?id=1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX")
4
- import clip
5
- import os
6
- from torch import nn
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as nnf
10
- import sys
11
- from typing import Tuple, List, Union, Optional
12
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
13
- from tqdm import tqdm, trange
14
- import skimage.io as io
15
- import PIL.Image
16
- import gradio as gr
17
-
18
- N = type(None)
19
- V = np.array
20
- ARRAY = np.ndarray
21
- ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
22
- VS = Union[Tuple[V, ...], List[V]]
23
- VN = Union[V, N]
24
- VNS = Union[VS, N]
25
- T = torch.Tensor
26
- TS = Union[Tuple[T, ...], List[T]]
27
- TN = Optional[T]
28
- TNS = Union[Tuple[TN, ...], List[TN]]
29
- TSN = Optional[TS]
30
- TA = Union[T, ARRAY]
31
-
32
-
33
- D = torch.device
34
- CPU = torch.device('cpu')
35
-
36
-
37
- def get_device(device_id: int) -> D:
38
- if not torch.cuda.is_available():
39
- return CPU
40
- device_id = min(torch.cuda.device_count() - 1, device_id)
41
- return torch.device(f'cuda:{device_id}')
42
-
43
-
44
- CUDA = get_device
45
-
46
- class MLP(nn.Module):
47
-
48
- def forward(self, x: T) -> T:
49
- return self.model(x)
50
-
51
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
52
- super(MLP, self).__init__()
53
- layers = []
54
- for i in range(len(sizes) -1):
55
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
56
- if i < len(sizes) - 2:
57
- layers.append(act())
58
- self.model = nn.Sequential(*layers)
59
-
60
-
61
- class ClipCaptionModel(nn.Module):
62
-
63
- #@functools.lru_cache #FIXME
64
- def get_dummy_token(self, batch_size: int, device: D) -> T:
65
- return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
66
-
67
- def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
68
- embedding_text = self.gpt.transformer.wte(tokens)
69
- prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
70
- #print(embedding_text.size()) #torch.Size([5, 67, 768])
71
- #print(prefix_projections.size()) #torch.Size([5, 1, 768])
72
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
73
- if labels is not None:
74
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
75
- labels = torch.cat((dummy_token, tokens), dim=1)
76
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
77
- return out
78
-
79
- def __init__(self, prefix_length: int, prefix_size: int = 512):
80
- super(ClipCaptionModel, self).__init__()
81
- self.prefix_length = prefix_length
82
- self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
83
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
84
- if prefix_length > 10: # not enough memory
85
- self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
86
- else:
87
- self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
88
-
89
-
90
- class ClipCaptionPrefix(ClipCaptionModel):
91
-
92
- def parameters(self, recurse: bool = True):
93
- return self.clip_project.parameters()
94
-
95
- def train(self, mode: bool = True):
96
- super(ClipCaptionPrefix, self).train(mode)
97
- self.gpt.eval()
98
- return self
99
-
100
-
101
- #@title Caption prediction
102
-
103
- def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
104
- entry_length=67, temperature=1., stop_token: str = '.'):
105
-
106
- model.eval()
107
- stop_token_index = tokenizer.encode(stop_token)[0]
108
- tokens = None
109
- scores = None
110
- device = next(model.parameters()).device
111
- seq_lengths = torch.ones(beam_size, device=device)
112
- is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
113
- with torch.no_grad():
114
- if embed is not None:
115
- generated = embed
116
- else:
117
- if tokens is None:
118
- tokens = torch.tensor(tokenizer.encode(prompt))
119
- tokens = tokens.unsqueeze(0).to(device)
120
- generated = model.gpt.transformer.wte(tokens)
121
- for i in range(entry_length):
122
- outputs = model.gpt(inputs_embeds=generated)
123
- logits = outputs.logits
124
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
125
- logits = logits.softmax(-1).log()
126
- if scores is None:
127
- scores, next_tokens = logits.topk(beam_size, -1)
128
- generated = generated.expand(beam_size, *generated.shape[1:])
129
- next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
130
- if tokens is None:
131
- tokens = next_tokens
132
- else:
133
- tokens = tokens.expand(beam_size, *tokens.shape[1:])
134
- tokens = torch.cat((tokens, next_tokens), dim=1)
135
- else:
136
- logits[is_stopped] = -float(np.inf)
137
- logits[is_stopped, 0] = 0
138
- scores_sum = scores[:, None] + logits
139
- seq_lengths[~is_stopped] += 1
140
- scores_sum_average = scores_sum / seq_lengths[:, None]
141
- scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
142
- next_tokens_source = next_tokens // scores_sum.shape[1]
143
- seq_lengths = seq_lengths[next_tokens_source]
144
- next_tokens = next_tokens % scores_sum.shape[1]
145
- next_tokens = next_tokens.unsqueeze(1)
146
- tokens = tokens[next_tokens_source]
147
- tokens = torch.cat((tokens, next_tokens), dim=1)
148
- generated = generated[next_tokens_source]
149
- scores = scores_sum_average * seq_lengths
150
- is_stopped = is_stopped[next_tokens_source]
151
- next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
152
- generated = torch.cat((generated, next_token_embed), dim=1)
153
- is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
154
- if is_stopped.all():
155
- break
156
- scores = scores / seq_lengths
157
- output_list = tokens.cpu().numpy()
158
- output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
159
- order = scores.argsort(descending=True)
160
- output_texts = [output_texts[i] for i in order]
161
- return output_texts
162
-
163
-
164
- def generate2(
165
- model,
166
- tokenizer,
167
- tokens=None,
168
- prompt=None,
169
- embed=None,
170
- entry_count=1,
171
- entry_length=67, # maximum number of words
172
- top_p=0.8,
173
- temperature=1.,
174
- stop_token: str = '.',
175
- ):
176
- model.eval()
177
- generated_num = 0
178
- generated_list = []
179
- stop_token_index = tokenizer.encode(stop_token)[0]
180
- filter_value = -float("Inf")
181
- device = next(model.parameters()).device
182
-
183
- with torch.no_grad():
184
-
185
- for entry_idx in trange(entry_count):
186
- if embed is not None:
187
- generated = embed
188
- else:
189
- if tokens is None:
190
- tokens = torch.tensor(tokenizer.encode(prompt))
191
- tokens = tokens.unsqueeze(0).to(device)
192
-
193
- generated = model.gpt.transformer.wte(tokens)
194
-
195
- for i in range(entry_length):
196
-
197
- outputs = model.gpt(inputs_embeds=generated)
198
- logits = outputs.logits
199
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
200
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
201
- cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
202
- sorted_indices_to_remove = cumulative_probs > top_p
203
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
204
- ..., :-1
205
- ].clone()
206
- sorted_indices_to_remove[..., 0] = 0
207
-
208
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
209
- logits[:, indices_to_remove] = filter_value
210
- next_token = torch.argmax(logits, -1).unsqueeze(0)
211
- next_token_embed = model.gpt.transformer.wte(next_token)
212
- if tokens is None:
213
- tokens = next_token
214
- else:
215
- tokens = torch.cat((tokens, next_token), dim=1)
216
- generated = torch.cat((generated, next_token_embed), dim=1)
217
- if stop_token_index == next_token.item():
218
- break
219
-
220
- output_list = list(tokens.squeeze().cpu().numpy())
221
- output_text = tokenizer.decode(output_list)
222
- generated_list.append(output_text)
223
-
224
- return generated_list[0]
225
-
226
- is_gpu = False
227
- device = CUDA(0) if is_gpu else "cpu"
228
- clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
229
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
230
-
231
- def inference(img,model_name):
232
- prefix_length = 10
233
-
234
- model = ClipCaptionModel(prefix_length)
235
-
236
- if model_name == "COCO":
237
- model_path = 'coco_weights.pt'
238
- else:
239
- model_path = 'conceptual_weights.pt'
240
- model.load_state_dict(torch.load(model_path, map_location=CPU))
241
- model = model.eval()
242
- device = CUDA(0) if is_gpu else "cpu"
243
- model = model.to(device)
244
-
245
- use_beam_search = False
246
- image = io.imread(img.name)
247
- pil_image = PIL.Image.fromarray(image)
248
- image = preprocess(pil_image).unsqueeze(0).to(device)
249
- with torch.no_grad():
250
- prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
251
- prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
252
- if use_beam_search:
253
- generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
254
- else:
255
- generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
256
- return generated_text_prefix
257
-
258
- title = "ImageSummarizer"
259
- description = "Gradio demo for Image Summarizer: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
260
-
261
-
262
- examples=[['water.jpeg',"COCO"]]
263
- gr.Interface(
264
- inference,
265
- [gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices=["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
266
- gr.outputs.Textbox(label="Output"),
267
- title=title,
268
- description=description,
269
- article=article,
270
- enable_queue=True,
271
- examples=examples
272
- ).launch(debug=True)