ucaslcl commited on
Commit
8af0d0c
1 Parent(s): 5bc9b12

Upload 13 files

Browse files
GOT_ocr_2_0.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, \
2
+ Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \
3
+ CLIPVisionModel, CLIPImageProcessor
4
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
5
+ from typing import List, Optional, Tuple, Union
6
+ from transformers.cache_utils import Cache, DynamicCache
7
+ # import sys
8
+ # import os
9
+ # sys.path.append(os.path.dirname(__file__))
10
+ # print(os.path.dirname(__file__))
11
+ # sys.path.append('/data/code/a2hf/GOT-OCR2_0')
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import CrossEntropyLoss
17
+ from .constants import *
18
+ from .vary_b import build_vary_vit_b
19
+ from .blip_process import BlipImageEvalProcessor
20
+ from .run_ocr import *
21
+
22
+ class GOTConfig(Qwen2Config):
23
+ model_type = "GOT"
24
+
25
+
26
+ class GOTQwenModel(Qwen2Model):
27
+ config_class = GOTConfig
28
+
29
+ def __init__(self, config: Qwen2Config):
30
+ super(GOTQwenModel, self).__init__(config)
31
+
32
+ self.vision_tower_high = build_vary_vit_b()
33
+
34
+ self.mm_projector_vary = nn.Linear(1024, 1024)
35
+
36
+
37
+ def initialize_vision_modules(
38
+ self,
39
+ vision_tower,
40
+ pretrained_stage1_model=None,
41
+ freeze_vision_tower=False,
42
+ use_im_start_end=False,
43
+ vision_select_layer=-1,
44
+ dtype=torch.float16,
45
+ device="cuda"
46
+ ):
47
+
48
+ # Vary old codes, not use in GOT
49
+ image_processor = BlipImageEvalProcessor(image_size=1024)
50
+ # 1024*1024
51
+
52
+ image_processor_high = BlipImageEvalProcessor(image_size=1024)
53
+
54
+
55
+
56
+ self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
57
+
58
+ self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
59
+
60
+
61
+ image_token_len = 256
62
+
63
+ self.config.vision_tower = vision_tower
64
+ self.config.image_token_len = image_token_len
65
+ # self.config.use_im_start_end = use_im_start_end
66
+ self.config.use_im_start_end = True
67
+
68
+ self.config.vision_select_layer = vision_select_layer
69
+ self.config.freeze_vision_tower = freeze_vision_tower
70
+
71
+ return dict(
72
+ image_processor=image_processor,
73
+ image_processor_high=image_processor_high,
74
+ image_token_len=image_token_len,
75
+ )
76
+
77
+ # def get_input_embeddings(self, x):
78
+ # return self.wte(x)
79
+
80
+ def forward(
81
+ self,
82
+ input_ids: torch.LongTensor = None,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ position_ids: Optional[torch.LongTensor] = None,
85
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
86
+ inputs_embeds: Optional[torch.FloatTensor] = None,
87
+ use_cache: Optional[bool] = None,
88
+ output_attentions: Optional[bool] = None,
89
+ output_hidden_states: Optional[bool] = None,
90
+ images: Optional[torch.FloatTensor] = None,
91
+ return_dict: Optional[bool] = None,
92
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
93
+
94
+ # HACK: replace back original embeddings for LLaVA pretraining
95
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
96
+ if orig_embeds_params is not None:
97
+ with torch.no_grad():
98
+ self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
99
+
100
+ if inputs_embeds is None:
101
+ inputs_embeds = self.embed_tokens(input_ids)
102
+
103
+
104
+ vision_tower_high = getattr(self, 'vision_tower_high', None)
105
+
106
+
107
+ if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
108
+ # if True:
109
+ # assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images")
110
+ # print(im)
111
+ use_im_start_end = getattr(self.config, "use_im_start_end", -1)
112
+
113
+ vision_select_layer = getattr(self.config, "vision_select_layer", -1)
114
+ im_patch_token = getattr(self.config, "im_patch_token", -1)
115
+ im_start_token = getattr(self.config, "im_start_token", -1)
116
+ im_end_token = getattr(self.config, "im_end_token", -1)
117
+ freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
118
+
119
+ im_patch_token = 151859
120
+
121
+ im_start_token = 151857
122
+
123
+ im_end_token = 151858
124
+
125
+
126
+
127
+ image_features = []
128
+
129
+
130
+ for image in images:
131
+ P, C, H, W = image[1].shape
132
+ # with torch.set_grad_enabled(True):
133
+ # # print(image[1].shape)
134
+ # cnn_feature = vision_tower_high(image[1])
135
+ # cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256 1024
136
+ # # image_features.append(cnn_feature)
137
+ # image_features_2.append(cnn_feature)
138
+ if P == 1:
139
+ with torch.set_grad_enabled(False):
140
+ # print(image[1].shape)
141
+ cnn_feature = vision_tower_high(image[1])
142
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
143
+ # image_features.append(cnn_feature)
144
+ # image_features_2.append(cnn_feature)
145
+ image_feature = self.mm_projector_vary(cnn_feature)
146
+ image_features.append(image_feature)
147
+
148
+ else:
149
+ image_patches = torch.unbind(image[1])
150
+ image_patches_features = []
151
+ for image_patch in image_patches:
152
+ image_p = torch.stack([image_patch])
153
+ with torch.set_grad_enabled(False):
154
+ cnn_feature_p = vision_tower_high(image_p)
155
+ cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
156
+ image_feature_p = self.mm_projector_vary(cnn_feature_p)
157
+ image_patches_features.append(image_feature_p)
158
+ image_feature = torch.cat(image_patches_features, dim=1)
159
+ # print(P)
160
+ # print(image_feature.shape)
161
+ # exit()
162
+ image_features.append(image_feature)
163
+
164
+
165
+
166
+ dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
167
+ # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
168
+ dummy_image_features = dummy_image_features_2
169
+ use_im_start_end = True
170
+ new_input_embeds = []
171
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
172
+ if (cur_input_ids == im_patch_token).sum() == 0:
173
+ # multimodal LLM, but the current sample is not multimodal
174
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
175
+ new_input_embeds.append(cur_input_embeds)
176
+ continue
177
+
178
+ if use_im_start_end:
179
+ if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
180
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
181
+
182
+ image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
183
+ for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
184
+ per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
185
+ num_patches = per_cur_image_features.shape[0]
186
+
187
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
188
+ raise ValueError("The image end token should follow the image start token.")
189
+
190
+ cur_input_embeds = torch.cat(
191
+ (
192
+ cur_input_embeds[:image_start_token_pos+1],
193
+ per_cur_image_features,
194
+ cur_input_embeds[image_start_token_pos + num_patches + 1:]
195
+ ),
196
+ dim=0
197
+ )
198
+
199
+
200
+ new_input_embeds.append(cur_input_embeds)
201
+ else:
202
+ raise NotImplementedError
203
+
204
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
205
+
206
+ return super(GOTQwenModel, self).forward(
207
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
208
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
209
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
210
+ return_dict=return_dict
211
+ )
212
+
213
+
214
+
215
+ class GOTQwenForCausalLM(Qwen2ForCausalLM):
216
+ config_class = GOTConfig
217
+ # supports_gradient_checkpointing = True
218
+
219
+ def __init__(self, config):
220
+ super(Qwen2ForCausalLM, self).__init__(config)
221
+ self.model = GOTQwenModel(config)
222
+
223
+ self.vocab_size = config.vocab_size
224
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
225
+
226
+ # Initialize weights and apply final processing
227
+ self.post_init()
228
+
229
+ def get_model(self):
230
+ return self.model
231
+
232
+ # def _set_gradient_checkpointing(self, module, value=False):
233
+ # if isinstance(module, GOTQwenModel):
234
+ # module.gradient_checkpointing = value
235
+ # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
236
+ # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
237
+ def forward(
238
+ self,
239
+ input_ids: torch.LongTensor = None,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ position_ids: Optional[torch.LongTensor] = None,
242
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
243
+ inputs_embeds: Optional[torch.FloatTensor] = None,
244
+ labels: Optional[torch.LongTensor] = None,
245
+ use_cache: Optional[bool] = None,
246
+ output_attentions: Optional[bool] = None,
247
+ output_hidden_states: Optional[bool] = None,
248
+ images: Optional[torch.FloatTensor] = None,
249
+ return_dict: Optional[bool] = None,
250
+
251
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
252
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
253
+ output_hidden_states = (
254
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
255
+ )
256
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
257
+
258
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
259
+ # print(input_ids)
260
+ # print(len(images))
261
+
262
+ # print(inputs_embeds)
263
+
264
+ outputs = self.model(
265
+ input_ids=input_ids,
266
+ past_key_values=past_key_values,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ inputs_embeds=inputs_embeds,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=output_hidden_states,
273
+ images=images,
274
+ return_dict=return_dict
275
+
276
+ )
277
+
278
+
279
+ hidden_states = outputs[0]
280
+ logits = self.lm_head(hidden_states)
281
+ logits = logits.float()
282
+
283
+ # logits
284
+
285
+ loss = None
286
+ if labels is not None:
287
+ # Shift so that tokens < n predict n
288
+ shift_logits = logits[..., :-1, :].contiguous()
289
+ shift_labels = labels[..., 1:].contiguous()
290
+ # Flatten the tokens
291
+ loss_fct = CrossEntropyLoss()
292
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
293
+ shift_labels = shift_labels.view(-1)
294
+ # Enable model parallelism
295
+ shift_labels = shift_labels.to(shift_logits.device)
296
+ loss = loss_fct(shift_logits, shift_labels)
297
+
298
+ if not return_dict:
299
+ output = (logits,) + outputs[1:]
300
+ return (loss,) + output if loss is not None else output
301
+
302
+ return CausalLMOutputWithPast(
303
+ loss=loss,
304
+ logits=logits,
305
+ past_key_values=outputs.past_key_values,
306
+ hidden_states=outputs.hidden_states,
307
+ attentions=outputs.attentions,
308
+ )
309
+
310
+
311
+ def prepare_inputs_for_generation(
312
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
313
+ ):
314
+ # Omit tokens covered by past_key_values
315
+ if past_key_values is not None:
316
+ if isinstance(past_key_values, Cache):
317
+ cache_length = past_key_values.get_seq_length()
318
+ past_length = past_key_values.seen_tokens
319
+ max_cache_length = past_key_values.get_max_length()
320
+ else:
321
+ cache_length = past_length = past_key_values[0][0].shape[2]
322
+ max_cache_length = None
323
+
324
+ # Keep only the unprocessed tokens:
325
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
326
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
327
+ # input)
328
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
329
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
330
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
331
+ # input_ids based on the past_length.
332
+ elif past_length < input_ids.shape[1]:
333
+ input_ids = input_ids[:, past_length:]
334
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
335
+
336
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
337
+ if (
338
+ max_cache_length is not None
339
+ and attention_mask is not None
340
+ and cache_length + input_ids.shape[1] > max_cache_length
341
+ ):
342
+ attention_mask = attention_mask[:, -max_cache_length:]
343
+
344
+ position_ids = kwargs.get("position_ids", None)
345
+ if attention_mask is not None and position_ids is None:
346
+ # create position_ids on the fly for batch generation
347
+ position_ids = attention_mask.long().cumsum(-1) - 1
348
+ position_ids.masked_fill_(attention_mask == 0, 1)
349
+ if past_key_values:
350
+ position_ids = position_ids[:, -input_ids.shape[1] :]
351
+
352
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
353
+ if inputs_embeds is not None and past_key_values is None:
354
+ model_inputs = {"inputs_embeds": inputs_embeds}
355
+ else:
356
+ model_inputs = {"input_ids": input_ids}
357
+
358
+ model_inputs.update(
359
+ {
360
+ "position_ids": position_ids,
361
+ "past_key_values": past_key_values,
362
+ "use_cache": kwargs.get("use_cache"),
363
+ "attention_mask": attention_mask,
364
+ "images": kwargs.get("images", None),
365
+ }
366
+ )
367
+ return model_inputs
368
+
369
+ def initialize_vision_tokenizer(
370
+ self,
371
+ tokenizer,
372
+ freeze_lm_model=False,
373
+ pretrained_stage1_model=None,
374
+ device="cuda"
375
+ ):
376
+ config = self.get_model().config
377
+
378
+ # add image patch token <image>
379
+ # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
380
+ self.resize_token_embeddings(len(tokenizer))
381
+ # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
382
+
383
+ config.im_patch_token = 151859
384
+
385
+ config.use_im_start_end = True
386
+
387
+ # add image start token <im_start> and end token <im_end>
388
+ if config.use_im_start_end:
389
+ # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
390
+ self.resize_token_embeddings(len(tokenizer))
391
+ # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
392
+
393
+ config.im_start_token, config.im_end_token = 151857, 151858
394
+
395
+
396
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False):
397
+ # Model
398
+ disable_torch_init()
399
+ # model_name = os.path.expanduser(args.model_name)
400
+
401
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
402
+ # model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
403
+ # model.to(device='cuda', dtype=torch.bfloat16)
404
+
405
+
406
+ # TODO vary old codes, NEED del
407
+ image_processor = BlipImageEvalProcessor(image_size=1024)
408
+
409
+ image_processor_high = BlipImageEvalProcessor(image_size=1024)
410
+
411
+ use_im_start_end = True
412
+
413
+ image_token_len = 256
414
+
415
+ image = load_image(image_file)
416
+
417
+ w, h = image.size
418
+ # print(image.size)
419
+
420
+ if ocr_type == 'format':
421
+ qs = 'OCR with format: '
422
+ else:
423
+ qs = 'OCR: '
424
+
425
+ if ocr_box:
426
+ bbox = eval(ocr_box)
427
+ if len(bbox) == 2:
428
+ bbox[0] = int(bbox[0]/w*1000)
429
+ bbox[1] = int(bbox[1]/h*1000)
430
+ if len(bbox) == 4:
431
+ bbox[0] = int(bbox[0]/w*1000)
432
+ bbox[1] = int(bbox[1]/h*1000)
433
+ bbox[2] = int(bbox[2]/w*1000)
434
+ bbox[3] = int(bbox[3]/h*1000)
435
+ if ocr_type == 'format':
436
+ qs = str(bbox) + ' ' + 'OCR with format: '
437
+ else:
438
+ qs = str(bbox) + ' ' + 'OCR: '
439
+
440
+ if ocr_color:
441
+ if ocr_type == 'format':
442
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
443
+ else:
444
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
445
+
446
+ if use_im_start_end:
447
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
448
+ else:
449
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
450
+
451
+
452
+
453
+ conv_mode = "mpt"
454
+ args.conv_mode = conv_mode
455
+
456
+ conv = conv_templates[args.conv_mode].copy()
457
+ conv.append_message(conv.roles[0], qs)
458
+ conv.append_message(conv.roles[1], None)
459
+ prompt = conv.get_prompt()
460
+
461
+ print(prompt)
462
+
463
+
464
+ inputs = tokenizer([prompt])
465
+
466
+
467
+ # vary old codes, no use
468
+ image_1 = image.copy()
469
+ image_tensor = image_processor(image)
470
+
471
+
472
+ image_tensor_1 = image_processor_high(image_1)
473
+
474
+
475
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
476
+
477
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
478
+ keywords = [stop_str]
479
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
480
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
481
+
482
+
483
+ with torch.autocast("cuda", dtype=torch.bfloat16):
484
+ output_ids = self.generate(
485
+ input_ids,
486
+ images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
487
+ do_sample=False,
488
+ num_beams = 1,
489
+ no_repeat_ngram_size = 20,
490
+ streamer=streamer,
491
+ max_new_tokens=4096,
492
+ stopping_criteria=[stopping_criteria]
493
+ )
494
+
495
+
496
+ if render:
497
+ print('==============rendering===============')
498
+
499
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
500
+
501
+ if outputs.endswith(stop_str):
502
+ outputs = outputs[:-len(stop_str)]
503
+ outputs = outputs.strip()
504
+
505
+ if '**kern' in outputs:
506
+ import verovio
507
+ from cairosvg import svg2png
508
+ import cv2
509
+ import numpy as np
510
+ tk = verovio.toolkit()
511
+ tk.loadData(outputs)
512
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
513
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
514
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
515
+ tk.getPageCount()
516
+ svg = tk.renderToSVG()
517
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
518
+
519
+ svg_to_html(svg, "./results/demo.html")
520
+
521
+ if ocr_type == 'format' and '**kern' not in outputs:
522
+
523
+
524
+ if '\\begin{tikzpicture}' not in outputs:
525
+ html_path = "./render_tools/" + "/content-mmd-to-html.html"
526
+ html_path_2 = "./results/demo.html"
527
+ right_num = outputs.count('\\right')
528
+ left_num = outputs.count('\left')
529
+
530
+ if right_num != left_num:
531
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
532
+
533
+
534
+ outputs = outputs.replace('"', '``').replace('$', '')
535
+
536
+ outputs_list = outputs.split('\n')
537
+ gt= ''
538
+ for out in outputs_list:
539
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
540
+
541
+ gt = gt[:-2]
542
+
543
+ with open(html_path, 'r') as web_f:
544
+ lines = web_f.read()
545
+ lines = lines.split("const text =")
546
+ new_web = lines[0] + 'const text =' + gt + lines[1]
547
+ else:
548
+ html_path = "./render_tools/" + "/tikz.html"
549
+ html_path_2 = "./results/demo.html"
550
+ outputs = outputs.translate(translation_table)
551
+ outputs_list = outputs.split('\n')
552
+ gt= ''
553
+ for out in outputs_list:
554
+ if out:
555
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
556
+ while out[-1] == ' ':
557
+ out = out[:-1]
558
+ if out is None:
559
+ break
560
+
561
+ if out:
562
+ if out[-1] != ';':
563
+ gt += out[:-1] + ';\n'
564
+ else:
565
+ gt += out + '\n'
566
+ else:
567
+ gt += out + '\n'
568
+
569
+
570
+ with open(html_path, 'r') as web_f:
571
+ lines = web_f.read()
572
+ lines = lines.split("const text =")
573
+ new_web = lines[0] + gt + lines[1]
574
+
575
+ with open(html_path_2, 'w') as web_f_new:
576
+ web_f_new.write(new_web)
577
+
578
+
579
+
580
+
581
+ AutoConfig.register("GOT", GOTConfig)
582
+ AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)
583
+
blip_process.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+ # from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from PIL import Image
17
+
18
+ class BaseProcessor:
19
+ def __init__(self):
20
+ self.transform = lambda x: x
21
+ return
22
+
23
+ def __call__(self, item):
24
+ return self.transform(item)
25
+
26
+ # @classmethod
27
+ # def from_config(cls, cfg=None):
28
+ # return cls()
29
+
30
+ # def build(self, **kwargs):
31
+ # cfg = OmegaConf.create(kwargs)
32
+
33
+ # return self.from_config(cfg)
34
+
35
+ class BlipImageBaseProcessor(BaseProcessor):
36
+ def __init__(self, mean=None, std=None):
37
+ if mean is None:
38
+ mean = (0.48145466, 0.4578275, 0.40821073)
39
+ if std is None:
40
+ std = (0.26862954, 0.26130258, 0.27577711)
41
+ # mean = (0.0, 0.0, 0.0)
42
+ # std = (1.0, 1.0, 1.0)
43
+
44
+ self.normalize = transforms.Normalize(mean, std)
45
+
46
+
47
+ ## aug functions
48
+ def identity_func(img):
49
+ return img
50
+
51
+
52
+ def autocontrast_func(img, cutoff=0):
53
+ """
54
+ same output as PIL.ImageOps.autocontrast
55
+ """
56
+ n_bins = 256
57
+
58
+ def tune_channel(ch):
59
+ n = ch.size
60
+ cut = cutoff * n // 100
61
+ if cut == 0:
62
+ high, low = ch.max(), ch.min()
63
+ else:
64
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
65
+ low = np.argwhere(np.cumsum(hist) > cut)
66
+ low = 0 if low.shape[0] == 0 else low[0]
67
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
68
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
69
+ if high <= low:
70
+ table = np.arange(n_bins)
71
+ else:
72
+ scale = (n_bins - 1) / (high - low)
73
+ offset = -low * scale
74
+ table = np.arange(n_bins) * scale + offset
75
+ table[table < 0] = 0
76
+ table[table > n_bins - 1] = n_bins - 1
77
+ table = table.clip(0, 255).astype(np.uint8)
78
+ return table[ch]
79
+
80
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
81
+ out = cv2.merge(channels)
82
+ return out
83
+
84
+
85
+ def equalize_func(img):
86
+ """
87
+ same output as PIL.ImageOps.equalize
88
+ PIL's implementation is different from cv2.equalize
89
+ """
90
+ n_bins = 256
91
+
92
+ def tune_channel(ch):
93
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
94
+ non_zero_hist = hist[hist != 0].reshape(-1)
95
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
96
+ if step == 0:
97
+ return ch
98
+ n = np.empty_like(hist)
99
+ n[0] = step // 2
100
+ n[1:] = hist[:-1]
101
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
102
+ return table[ch]
103
+
104
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
105
+ out = cv2.merge(channels)
106
+ return out
107
+
108
+
109
+ def rotate_func(img, degree, fill=(0, 0, 0)):
110
+ """
111
+ like PIL, rotate by degree, not radians
112
+ """
113
+ H, W = img.shape[0], img.shape[1]
114
+ center = W / 2, H / 2
115
+ M = cv2.getRotationMatrix2D(center, degree, 1)
116
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
117
+ return out
118
+
119
+
120
+ def solarize_func(img, thresh=128):
121
+ """
122
+ same output as PIL.ImageOps.posterize
123
+ """
124
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
125
+ table = table.clip(0, 255).astype(np.uint8)
126
+ out = table[img]
127
+ return out
128
+
129
+
130
+ def color_func(img, factor):
131
+ """
132
+ same output as PIL.ImageEnhance.Color
133
+ """
134
+ ## implementation according to PIL definition, quite slow
135
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
136
+ # out = blend(degenerate, img, factor)
137
+ # M = (
138
+ # np.eye(3) * factor
139
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
140
+ # )[np.newaxis, np.newaxis, :]
141
+ M = np.float32(
142
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
143
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
144
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
145
+ return out
146
+
147
+
148
+ def contrast_func(img, factor):
149
+ """
150
+ same output as PIL.ImageEnhance.Contrast
151
+ """
152
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
153
+ table = (
154
+ np.array([(el - mean) * factor + mean for el in range(256)])
155
+ .clip(0, 255)
156
+ .astype(np.uint8)
157
+ )
158
+ out = table[img]
159
+ return out
160
+
161
+
162
+ def brightness_func(img, factor):
163
+ """
164
+ same output as PIL.ImageEnhance.Contrast
165
+ """
166
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
167
+ out = table[img]
168
+ return out
169
+
170
+
171
+ def sharpness_func(img, factor):
172
+ """
173
+ The differences the this result and PIL are all on the 4 boundaries, the center
174
+ areas are same
175
+ """
176
+ kernel = np.ones((3, 3), dtype=np.float32)
177
+ kernel[1][1] = 5
178
+ kernel /= 13
179
+ degenerate = cv2.filter2D(img, -1, kernel)
180
+ if factor == 0.0:
181
+ out = degenerate
182
+ elif factor == 1.0:
183
+ out = img
184
+ else:
185
+ out = img.astype(np.float32)
186
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
187
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
188
+ out = out.astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
193
+ H, W = img.shape[0], img.shape[1]
194
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
195
+ out = cv2.warpAffine(
196
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
197
+ ).astype(np.uint8)
198
+ return out
199
+
200
+
201
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
202
+ """
203
+ same output as PIL.Image.transform
204
+ """
205
+ H, W = img.shape[0], img.shape[1]
206
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
207
+ out = cv2.warpAffine(
208
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
209
+ ).astype(np.uint8)
210
+ return out
211
+
212
+
213
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
214
+ """
215
+ same output as PIL.Image.transform
216
+ """
217
+ H, W = img.shape[0], img.shape[1]
218
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
219
+ out = cv2.warpAffine(
220
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
221
+ ).astype(np.uint8)
222
+ return out
223
+
224
+
225
+ def posterize_func(img, bits):
226
+ """
227
+ same output as PIL.ImageOps.posterize
228
+ """
229
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
230
+ return out
231
+
232
+
233
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
234
+ H, W = img.shape[0], img.shape[1]
235
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
236
+ out = cv2.warpAffine(
237
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
238
+ ).astype(np.uint8)
239
+ return out
240
+
241
+
242
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
243
+ replace = np.array(replace, dtype=np.uint8)
244
+ H, W = img.shape[0], img.shape[1]
245
+ rh, rw = np.random.random(2)
246
+ pad_size = pad_size // 2
247
+ ch, cw = int(rh * H), int(rw * W)
248
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
249
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
250
+ out = img.copy()
251
+ out[x1:x2, y1:y2, :] = replace
252
+ return out
253
+
254
+
255
+ ### level to args
256
+ def enhance_level_to_args(MAX_LEVEL):
257
+ def level_to_args(level):
258
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
259
+
260
+ return level_to_args
261
+
262
+
263
+ def shear_level_to_args(MAX_LEVEL, replace_value):
264
+ def level_to_args(level):
265
+ level = (level / MAX_LEVEL) * 0.3
266
+ if np.random.random() > 0.5:
267
+ level = -level
268
+ return (level, replace_value)
269
+
270
+ return level_to_args
271
+
272
+
273
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
274
+ def level_to_args(level):
275
+ level = (level / MAX_LEVEL) * float(translate_const)
276
+ if np.random.random() > 0.5:
277
+ level = -level
278
+ return (level, replace_value)
279
+
280
+ return level_to_args
281
+
282
+
283
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
284
+ def level_to_args(level):
285
+ level = int((level / MAX_LEVEL) * cutout_const)
286
+ return (level, replace_value)
287
+
288
+ return level_to_args
289
+
290
+
291
+ def solarize_level_to_args(MAX_LEVEL):
292
+ def level_to_args(level):
293
+ level = int((level / MAX_LEVEL) * 256)
294
+ return (level,)
295
+
296
+ return level_to_args
297
+
298
+
299
+ def none_level_to_args(level):
300
+ return ()
301
+
302
+
303
+ def posterize_level_to_args(MAX_LEVEL):
304
+ def level_to_args(level):
305
+ level = int((level / MAX_LEVEL) * 4)
306
+ return (level,)
307
+
308
+ return level_to_args
309
+
310
+
311
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
312
+ def level_to_args(level):
313
+ level = (level / MAX_LEVEL) * 30
314
+ if np.random.random() < 0.5:
315
+ level = -level
316
+ return (level, replace_value)
317
+
318
+ return level_to_args
319
+
320
+
321
+ func_dict = {
322
+ "Identity": identity_func,
323
+ "AutoContrast": autocontrast_func,
324
+ "Equalize": equalize_func,
325
+ "Rotate": rotate_func,
326
+ "Solarize": solarize_func,
327
+ "Color": color_func,
328
+ "Contrast": contrast_func,
329
+ "Brightness": brightness_func,
330
+ "Sharpness": sharpness_func,
331
+ "ShearX": shear_x_func,
332
+ "TranslateX": translate_x_func,
333
+ "TranslateY": translate_y_func,
334
+ "Posterize": posterize_func,
335
+ "ShearY": shear_y_func,
336
+ }
337
+
338
+ translate_const = 10
339
+ MAX_LEVEL = 10
340
+ replace_value = (128, 128, 128)
341
+ arg_dict = {
342
+ "Identity": none_level_to_args,
343
+ "AutoContrast": none_level_to_args,
344
+ "Equalize": none_level_to_args,
345
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
346
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
347
+ "Color": enhance_level_to_args(MAX_LEVEL),
348
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
349
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
350
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
351
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
352
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
353
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
354
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
355
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
356
+ }
357
+
358
+
359
+ class RandomAugment(object):
360
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
361
+ self.N = N
362
+ self.M = M
363
+ self.isPIL = isPIL
364
+ if augs:
365
+ self.augs = augs
366
+ else:
367
+ self.augs = list(arg_dict.keys())
368
+
369
+ def get_random_ops(self):
370
+ sampled_ops = np.random.choice(self.augs, self.N)
371
+ return [(op, 0.5, self.M) for op in sampled_ops]
372
+
373
+ def __call__(self, img):
374
+ if self.isPIL:
375
+ img = np.array(img)
376
+ ops = self.get_random_ops()
377
+ for name, prob, level in ops:
378
+ if np.random.random() > prob:
379
+ continue
380
+ args = arg_dict[name](level)
381
+ img = func_dict[name](img, *args)
382
+ return img
383
+
384
+
385
+ class VideoRandomAugment(object):
386
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
387
+ self.N = N
388
+ self.M = M
389
+ self.p = p
390
+ self.tensor_in_tensor_out = tensor_in_tensor_out
391
+ if augs:
392
+ self.augs = augs
393
+ else:
394
+ self.augs = list(arg_dict.keys())
395
+
396
+ def get_random_ops(self):
397
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
398
+ return [(op, self.M) for op in sampled_ops]
399
+
400
+ def __call__(self, frames):
401
+ assert (
402
+ frames.shape[-1] == 3
403
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
404
+
405
+ if self.tensor_in_tensor_out:
406
+ frames = frames.numpy().astype(np.uint8)
407
+
408
+ num_frames = frames.shape[0]
409
+
410
+ ops = num_frames * [self.get_random_ops()]
411
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
412
+
413
+ frames = torch.stack(
414
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
415
+ ).float()
416
+
417
+ return frames
418
+
419
+ def _aug(self, img, ops, apply_or_not):
420
+ for i, (name, level) in enumerate(ops):
421
+ if not apply_or_not[i]:
422
+ continue
423
+ args = arg_dict[name](level)
424
+ img = func_dict[name](img, *args)
425
+ return torch.from_numpy(img)
426
+
427
+
428
+ # if __name__ == "__main__":
429
+ # a = RandomAugment()
430
+ # img = np.random.randn(32, 32, 3)
431
+ # a(img)
432
+
433
+
434
+
435
+
436
+
437
+
438
+ class BlipImageTrainProcessor(BlipImageBaseProcessor):
439
+ def __init__(
440
+ self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
441
+ ):
442
+ super().__init__(mean=mean, std=std)
443
+
444
+ self.transform = transforms.Compose(
445
+ [
446
+ transforms.RandomResizedCrop(
447
+ image_size,
448
+ scale=(min_scale, max_scale),
449
+ interpolation=InterpolationMode.BICUBIC,
450
+ ),
451
+ # transforms.RandomHorizontalFlip(),
452
+ RandomAugment(
453
+ 2,
454
+ 5,
455
+ isPIL=True,
456
+ augs=[
457
+ "Identity",
458
+ # "AutoContrast",
459
+ "Brightness",
460
+ "Sharpness",
461
+ "Equalize",
462
+ # "ShearX",
463
+ # "ShearY",
464
+ # "TranslateX",
465
+ # "TranslateY",
466
+ # "Rotate",
467
+ ],
468
+ ),
469
+ transforms.ToTensor(),
470
+ self.normalize,
471
+ ]
472
+ )
473
+
474
+ def __call__(self, item):
475
+ return self.transform(item)
476
+
477
+
478
+ class BlipImageEvalProcessor(BlipImageBaseProcessor):
479
+ def __init__(self, image_size=384, mean=None, std=None):
480
+ super().__init__(mean=mean, std=std)
481
+
482
+ self.transform = transforms.Compose(
483
+ [
484
+ transforms.Resize(
485
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
486
+ ),
487
+ transforms.ToTensor(),
488
+ self.normalize,
489
+ ]
490
+ )
491
+
492
+ def __call__(self, item):
493
+ return self.transform(item)
494
+
495
+
496
+ # if __name__ == "__main__":
497
+ # a = BlipImageTrainProcessor(image_size=1024)
498
+ # # img = np.random.randn(1024, 1024, 3)
499
+ # # x = torch.zeros(1024, 1024, 3)
500
+ # x = Image.open("/data/codes/GOT-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
501
+ # print(x.size)
502
+ # y = a(x)
503
+
504
+ # print(y.size())
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ucaslcl/GOT-OCR2_0",
3
+ "architectures": [
4
+ "GOTQwenForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "GOT_ocr_2_0.GOTQwenForCausalLM"
8
+ },
9
+ "attention_dropout": 0.0,
10
+ "bos_token_id": 151643,
11
+ "eos_token_id": 151643,
12
+ "freeze_vision_tower": false,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 1024,
15
+ "im_end_token": 151858,
16
+ "im_patch_token": 151859,
17
+ "im_start_token": 151857,
18
+ "image_token_len": 256,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 2816,
21
+ "max_position_embeddings": 32768,
22
+ "max_window_layers": 21,
23
+ "model_type": "mmgpt",
24
+ "num_attention_heads": 16,
25
+ "num_hidden_layers": 24,
26
+ "num_key_value_heads": 16,
27
+ "rms_norm_eps": 1e-06,
28
+ "rope_theta": 1000000.0,
29
+ "sliding_window": 32768,
30
+ "tie_word_embeddings": true,
31
+ "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.37.2",
33
+ "use_cache": true,
34
+ "use_im_start_end": true,
35
+ "use_sliding_window": false,
36
+ "vocab_size": 151860
37
+ }
constants.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "log"
5
+
6
+ IGNORE_INDEX = -100
7
+ # DEFAULT_PAD_TOKEN = "[PAD]"
8
+
9
+ DEFAULT_PAD_TOKEN = "<|endoftext|>"
10
+ DEFAULT_EOS_TOKEN = "</s>"
11
+ DEFAULT_BOS_TOKEN = "</s>"
12
+ DEFAULT_UNK_TOKEN = "<unk>"
13
+ DEFAULT_IMAGE_TOKEN = "<image>"
14
+ DEFAULT_BOX_TOKEN = "<box>"
15
+
16
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
17
+
18
+ DEFAULT_IM_START_TOKEN = '<img>'
19
+ DEFAULT_IM_END_TOKEN = '</img>'
20
+
21
+
22
+
23
+ CONVERSATION_DATA = {
24
+
25
+ 'data_1': {
26
+ 'images': '/path/',
27
+ 'annotations': '/path/data1.json',
28
+ },
29
+ 'data_2': {
30
+ 'images': '/path/',
31
+ 'annotations': '/path/data2.json',
32
+ },
33
+ 'data_3': {
34
+ 'images': '/path/',
35
+ 'annotations': '/path/data3.json',
36
+ },
37
+
38
+
39
+ }
conversation.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+
12
+
13
+
14
+ # simple_conv_multimodal = Conversation(
15
+ # system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
16
+ # "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
17
+ # "Follow the instructions carefully and explain your answers in detail.",
18
+ # # system="",
19
+ # roles=("Human", "Assistant"),
20
+ # messages=(
21
+ # ("Human", "Hi!"),
22
+ # ("Assistant", "Hi there! How can I help you today?\n")
23
+ # ),
24
+ # offset=2,
25
+ # sep_style=SeparatorStyle.SINGLE,
26
+ # sep="###",
27
+ # )
28
+
29
+ # conv_mpt = Conversation(
30
+ # system="""<|im_start|>system
31
+ # - You are a helpful language and vision assistant.
32
+ # - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
33
+ # - You should follow the instructions carefully and explain your answers in detail.""",
34
+ # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
35
+ # version="mpt",
36
+ # messages=(),
37
+ # offset=0,
38
+ # sep_style=SeparatorStyle.MPT,
39
+ # sep="<|im_end|>",
40
+ # )
41
+
42
+ @dataclasses.dataclass
43
+ class Conversation:
44
+ """A class that keeps all conversation history."""
45
+ system: str
46
+ roles: List[str]
47
+ messages: List[List[str]]
48
+ offset: int
49
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
50
+ sep: str = "<|im_end|>"
51
+ sep2: str = None
52
+ version: str = "Unknown"
53
+
54
+ skip_next: bool = False
55
+
56
+ def get_prompt(self):
57
+ if self.sep_style == SeparatorStyle.SINGLE:
58
+ ret = self.system + self.sep + '\n'
59
+ for role, message in self.messages:
60
+ if message:
61
+ if type(message) is tuple:
62
+ message, _, _ = message
63
+ ret += role + ": " + message + self.sep
64
+ else:
65
+ ret += role + ":"
66
+ return ret
67
+ elif self.sep_style == SeparatorStyle.TWO:
68
+ seps = [self.sep, self.sep2]
69
+ ret = self.system + seps[0]
70
+ for i, (role, message) in enumerate(self.messages):
71
+ if message:
72
+ if type(message) is tuple:
73
+ message, _, _ = message
74
+ ret += role + ": " + message + seps[i % 2]
75
+ else:
76
+ ret += role + ":"
77
+ return ret
78
+ if self.sep_style == SeparatorStyle.MPT:
79
+ if self.system:
80
+ ret = self.system + self.sep
81
+ else:
82
+ ret = ''
83
+ for role, message in self.messages:
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, _, _ = message
87
+ ret += role + message + self.sep
88
+ else:
89
+ ret += role
90
+ return ret
91
+ else:
92
+ raise ValueError(f"Invalid style: {self.sep_style}")
93
+ # if self.sep_style == SeparatorStyle.MPT:
94
+ # if self.system:
95
+ # ret = self.system + self.sep
96
+ # else:
97
+ # ret = ''
98
+ # for role, message in self.messages:
99
+ # if message:
100
+ # if type(message) is tuple:
101
+ # message, _, _ = message
102
+ # ret += role + message + self.sep
103
+ # # if 'user' in role:
104
+ # # ret += role + message + self.sep + "\n"
105
+ # # else:
106
+ # # ret += role + message + self.sep
107
+ # else:
108
+ # ret += role
109
+ # return ret
110
+ # else:
111
+ # raise ValueError(f"Invalid style: {self.sep_style}")
112
+
113
+ def append_message(self, role, message):
114
+ self.messages.append([role, message])
115
+
116
+ def get_images(self, return_pil=False):
117
+ images = []
118
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
119
+ if i % 2 == 0:
120
+ if type(msg) is tuple:
121
+ import base64
122
+ from io import BytesIO
123
+ from PIL import Image
124
+ msg, image, image_process_mode = msg
125
+ if image_process_mode == "Pad":
126
+ def expand2square(pil_img, background_color=(122, 116, 104)):
127
+ width, height = pil_img.size
128
+ if width == height:
129
+ return pil_img
130
+ elif width > height:
131
+ result = Image.new(pil_img.mode, (width, width), background_color)
132
+ # result.paste(pil_img, (0, (width - height) // 2))
133
+ result.paste(pil_img)
134
+ return result
135
+ else:
136
+ result = Image.new(pil_img.mode, (height, height), background_color)
137
+ # result.paste(pil_img, ((height - width) // 2, 0))
138
+ result.paste(pil_img)
139
+ return result
140
+ image = expand2square(image)
141
+ elif image_process_mode == "Crop":
142
+ max_hw, min_hw = max(image.size), min(image.size)
143
+ aspect_ratio = max_hw / min_hw
144
+ max_len, min_len = 800, 400
145
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
146
+ longest_edge = int(shortest_edge * aspect_ratio)
147
+ W, H = image.size
148
+ if H > W:
149
+ H, W = longest_edge, shortest_edge
150
+ else:
151
+ H, W = shortest_edge, longest_edge
152
+ image = image.resize((W, H))
153
+ elif image_process_mode == "Resize":
154
+ image = image.resize((224, 224))
155
+ else:
156
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
157
+
158
+ if return_pil:
159
+ images.append(image)
160
+ else:
161
+ buffered = BytesIO()
162
+ image.convert('RGB').save(buffered, format="JPEG")
163
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
164
+ images.append(img_b64_str)
165
+ return images
166
+
167
+ def to_gradio_chatbot(self):
168
+ ret = []
169
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
170
+ if i % 2 == 0:
171
+ if type(msg) is tuple:
172
+ import base64
173
+ from io import BytesIO
174
+ msg, image, image_process_mode = msg
175
+ max_hw, min_hw = max(image.size), min(image.size)
176
+ aspect_ratio = max_hw / min_hw
177
+ max_len, min_len = 800, 400
178
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
179
+ longest_edge = int(shortest_edge * aspect_ratio)
180
+ W, H = image.size
181
+ if H > W:
182
+ H, W = longest_edge, shortest_edge
183
+ else:
184
+ H, W = shortest_edge, longest_edge
185
+ image = image.resize((W, H))
186
+ # image = image.resize((224, 224))
187
+ buffered = BytesIO()
188
+ image.save(buffered, format="JPEG")
189
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
190
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg.replace('<image>', img_str)
192
+ ret.append([msg, None])
193
+ else:
194
+ ret[-1][-1] = msg
195
+ return ret
196
+
197
+ def copy(self):
198
+ return Conversation(
199
+ system=self.system,
200
+ roles=self.roles,
201
+ messages=[[x, y] for x, y in self.messages],
202
+ offset=self.offset,
203
+ sep_style=self.sep_style,
204
+ sep=self.sep,
205
+ sep2=self.sep2)
206
+
207
+ def dict(self):
208
+ if len(self.get_images()) > 0:
209
+ return {
210
+ "system": self.system,
211
+ "roles": self.roles,
212
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
213
+ "offset": self.offset,
214
+ "sep": self.sep,
215
+ "sep2": self.sep2,
216
+ }
217
+ return {
218
+ "system": self.system,
219
+ "roles": self.roles,
220
+ "messages": self.messages,
221
+ "offset": self.offset,
222
+ "sep": self.sep,
223
+ "sep2": self.sep2,
224
+ }
225
+
226
+
227
+ conv_v1 = Conversation(
228
+ system="A chat between a curious human and an artificial intelligence assistant. "
229
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
230
+ roles=("Human", "Assistant"),
231
+ messages=(
232
+ ("Human", "Give three tips for staying healthy."),
233
+ ("Assistant",
234
+ "Sure, here are three tips for staying healthy:\n"
235
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
236
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
237
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
238
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
239
+ "activities at least two days per week.\n"
240
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
241
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
242
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
243
+ "and aim to drink plenty of water throughout the day.\n"
244
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
245
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
246
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
247
+ "help improve the quality of your sleep.")
248
+ ),
249
+ offset=2,
250
+ sep_style=SeparatorStyle.SINGLE,
251
+ sep="###",
252
+ )
253
+
254
+ conv_v1_2 = Conversation(
255
+ system="A chat between a curious human and an artificial intelligence assistant. "
256
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
257
+ roles=("Human", "Assistant"),
258
+ messages=(
259
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
260
+ ("Assistant",
261
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
262
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
263
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
264
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
265
+ "renewable and non-renewable energy sources:\n"
266
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
267
+ "energy sources are finite and will eventually run out.\n"
268
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
269
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
270
+ "and other negative effects.\n"
271
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
272
+ "have lower operational costs than non-renewable sources.\n"
273
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
274
+ "locations than non-renewable sources.\n"
275
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
276
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
277
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
278
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
279
+ ),
280
+ offset=2,
281
+ sep_style=SeparatorStyle.SINGLE,
282
+ sep="###",
283
+ )
284
+
285
+ conv_vicuna_v1_1 = Conversation(
286
+ system="A chat between a curious user and an artificial intelligence assistant. "
287
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
288
+ roles=("USER", "ASSISTANT"),
289
+ version="v1",
290
+ messages=(),
291
+ offset=0,
292
+ sep_style=SeparatorStyle.TWO,
293
+ sep=" ",
294
+ sep2="</s>",
295
+ )
296
+
297
+ # conv_mpt = Conversation(
298
+ # system="""<|im_start|>system
299
+ # - You are designed by Megvii(旷视), and your name is GOT.
300
+ # - 你叫GOT, 你来自旷视, 你是旷视开发的。
301
+ # - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""",
302
+ # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
303
+ # version="mpt",
304
+ # messages=(),
305
+ # offset=0,
306
+ # sep_style=SeparatorStyle.MPT,
307
+ # sep="<|im_end|>",
308
+ # )
309
+
310
+ conv_mpt = Conversation(
311
+ system="""<|im_start|>system
312
+ You should follow the instructions carefully and explain your answers in detail.""",
313
+ # system = None,
314
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
315
+ version="mpt",
316
+ messages=(),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.MPT,
319
+ sep="<|im_end|>",
320
+ )
321
+
322
+ conv_mpt_eval = Conversation(
323
+ system="",
324
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
325
+ version="mpt",
326
+ messages=(),
327
+ offset=0,
328
+ sep_style=SeparatorStyle.MPT,
329
+ sep="<|im_end|>",
330
+ )
331
+
332
+ conv_mpt_text = Conversation(
333
+ system="""<|im_start|>system
334
+ - You are a helpful assistant chatbot trained by MosaicML.
335
+ - You answer questions.
336
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
337
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
338
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
339
+ version="mpt",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.MPT,
343
+ sep="<|im_end|>",
344
+ )
345
+
346
+ conv_bair_v1 = Conversation(
347
+ system="BEGINNING OF CONVERSATION:",
348
+ roles=("USER", "GPT"),
349
+ messages=(),
350
+ offset=0,
351
+ sep_style=SeparatorStyle.TWO,
352
+ sep=" ",
353
+ sep2="</s>",
354
+ )
355
+
356
+ # simple_conv = Conversation(
357
+ # system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture."
358
+ # "You are designed to assist human with a variety of tasks using natural language."
359
+ # "Follow the instructions carefully.",
360
+ # roles=("Human", "Assistant"),
361
+ # messages=(
362
+ # ("Human", "Hi!"),
363
+ # ("Assistant", "Hi there! How can I help you today?\n")
364
+ # ),
365
+ # offset=2,
366
+ # sep_style=SeparatorStyle.SINGLE,
367
+ # sep="###",
368
+ # )
369
+
370
+
371
+ simple_conv = Conversation(
372
+ system="",
373
+ roles=("Human", "Assistant"),
374
+ messages=(
375
+ ),
376
+ offset=0,
377
+ sep_style=SeparatorStyle.SINGLE,
378
+ sep="###",
379
+ )
380
+
381
+ simple_conv_multimodal = Conversation(
382
+ system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
383
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
384
+ "Follow the instructions carefully and explain your answers in detail.",
385
+ # system="",
386
+ roles=("Human", "Assistant"),
387
+ messages=(
388
+ ("Human", "Hi!"),
389
+ ("Assistant", "Hi there! How can I help you today?\n")
390
+ ),
391
+ offset=2,
392
+ sep_style=SeparatorStyle.SINGLE,
393
+ sep="###",
394
+ )
395
+
396
+ simple_conv_mpt_multimodal = Conversation(
397
+ system="""<|im_start|>system
398
+ - You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.
399
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
400
+ - You should follow the instructions carefully and explain your answers in detail.""",
401
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
402
+ version="mpt",
403
+ messages=(),
404
+ offset=0,
405
+ sep_style=SeparatorStyle.MPT,
406
+ sep="<|im_end|>",
407
+ )
408
+
409
+ simple_conv_legacy = Conversation(
410
+ system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology."
411
+ "You are designed to assist human with a variety of tasks using natural language."
412
+ "Follow the instructions carefully.",
413
+ roles=("Human", "Assistant"),
414
+ messages=(
415
+ ("Human", "Hi!\n\n### Response:"),
416
+ ("Assistant", "Hi there! How can I help you today?\n")
417
+ ),
418
+ offset=2,
419
+ sep_style=SeparatorStyle.SINGLE,
420
+ sep="###",
421
+ )
422
+
423
+ conv_llava_v1 = Conversation(
424
+ system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
425
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
426
+ "Follow the instructions carefully and explain your answers in detail.",
427
+ roles=("USER", "ASSISTANT"),
428
+ version="v1",
429
+ messages=(),
430
+ offset=0,
431
+ sep_style=SeparatorStyle.TWO,
432
+ sep=" ",
433
+ sep2="</s>",
434
+ )
435
+
436
+ default_conversation = conv_mpt
437
+ conv_templates = {
438
+ "default": simple_conv_multimodal,
439
+ "simple": simple_conv,
440
+ "simple_legacy": simple_conv_legacy,
441
+ "multimodal": simple_conv,
442
+ "mpt_multimodal": simple_conv_mpt_multimodal,
443
+ "llava_v1": conv_llava_v1,
444
+ "mpt_eval": conv_mpt_eval,
445
+ # fastchat
446
+ "v1": conv_vicuna_v1_1,
447
+ "bair_v1": conv_bair_v1,
448
+ "vicuna_v1_1": conv_vicuna_v1_1,
449
+ "mpt": conv_mpt,
450
+ "mpt_text": conv_mpt_text,
451
+ }
452
+
453
+
454
+ if __name__ == "__main__":
455
+ print(default_conversation.get_prompt())
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.37.2"
6
+ }
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
run_ocr.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import os
5
+ from .conversation import conv_templates, SeparatorStyle
6
+ from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
7
+ from .utils import KeywordsStoppingCriteria, disable_torch_init
8
+
9
+ from PIL import Image
10
+
11
+ import os
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from .blip_process import BlipImageEvalProcessor
16
+
17
+ from .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig
18
+
19
+ from transformers import TextStreamer
20
+ import re
21
+ import string
22
+
23
+
24
+ import string
25
+
26
+ punctuation_dict = {
27
+ ",": ",",
28
+ "。": ".",
29
+ }
30
+
31
+
32
+ def svg_to_html(svg_content, output_filename):
33
+
34
+ html_content = f"""
35
+ <!DOCTYPE html>
36
+ <html lang="en">
37
+ <head>
38
+ <meta charset="UTF-8">
39
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
40
+ <title>SVG Embedded in HTML</title>
41
+ </head>
42
+ <body>
43
+ <svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
44
+ {svg_content}
45
+ </svg>
46
+ </body>
47
+ </html>
48
+ """
49
+
50
+ with open(output_filename, 'w') as file:
51
+ file.write(html_content)
52
+
53
+
54
+
55
+ DEFAULT_IMAGE_TOKEN = "<image>"
56
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
57
+
58
+ DEFAULT_IM_START_TOKEN = '<img>'
59
+ DEFAULT_IM_END_TOKEN = '</img>'
60
+
61
+
62
+
63
+ translation_table = str.maketrans(punctuation_dict)
64
+
65
+
66
+ def load_image(image_file):
67
+ if image_file.startswith('http') or image_file.startswith('https'):
68
+ response = requests.get(image_file)
69
+ image = Image.open(BytesIO(response.content)).convert('RGB')
70
+ else:
71
+ image = Image.open(image_file).convert('RGB')
72
+ return image
73
+
74
+
75
+ def eval_model(model_name, image_file, ocr_type, ocr_box='', ocr_color='', render=False):
76
+ # Model
77
+ disable_torch_init()
78
+ # model_name = os.path.expanduser(args.model_name)
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
81
+
82
+
83
+ model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
84
+
85
+
86
+
87
+ model.to(device='cuda', dtype=torch.bfloat16)
88
+
89
+
90
+ # TODO vary old codes, NEED del
91
+ image_processor = BlipImageEvalProcessor(image_size=1024)
92
+
93
+ image_processor_high = BlipImageEvalProcessor(image_size=1024)
94
+
95
+ use_im_start_end = True
96
+
97
+ image_token_len = 256
98
+
99
+ image = load_image(image_file)
100
+
101
+ w, h = image.size
102
+ # print(image.size)
103
+
104
+ if ocr_type == 'format':
105
+ qs = 'OCR with format: '
106
+ else:
107
+ qs = 'OCR: '
108
+
109
+ if ocr_box:
110
+ bbox = eval(ocr_box)
111
+ if len(bbox) == 2:
112
+ bbox[0] = int(bbox[0]/w*1000)
113
+ bbox[1] = int(bbox[1]/h*1000)
114
+ if len(bbox) == 4:
115
+ bbox[0] = int(bbox[0]/w*1000)
116
+ bbox[1] = int(bbox[1]/h*1000)
117
+ bbox[2] = int(bbox[2]/w*1000)
118
+ bbox[3] = int(bbox[3]/h*1000)
119
+ if ocr_type == 'format':
120
+ qs = str(bbox) + ' ' + 'OCR with format: '
121
+ else:
122
+ qs = str(bbox) + ' ' + 'OCR: '
123
+
124
+ if ocr_color:
125
+ if ocr_type == 'format':
126
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
127
+ else:
128
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
129
+
130
+ if use_im_start_end:
131
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
132
+ else:
133
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
134
+
135
+
136
+
137
+ conv_mode = "mpt"
138
+ args.conv_mode = conv_mode
139
+
140
+ conv = conv_templates[args.conv_mode].copy()
141
+ conv.append_message(conv.roles[0], qs)
142
+ conv.append_message(conv.roles[1], None)
143
+ prompt = conv.get_prompt()
144
+
145
+ print(prompt)
146
+
147
+
148
+ inputs = tokenizer([prompt])
149
+
150
+
151
+ # vary old codes, no use
152
+ image_1 = image.copy()
153
+ image_tensor = image_processor(image)
154
+
155
+
156
+ image_tensor_1 = image_processor_high(image_1)
157
+
158
+
159
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
160
+
161
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
162
+ keywords = [stop_str]
163
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
164
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
165
+
166
+
167
+ with torch.autocast("cuda", dtype=torch.bfloat16):
168
+ output_ids = model.generate(
169
+ input_ids,
170
+ images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
171
+ do_sample=False,
172
+ num_beams = 1,
173
+ no_repeat_ngram_size = 20,
174
+ streamer=streamer,
175
+ max_new_tokens=4096,
176
+ stopping_criteria=[stopping_criteria]
177
+ )
178
+
179
+
180
+ if render:
181
+ print('==============rendering===============')
182
+
183
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
184
+
185
+ if outputs.endswith(stop_str):
186
+ outputs = outputs[:-len(stop_str)]
187
+ outputs = outputs.strip()
188
+
189
+ if '**kern' in outputs:
190
+ import verovio
191
+ from cairosvg import svg2png
192
+ import cv2
193
+ import numpy as np
194
+ tk = verovio.toolkit()
195
+ tk.loadData(outputs)
196
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
197
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
198
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
199
+ tk.getPageCount()
200
+ svg = tk.renderToSVG()
201
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
202
+
203
+ svg_to_html(svg, "./results/demo.html")
204
+
205
+ if ocr_type == 'format' and '**kern' not in outputs:
206
+
207
+
208
+ if '\\begin{tikzpicture}' not in outputs:
209
+ html_path = "./render_tools/" + "/content-mmd-to-html.html"
210
+ html_path_2 = "./results/demo.html"
211
+ right_num = outputs.count('\\right')
212
+ left_num = outputs.count('\left')
213
+
214
+ if right_num != left_num:
215
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
216
+
217
+
218
+ outputs = outputs.replace('"', '``').replace('$', '')
219
+
220
+ outputs_list = outputs.split('\n')
221
+ gt= ''
222
+ for out in outputs_list:
223
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
224
+
225
+ gt = gt[:-2]
226
+
227
+ with open(html_path, 'r') as web_f:
228
+ lines = web_f.read()
229
+ lines = lines.split("const text =")
230
+ new_web = lines[0] + 'const text =' + gt + lines[1]
231
+ else:
232
+ html_path = "./render_tools/" + "/tikz.html"
233
+ html_path_2 = "./results/demo.html"
234
+ outputs = outputs.translate(translation_table)
235
+ outputs_list = outputs.split('\n')
236
+ gt= ''
237
+ for out in outputs_list:
238
+ if out:
239
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
240
+ while out[-1] == ' ':
241
+ out = out[:-1]
242
+ if out is None:
243
+ break
244
+
245
+ if out:
246
+ if out[-1] != ';':
247
+ gt += out[:-1] + ';\n'
248
+ else:
249
+ gt += out + '\n'
250
+ else:
251
+ gt += out + '\n'
252
+
253
+
254
+ with open(html_path, 'r') as web_f:
255
+ lines = web_f.read()
256
+ lines = lines.split("const text =")
257
+ new_web = lines[0] + gt + lines[1]
258
+
259
+ with open(html_path_2, 'w') as web_f_new:
260
+ web_f_new.write(new_web)
261
+
262
+
263
+
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
269
+ parser.add_argument("--image-file", type=str, required=True)
270
+ parser.add_argument("--type", type=str, required=True)
271
+ parser.add_argument("--box", type=str, default= '')
272
+ parser.add_argument("--color", type=str, default= '')
273
+ parser.add_argument("--render", action='store_true')
274
+ args = parser.parse_args()
275
+
276
+ eval_model(args)
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ }
9
+ }
tokenization_qwen.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ image_start_tag='<img>',
55
+ image_end_tag='</img>',
56
+ image_pad_tag='<imgpad>',
57
+ ref_start_tag='<ref>',
58
+ ref_end_tag='</ref>',
59
+ box_start_tag='<box>',
60
+ box_end_tag='</box>',
61
+ quad_start_tag='<quad>',
62
+ quad_end_tag='</quad>',
63
+ **kwargs,
64
+ ):
65
+ super().__init__(**kwargs)
66
+
67
+ self.image_start_tag = image_start_tag
68
+ self.image_end_tag = image_end_tag
69
+ self.image_pad_tag = image_pad_tag
70
+ self.ref_start_tag = ref_start_tag
71
+ self.ref_end_tag = ref_end_tag
72
+ self.box_start_tag = box_start_tag
73
+ self.box_end_tag = box_end_tag
74
+ self.quad_start_tag = quad_start_tag
75
+ self.quad_end_tag = quad_end_tag
76
+ self.IMAGE_ST = (
77
+ ref_start_tag, ref_end_tag,
78
+ box_start_tag, box_end_tag,
79
+ quad_start_tag, quad_end_tag,
80
+ image_start_tag, image_end_tag,
81
+ image_pad_tag
82
+ )
83
+
84
+ self.errors = errors # how to handle errors in decoding
85
+
86
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
87
+ self.special_tokens = {
88
+ token: index
89
+ for index, token in enumerate(
90
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
91
+ )
92
+ }
93
+
94
+ self.img_start_id = self.special_tokens[self.image_start_tag]
95
+ self.img_end_id = self.special_tokens[self.image_end_tag]
96
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
97
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
98
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
99
+ self.box_start_id = self.special_tokens[self.box_start_tag]
100
+ self.box_end_id = self.special_tokens[self.box_end_tag]
101
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
102
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
103
+
104
+ enc = tiktoken.Encoding(
105
+ "Qwen",
106
+ pat_str=PAT_STR,
107
+ mergeable_ranks=self.mergeable_ranks,
108
+ special_tokens=self.special_tokens,
109
+ )
110
+ assert (
111
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
112
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
113
+
114
+ self.decoder = {
115
+ v: k for k, v in self.mergeable_ranks.items()
116
+ } # type: dict[int, bytes|str]
117
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
118
+
119
+ self.tokenizer = enc # type: tiktoken.Encoding
120
+
121
+ self.eod_id = self.tokenizer.eot_token
122
+ self.im_start_id = self.special_tokens[IMSTART]
123
+ self.im_end_id = self.special_tokens[IMEND]
124
+
125
+ def __len__(self) -> int:
126
+ return self.tokenizer.n_vocab
127
+
128
+ def get_vocab(self) -> Dict[bytes, int]:
129
+ return self.mergeable_ranks
130
+
131
+ def convert_tokens_to_ids(
132
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
133
+ ) -> List[int]:
134
+ ids = []
135
+ if isinstance(tokens, (str, bytes)):
136
+ if tokens in self.special_tokens:
137
+ return self.special_tokens[tokens]
138
+ else:
139
+ return self.mergeable_ranks.get(tokens)
140
+ for token in tokens:
141
+ if token in self.special_tokens:
142
+ ids.append(self.special_tokens[token])
143
+ else:
144
+ ids.append(self.mergeable_ranks.get(token))
145
+ return ids
146
+
147
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
148
+ if not special_tokens and new_tokens:
149
+ raise ValueError('Adding regular tokens is not supported')
150
+ for token in new_tokens:
151
+ surface_form = token.content if isinstance(token, AddedToken) else token
152
+ if surface_form not in SPECIAL_TOKENS:
153
+ raise ValueError('Adding unknown special tokens is not supported')
154
+ return 0
155
+
156
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
157
+ """
158
+ Save only the vocabulary of the tokenizer (vocabulary).
159
+
160
+ Returns:
161
+ `Tuple(str)`: Paths to the files saved.
162
+ """
163
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
164
+ with open(file_path, "w", encoding="utf8") as w:
165
+ for k, v in self.mergeable_ranks.items():
166
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
167
+ w.write(line)
168
+ return (file_path,)
169
+
170
+ def tokenize(
171
+ self,
172
+ text: str,
173
+ allowed_special: Union[Set, str] = "all",
174
+ disallowed_special: Union[Collection, str] = (),
175
+ **kwargs,
176
+ ) -> List[Union[bytes, str]]:
177
+ """
178
+ Converts a string in a sequence of tokens.
179
+
180
+ Args:
181
+ text (`str`):
182
+ The sequence to be encoded.
183
+ allowed_special (`Literal["all"]` or `set`):
184
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
185
+ Default to "all".
186
+ disallowed_special (`Literal["all"]` or `Collection`):
187
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
188
+ Default to an empty tuple.
189
+
190
+ kwargs (additional keyword arguments, *optional*):
191
+ Will be passed to the underlying model specific encode method.
192
+
193
+ Returns:
194
+ `List[bytes|str]`: The list of tokens.
195
+ """
196
+ tokens = []
197
+ text = unicodedata.normalize("NFC", text)
198
+
199
+ # this implementation takes a detour: text -> token id -> token surface forms
200
+ for t in self.tokenizer.encode(
201
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
202
+ ):
203
+ tokens.append(self.decoder[t])
204
+ return tokens
205
+
206
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
207
+ """
208
+ Converts a sequence of tokens in a single string.
209
+ """
210
+ text = ""
211
+ temp = b""
212
+ for t in tokens:
213
+ if isinstance(t, str):
214
+ if temp:
215
+ text += temp.decode("utf-8", errors=self.errors)
216
+ temp = b""
217
+ text += t
218
+ elif isinstance(t, bytes):
219
+ temp += t
220
+ else:
221
+ raise TypeError("token should only be of type types or str")
222
+ if temp:
223
+ text += temp.decode("utf-8", errors=self.errors)
224
+ return text
225
+
226
+ @property
227
+ def vocab_size(self):
228
+ return self.tokenizer.n_vocab
229
+
230
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
231
+ """Converts an id to a token, special tokens included"""
232
+ if index in self.decoder:
233
+ return self.decoder[index]
234
+ raise ValueError("unknown ids")
235
+
236
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
237
+ """Converts a token to an id using the vocab, special tokens included"""
238
+ if token in self.special_tokens:
239
+ return self.special_tokens[token]
240
+ if token in self.mergeable_ranks:
241
+ return self.mergeable_ranks[token]
242
+ raise ValueError("unknown token")
243
+
244
+ def _tokenize(self, text: str, **kwargs):
245
+ """
246
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
247
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
248
+
249
+ Do NOT take care of added tokens.
250
+ """
251
+ raise NotImplementedError
252
+
253
+ def _decode(
254
+ self,
255
+ token_ids: Union[int, List[int]],
256
+ skip_special_tokens: bool = False,
257
+ errors: str = None,
258
+ **kwargs,
259
+ ) -> str:
260
+ if isinstance(token_ids, int):
261
+ token_ids = [token_ids]
262
+ if skip_special_tokens:
263
+ token_ids = [i for i in token_ids if i < self.eod_id]
264
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_qwen.QWenTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": true,
10
+ "model_max_length": 8000,
11
+ "pad_token": "<|endoftext|>",
12
+ "padding_side": "right",
13
+ "tokenizer_class": "QWenTokenizer"
14
+ }
utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+ import torch
7
+ import requests
8
+
9
+ from transformers import StoppingCriteria
10
+ from .constants import LOGDIR
11
+
12
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
13
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
14
+
15
+ handler = None
16
+
17
+
18
+ def build_logger(logger_name, logger_filename):
19
+ global handler
20
+
21
+ formatter = logging.Formatter(
22
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
23
+ datefmt="%Y-%m-%d %H:%M:%S",
24
+ )
25
+
26
+ # Set the format of root handlers
27
+ if not logging.getLogger().handlers:
28
+ logging.basicConfig(level=logging.INFO)
29
+ logging.getLogger().handlers[0].setFormatter(formatter)
30
+
31
+ # Redirect stdout and stderr to loggers
32
+ stdout_logger = logging.getLogger("stdout")
33
+ stdout_logger.setLevel(logging.INFO)
34
+ sl = StreamToLogger(stdout_logger, logging.INFO)
35
+ sys.stdout = sl
36
+
37
+ stderr_logger = logging.getLogger("stderr")
38
+ stderr_logger.setLevel(logging.ERROR)
39
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
40
+ sys.stderr = sl
41
+
42
+ # Get logger
43
+ logger = logging.getLogger(logger_name)
44
+ logger.setLevel(logging.INFO)
45
+
46
+ # Add a file handler for all loggers
47
+ if handler is None:
48
+ os.makedirs(LOGDIR, exist_ok=True)
49
+ filename = os.path.join(LOGDIR, logger_filename)
50
+ handler = logging.handlers.TimedRotatingFileHandler(
51
+ filename, when='D', utc=True)
52
+ handler.setFormatter(formatter)
53
+
54
+ for name, item in logging.root.manager.loggerDict.items():
55
+ if isinstance(item, logging.Logger):
56
+ item.addHandler(handler)
57
+
58
+ return logger
59
+
60
+
61
+ class StreamToLogger(object):
62
+ """
63
+ Fake file-like stream object that redirects writes to a logger instance.
64
+ """
65
+ def __init__(self, logger, log_level=logging.INFO):
66
+ self.terminal = sys.stdout
67
+ self.logger = logger
68
+ self.log_level = log_level
69
+ self.linebuf = ''
70
+
71
+ def __getattr__(self, attr):
72
+ return getattr(self.terminal, attr)
73
+
74
+ def write(self, buf):
75
+ temp_linebuf = self.linebuf + buf
76
+ self.linebuf = ''
77
+ for line in temp_linebuf.splitlines(True):
78
+ # From the io.TextIOWrapper docs:
79
+ # On output, if newline is None, any '\n' characters written
80
+ # are translated to the system default line separator.
81
+ # By default sys.stdout.write() expects '\n' newlines and then
82
+ # translates them so this is still cross platform.
83
+ if line[-1] == '\n':
84
+ self.logger.log(self.log_level, line.rstrip())
85
+ else:
86
+ self.linebuf += line
87
+
88
+ def flush(self):
89
+ if self.linebuf != '':
90
+ self.logger.log(self.log_level, self.linebuf.rstrip())
91
+ self.linebuf = ''
92
+
93
+
94
+ def disable_torch_init():
95
+ """
96
+ Disable the redundant torch default initialization to accelerate model creation.
97
+ """
98
+ import torch
99
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
100
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
101
+
102
+
103
+ def violates_moderation(text):
104
+ """
105
+ Check whether the text violates OpenAI moderation API.
106
+ """
107
+ url = "https://api.openai.com/v1/moderations"
108
+ headers = {"Content-Type": "application/json",
109
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
110
+ text = text.replace("\n", "")
111
+ data = "{" + '"input": ' + f'"{text}"' + "}"
112
+ data = data.encode("utf-8")
113
+ try:
114
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
115
+ flagged = ret.json()["results"][0]["flagged"]
116
+ except requests.exceptions.RequestException as e:
117
+ flagged = False
118
+ except KeyError as e:
119
+ flagged = False
120
+
121
+ return flagged
122
+
123
+
124
+ def pretty_print_semaphore(semaphore):
125
+ if semaphore is None:
126
+ return "None"
127
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
128
+
129
+
130
+ class KeywordsStoppingCriteria(StoppingCriteria):
131
+ def __init__(self, keywords, tokenizer, input_ids):
132
+ self.keywords = keywords
133
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
134
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
135
+ self.tokenizer = tokenizer
136
+ self.start_len = None
137
+ self.input_ids = input_ids
138
+
139
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
140
+ if self.start_len is None:
141
+ self.start_len = self.input_ids.shape[1]
142
+ else:
143
+ for keyword_id in self.keyword_ids:
144
+ if output_ids[0, -1] == keyword_id:
145
+ return True
146
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
147
+ for keyword in self.keywords:
148
+ if keyword in outputs:
149
+ return True
150
+ return False
151
+
152
+
153
+ def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):
154
+ """Resize tokenizer and embedding.
155
+
156
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
157
+ """
158
+ # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
159
+ # # num_new_tokens = 1
160
+ # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True)
161
+ # model.resize_token_embeddings(len(tokenizer))
162
+
163
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
164
+ model.resize_token_embeddings(len(tokenizer))
165
+
166
+ if num_new_tokens > 0:
167
+ input_embeddings = model.get_input_embeddings().weight.data
168
+ output_embeddings = model.get_output_embeddings().weight.data
169
+
170
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
171
+ dim=0, keepdim=True)
172
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
173
+ dim=0, keepdim=True)
174
+
175
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
176
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
177
+
178
+
179
+ def maybe_zero_3(param, ignore_status=False, name=None):
180
+ from deepspeed import zero
181
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
182
+ if hasattr(param, "ds_id"):
183
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
184
+ if not ignore_status:
185
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
186
+ with zero.GatheredParameters([param]):
187
+ param = param.data.detach().cpu().clone()
188
+ else:
189
+ param = param.detach().cpu().clone()
190
+ return param
191
+
192
+
193
+ # Borrowed from peft.utils.get_peft_model_state_dict
194
+ def get_peft_state_maybe_zero_3(named_params, bias):
195
+ if bias == "none":
196
+ to_return = {k: t for k, t in named_params if "lora_" in k}
197
+ elif bias == "all":
198
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
199
+ elif bias == "lora_only":
200
+ to_return = {}
201
+ maybe_lora_bias = {}
202
+ lora_bias_names = set()
203
+ for k, t in named_params:
204
+ if "lora_" in k:
205
+ to_return[k] = t
206
+ bias_name = k.split("lora_")[0] + "bias"
207
+ lora_bias_names.add(bias_name)
208
+ elif "bias" in k:
209
+ maybe_lora_bias[k] = t
210
+ for k, t in maybe_lora_bias:
211
+ if bias_name in lora_bias_names:
212
+ to_return[bias_name] = t
213
+ else:
214
+ raise NotImplementedError
215
+ to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
216
+ return to_return
217
+
218
+
219
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
220
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
221
+ if require_grad_only:
222
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
223
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
224
+ return to_return
225
+
226
+
227
+ def find_all_linear_names(model):
228
+ cls = torch.nn.Linear
229
+ lora_module_names = set()
230
+ for name, module in model.named_modules():
231
+ if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name:
232
+ lora_module_names.add(name)
233
+
234
+ print(lora_module_names)
235
+ return list(lora_module_names)
vary_b.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from typing import Type
19
+
20
+ # from GOT.model.vision_encoder.vitg_qwen import Resampler
21
+ import math
22
+
23
+
24
+ class Projector(nn.Module):
25
+ def __init__(
26
+ self,
27
+ width: 256,
28
+ n_queries: int = 256,
29
+ output_dim: int = 4096,
30
+ **kwargs
31
+ ):
32
+ super().__init__()
33
+
34
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
35
+ self.attn_pool = Resampler(
36
+ grid_size=int(math.sqrt(n_queries)),
37
+ embed_dim=output_dim,
38
+ num_heads=output_dim // 128,
39
+ kv_dim=width,
40
+ norm_layer=norm_layer,
41
+ )
42
+ self.ln_post = norm_layer(output_dim)
43
+ self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ x = self.attn_pool(x)
47
+ x = self.ln_post(x)
48
+ x = x @ self.proj
49
+
50
+ return x
51
+
52
+
53
+ class MLPBlock(nn.Module):
54
+ def __init__(
55
+ self,
56
+ embedding_dim: int,
57
+ mlp_dim: int,
58
+ act: Type[nn.Module] = nn.GELU,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
62
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
63
+ self.act = act()
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ return self.lin2(self.act(self.lin1(x)))
67
+
68
+
69
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
70
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
71
+ class LayerNorm2d(nn.Module):
72
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
73
+ super().__init__()
74
+ self.weight = nn.Parameter(torch.ones(num_channels))
75
+ self.bias = nn.Parameter(torch.zeros(num_channels))
76
+ self.eps = eps
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ u = x.mean(1, keepdim=True)
80
+ s = (x - u).pow(2).mean(1, keepdim=True)
81
+ x = (x - u) / torch.sqrt(s + self.eps)
82
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
83
+ return x
84
+
85
+
86
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
87
+ class ImageEncoderViT(nn.Module):
88
+ def __init__(
89
+ self,
90
+ img_size: int = 1024,
91
+ patch_size: int = 16,
92
+ in_chans: int = 3,
93
+ embed_dim: int = 768,
94
+ depth: int = 12,
95
+ num_heads: int = 12,
96
+ mlp_ratio: float = 4.0,
97
+ out_chans: int = 256,
98
+ qkv_bias: bool = True,
99
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
100
+ act_layer: Type[nn.Module] = nn.GELU,
101
+ use_abs_pos: bool = True,
102
+ use_rel_pos: bool = False,
103
+ rel_pos_zero_init: bool = True,
104
+ window_size: int = 0,
105
+ global_attn_indexes: Tuple[int, ...] = (),
106
+ ) -> None:
107
+ """
108
+ Args:
109
+ img_size (int): Input image size.
110
+ patch_size (int): Patch size.
111
+ in_chans (int): Number of input image channels.
112
+ embed_dim (int): Patch embedding dimension.
113
+ depth (int): Depth of ViT.
114
+ num_heads (int): Number of attention heads in each ViT block.
115
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
116
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
117
+ norm_layer (nn.Module): Normalization layer.
118
+ act_layer (nn.Module): Activation layer.
119
+ use_abs_pos (bool): If True, use absolute positional embeddings.
120
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
121
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
122
+ window_size (int): Window size for window attention blocks.
123
+ global_attn_indexes (list): Indexes for blocks using global attention.
124
+ """
125
+ super().__init__()
126
+ self.img_size = img_size
127
+
128
+ self.patch_embed = PatchEmbed(
129
+ kernel_size=(patch_size, patch_size),
130
+ stride=(patch_size, patch_size),
131
+ in_chans=in_chans,
132
+ embed_dim=embed_dim,
133
+ )
134
+
135
+ self.pos_embed: Optional[nn.Parameter] = None
136
+ if use_abs_pos:
137
+ # Initialize absolute positional embedding with pretrain image size.
138
+ self.pos_embed = nn.Parameter(
139
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
140
+ )
141
+
142
+ self.blocks = nn.ModuleList()
143
+ for i in range(depth):
144
+ block = Block(
145
+ dim=embed_dim,
146
+ num_heads=num_heads,
147
+ mlp_ratio=mlp_ratio,
148
+ qkv_bias=qkv_bias,
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ use_rel_pos=use_rel_pos,
152
+ rel_pos_zero_init=rel_pos_zero_init,
153
+ window_size=window_size if i not in global_attn_indexes else 0,
154
+ input_size=(img_size // patch_size, img_size // patch_size),
155
+ )
156
+ self.blocks.append(block)
157
+
158
+ self.neck = nn.Sequential(
159
+ nn.Conv2d(
160
+ embed_dim,
161
+ out_chans,
162
+ kernel_size=1,
163
+ bias=False,
164
+ ),
165
+ LayerNorm2d(out_chans),
166
+ nn.Conv2d(
167
+ out_chans,
168
+ out_chans,
169
+ kernel_size=3,
170
+ padding=1,
171
+ bias=False,
172
+ ),
173
+ LayerNorm2d(out_chans),
174
+ )
175
+
176
+
177
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
178
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ x = self.patch_embed(x)
182
+ if self.pos_embed is not None:
183
+ x = x + self.pos_embed
184
+
185
+ for blk in self.blocks:
186
+ x = blk(x)
187
+
188
+ x = self.neck(x.permute(0, 3, 1, 2))
189
+ x = self.net_2(x)
190
+ x = self.net_3(x)
191
+
192
+
193
+ return x
194
+
195
+
196
+ class Block(nn.Module):
197
+ """Transformer blocks with support of window attention and residual propagation blocks"""
198
+
199
+ def __init__(
200
+ self,
201
+ dim: int,
202
+ num_heads: int,
203
+ mlp_ratio: float = 4.0,
204
+ qkv_bias: bool = True,
205
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
206
+ act_layer: Type[nn.Module] = nn.GELU,
207
+ use_rel_pos: bool = False,
208
+ rel_pos_zero_init: bool = True,
209
+ window_size: int = 0,
210
+ input_size: Optional[Tuple[int, int]] = None,
211
+ ) -> None:
212
+ """
213
+ Args:
214
+ dim (int): Number of input channels.
215
+ num_heads (int): Number of attention heads in each ViT block.
216
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
217
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
218
+ norm_layer (nn.Module): Normalization layer.
219
+ act_layer (nn.Module): Activation layer.
220
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
221
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
222
+ window_size (int): Window size for window attention blocks. If it equals 0, then
223
+ use global attention.
224
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
225
+ positional parameter size.
226
+ """
227
+ super().__init__()
228
+ self.norm1 = norm_layer(dim)
229
+ self.attn = Attention(
230
+ dim,
231
+ num_heads=num_heads,
232
+ qkv_bias=qkv_bias,
233
+ use_rel_pos=use_rel_pos,
234
+ rel_pos_zero_init=rel_pos_zero_init,
235
+ input_size=input_size if window_size == 0 else (window_size, window_size),
236
+ )
237
+
238
+ self.norm2 = norm_layer(dim)
239
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
240
+
241
+ self.window_size = window_size
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ shortcut = x
245
+ x = self.norm1(x)
246
+ # Window partition
247
+ if self.window_size > 0:
248
+ H, W = x.shape[1], x.shape[2]
249
+ x, pad_hw = window_partition(x, self.window_size)
250
+
251
+ x = self.attn(x)
252
+ # Reverse window partition
253
+ if self.window_size > 0:
254
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
255
+
256
+ x = shortcut + x
257
+ x = x + self.mlp(self.norm2(x))
258
+
259
+ return x
260
+
261
+
262
+ class Attention(nn.Module):
263
+ """Multi-head Attention block with relative position embeddings."""
264
+
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ num_heads: int = 8,
269
+ qkv_bias: bool = True,
270
+ use_rel_pos: bool = False,
271
+ rel_pos_zero_init: bool = True,
272
+ input_size: Optional[Tuple[int, int]] = None,
273
+ ) -> None:
274
+ """
275
+ Args:
276
+ dim (int): Number of input channels.
277
+ num_heads (int): Number of attention heads.
278
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
279
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
280
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
281
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
282
+ positional parameter size.
283
+ """
284
+ super().__init__()
285
+ self.num_heads = num_heads
286
+ head_dim = dim // num_heads
287
+ self.scale = head_dim**-0.5
288
+
289
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
290
+ self.proj = nn.Linear(dim, dim)
291
+
292
+ self.use_rel_pos = use_rel_pos
293
+ if self.use_rel_pos:
294
+ assert (
295
+ input_size is not None
296
+ ), "Input size must be provided if using relative positional encoding."
297
+ # initialize relative positional embeddings
298
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
299
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
300
+
301
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
302
+ B, H, W, _ = x.shape
303
+ # qkv with shape (3, B, nHead, H * W, C)
304
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
305
+ # q, k, v with shape (B * nHead, H * W, C)
306
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
307
+
308
+ attn = (q * self.scale) @ k.transpose(-2, -1)
309
+
310
+ if self.use_rel_pos:
311
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
312
+
313
+ attn = attn.softmax(dim=-1)
314
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
315
+ x = self.proj(x)
316
+
317
+ return x
318
+
319
+
320
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
321
+ """
322
+ Partition into non-overlapping windows with padding if needed.
323
+ Args:
324
+ x (tensor): input tokens with [B, H, W, C].
325
+ window_size (int): window size.
326
+
327
+ Returns:
328
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
329
+ (Hp, Wp): padded height and width before partition
330
+ """
331
+ B, H, W, C = x.shape
332
+
333
+ pad_h = (window_size - H % window_size) % window_size
334
+ pad_w = (window_size - W % window_size) % window_size
335
+ if pad_h > 0 or pad_w > 0:
336
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
337
+ Hp, Wp = H + pad_h, W + pad_w
338
+
339
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
340
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
341
+ return windows, (Hp, Wp)
342
+
343
+
344
+ def window_unpartition(
345
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
346
+ ) -> torch.Tensor:
347
+ """
348
+ Window unpartition into original sequences and removing padding.
349
+ Args:
350
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
351
+ window_size (int): window size.
352
+ pad_hw (Tuple): padded height and width (Hp, Wp).
353
+ hw (Tuple): original height and width (H, W) before padding.
354
+
355
+ Returns:
356
+ x: unpartitioned sequences with [B, H, W, C].
357
+ """
358
+ Hp, Wp = pad_hw
359
+ H, W = hw
360
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
361
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
362
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
363
+
364
+ if Hp > H or Wp > W:
365
+ x = x[:, :H, :W, :].contiguous()
366
+ return x
367
+
368
+
369
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
370
+ """
371
+ Get relative positional embeddings according to the relative positions of
372
+ query and key sizes.
373
+ Args:
374
+ q_size (int): size of query q.
375
+ k_size (int): size of key k.
376
+ rel_pos (Tensor): relative position embeddings (L, C).
377
+
378
+ Returns:
379
+ Extracted positional embeddings according to relative positions.
380
+ """
381
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
382
+ # Interpolate rel pos if needed.
383
+ if rel_pos.shape[0] != max_rel_dist:
384
+ # Interpolate rel pos.
385
+ rel_pos_resized = F.interpolate(
386
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
387
+ size=max_rel_dist,
388
+ mode="linear",
389
+ )
390
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
391
+ else:
392
+ rel_pos_resized = rel_pos
393
+
394
+ # Scale the coords with short length if shapes for q and k are different.
395
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
396
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
397
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
398
+
399
+ return rel_pos_resized[relative_coords.long()]
400
+
401
+
402
+ def add_decomposed_rel_pos(
403
+ attn: torch.Tensor,
404
+ q: torch.Tensor,
405
+ rel_pos_h: torch.Tensor,
406
+ rel_pos_w: torch.Tensor,
407
+ q_size: Tuple[int, int],
408
+ k_size: Tuple[int, int],
409
+ ) -> torch.Tensor:
410
+ """
411
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
412
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
413
+ Args:
414
+ attn (Tensor): attention map.
415
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
416
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
417
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
418
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
419
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
420
+
421
+ Returns:
422
+ attn (Tensor): attention map with added relative positional embeddings.
423
+ """
424
+ q_h, q_w = q_size
425
+ k_h, k_w = k_size
426
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
427
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
428
+
429
+ B, _, dim = q.shape
430
+ r_q = q.reshape(B, q_h, q_w, dim)
431
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
432
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
433
+
434
+ attn = (
435
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
436
+ ).view(B, q_h * q_w, k_h * k_w)
437
+
438
+ return attn
439
+
440
+
441
+ class PatchEmbed(nn.Module):
442
+ """
443
+ Image to Patch Embedding.
444
+ """
445
+
446
+ def __init__(
447
+ self,
448
+ kernel_size: Tuple[int, int] = (16, 16),
449
+ stride: Tuple[int, int] = (16, 16),
450
+ padding: Tuple[int, int] = (0, 0),
451
+ in_chans: int = 3,
452
+ embed_dim: int = 768,
453
+ ) -> None:
454
+ """
455
+ Args:
456
+ kernel_size (Tuple): kernel size of the projection layer.
457
+ stride (Tuple): stride of the projection layer.
458
+ padding (Tuple): padding size of the projection layer.
459
+ in_chans (int): Number of input image channels.
460
+ embed_dim (int): Patch embedding dimension.
461
+ """
462
+ super().__init__()
463
+
464
+ self.proj = nn.Conv2d(
465
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
466
+ )
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ x = self.proj(x)
470
+ # B C H W -> B H W C
471
+ x = x.permute(0, 2, 3, 1)
472
+ return x
473
+
474
+
475
+
476
+ def build_vary_vit_b(checkpoint=None):
477
+ return _build_vary(
478
+ encoder_embed_dim=768,
479
+ encoder_depth=12,
480
+ encoder_num_heads=12,
481
+ encoder_global_attn_indexes=[2, 5, 8, 11],
482
+ checkpoint=checkpoint,
483
+ )
484
+
485
+
486
+ def _build_vary(
487
+ encoder_embed_dim,
488
+ encoder_depth,
489
+ encoder_num_heads,
490
+ encoder_global_attn_indexes,
491
+ checkpoint=None,
492
+ ):
493
+ prompt_embed_dim = 256
494
+ image_size = 1024
495
+ vit_patch_size = 16
496
+ image_embedding_size = image_size // vit_patch_size
497
+ image_encoder=ImageEncoderViT(
498
+ depth=encoder_depth,
499
+ embed_dim=encoder_embed_dim,
500
+ img_size=image_size,
501
+ mlp_ratio=4,
502
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
503
+ num_heads=encoder_num_heads,
504
+ patch_size=vit_patch_size,
505
+ qkv_bias=True,
506
+ use_rel_pos=True,
507
+ global_attn_indexes=encoder_global_attn_indexes,
508
+ window_size=14,
509
+ out_chans=prompt_embed_dim,
510
+ )
511
+
512
+
513
+ return image_encoder
514
+