VishalD1234 commited on
Commit
490abdf
·
verified ·
1 Parent(s): 1b3635e

Delete processing_florence2.py

Browse files
Files changed (1) hide show
  1. processing_florence2.py +0 -1078
processing_florence2.py DELETED
@@ -1,1078 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Processor class for Florence-2.
17
- """
18
-
19
- import re
20
- import logging
21
- from typing import List, Optional, Union
22
- import numpy as np
23
-
24
- import torch
25
-
26
- from transformers.feature_extraction_utils import BatchFeature
27
- from transformers.image_utils import ImageInput, is_valid_image
28
- from transformers.processing_utils import ProcessorMixin
29
- from transformers.tokenization_utils_base import (
30
- PaddingStrategy,
31
- PreTokenizedInput,
32
- TextInput,
33
- TruncationStrategy,
34
- )
35
- from transformers.utils import TensorType
36
-
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
- # Copied from transformers.models.idefics2.processing_idefics2.is_url
41
- def is_url(val) -> bool:
42
- return isinstance(val, str) and val.startswith("http")
43
-
44
- # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
45
- def is_image_or_image_url(elem):
46
- return is_url(elem) or is_valid_image(elem)
47
-
48
-
49
- def _is_str_or_image(elem):
50
- return isinstance(elem, (str)) or is_image_or_image_url(elem)
51
-
52
-
53
- class Florence2Processor(ProcessorMixin):
54
- r"""
55
- Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
56
- [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
57
- [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
58
- Args:
59
- image_processor ([`CLIPImageProcessor`], *optional*):
60
- The image processor is a required input.
61
- tokenizer ([`BartTokenizerFast`], *optional*):
62
- The tokenizer is a required input.
63
- """
64
-
65
- attributes = ["image_processor", "tokenizer"]
66
- image_processor_class = "CLIPImageProcessor"
67
- tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
68
-
69
- def __init__(
70
- self,
71
- image_processor=None,
72
- tokenizer=None,
73
- ):
74
- if image_processor is None:
75
- raise ValueError("You need to specify an `image_processor`.")
76
- if tokenizer is None:
77
- raise ValueError("You need to specify a `tokenizer`.")
78
- if not hasattr(image_processor, "image_seq_length"):
79
- raise ValueError("Image processor is missing an `image_seq_length` attribute.")
80
-
81
- self.image_seq_length = image_processor.image_seq_length
82
-
83
- tokens_to_add = {
84
- 'additional_special_tokens': \
85
- tokenizer.additional_special_tokens + \
86
- ['<od>', '</od>', '<ocr>', '</ocr>'] + \
87
- [f'<loc_{x}>' for x in range(1000)] + \
88
- ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
89
- }
90
- tokenizer.add_special_tokens(tokens_to_add)
91
-
92
- self.tasks_answer_post_processing_type = {
93
- '<OCR>': 'pure_text',
94
- '<OCR_WITH_REGION>': 'ocr',
95
- '<CAPTION>': 'pure_text',
96
- '<DETAILED_CAPTION>': 'pure_text',
97
- '<MORE_DETAILED_CAPTION>': 'pure_text',
98
- '<OD>': 'description_with_bboxes',
99
- '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
100
- '<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
101
- '<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
102
- '<REGION_TO_SEGMENTATION>': 'polygons',
103
- '<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
104
- '<REGION_TO_CATEGORY>': 'pure_text',
105
- '<REGION_TO_DESCRIPTION>': 'pure_text',
106
- '<REGION_TO_OCR>': 'pure_text',
107
- '<REGION_PROPOSAL>': 'bboxes'
108
- }
109
-
110
- self.task_prompts_without_inputs = {
111
- '<OCR>': 'What is the text in the image?',
112
- '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
113
- '<CAPTION>': 'What does the image describe?',
114
- '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
115
- '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
116
- '<OD>': 'Locate the objects with category name in the image.',
117
- '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
118
- '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
119
- }
120
-
121
- self.task_prompts_with_input = {
122
- '<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
123
- '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
124
- '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
125
- '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
126
- '<REGION_TO_CATEGORY>': 'What is the region {input}?',
127
- '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
128
- '<REGION_TO_OCR>': 'What text is in the region {input}?',
129
- }
130
-
131
- self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
132
-
133
-
134
- super().__init__(image_processor, tokenizer)
135
-
136
- def _construct_prompts(self, text):
137
- # replace the task tokens with the task prompts if task token is in the text
138
- prompts = []
139
- for _text in text:
140
- # 1. fixed task prompts without additional inputs
141
- for task_token, task_prompt in self.task_prompts_without_inputs.items():
142
- if task_token in _text:
143
- assert _text == task_token, f"Task token {task_token} should be the only token in the text."
144
- _text = task_prompt
145
- break
146
- # 2. task prompts with additional inputs
147
- for task_token, task_prompt in self.task_prompts_with_input.items():
148
- if task_token in _text:
149
- _text = task_prompt.format(input=_text.replace(task_token, ''))
150
- break
151
- prompts.append(_text)
152
- return prompts
153
-
154
- def __call__(
155
- self,
156
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
157
- images: ImageInput = None,
158
- tokenize_newline_separately: bool = True,
159
- padding: Union[bool, str, PaddingStrategy] = False,
160
- truncation: Union[bool, str, TruncationStrategy] = None,
161
- max_length=None,
162
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
163
- do_resize: bool = None,
164
- do_normalize: bool = None,
165
- image_mean: Optional[Union[float, List[float]]] = None,
166
- image_std: Optional[Union[float, List[float]]] = None,
167
- data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
168
- input_data_format: Optional[
169
- Union[str, "ChannelDimension"] # noqa: F821
170
- ] = None,
171
- resample: "PILImageResampling" = None, # noqa: F821
172
- do_convert_rgb: bool = None,
173
- do_thumbnail: bool = None,
174
- do_align_long_axis: bool = None,
175
- do_rescale: bool = None,
176
- ) -> BatchFeature:
177
- """
178
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
179
- and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
180
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
181
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
182
- of the above two methods for more information.
183
- Args:
184
- text (`str`, `List[str]`, `List[List[str]]`):
185
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
186
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
187
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
188
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
189
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
190
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
191
- number of channels, H and W are image height and width.
192
- tokenize_newline_separately (`bool`, defaults to `True`):
193
- Adds a separately tokenized '\n' at the end of the prompt.
194
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
195
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
196
- index) among:
197
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
198
- sequence if provided).
199
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
200
- acceptable input length for the model if that argument is not provided.
201
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
202
- lengths).
203
- max_length (`int`, *optional*):
204
- Maximum length of the returned list and optionally padding length (see above).
205
- truncation (`bool`, *optional*):
206
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
207
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
208
- If set, will return tensors of a particular framework. Acceptable values are:
209
- - `'tf'`: Return TensorFlow `tf.constant` objects.
210
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
211
- - `'np'`: Return NumPy `np.ndarray` objects.
212
- - `'jax'`: Return JAX `jnp.ndarray` objects.
213
- Returns:
214
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
215
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
216
- is provided, the `input_ids` will also contain the suffix input ids.
217
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
218
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
219
- `None`).
220
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
221
- - **labels** -- Labels compatible with training if `suffix` is not None
222
- """
223
-
224
- return_token_type_ids = False
225
-
226
- if images is None:
227
- raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
228
- if text is None:
229
- logger.warning_once(
230
- "You are using Florence-2 without a text prompt."
231
- )
232
- text = ""
233
-
234
- if isinstance(text, List) and isinstance(images, List):
235
- if len(images) < len(text):
236
- raise ValueError(
237
- f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
238
- )
239
- if _is_str_or_image(text):
240
- text = [text]
241
- elif isinstance(text, list) and _is_str_or_image(text[0]):
242
- pass
243
-
244
- pixel_values = self.image_processor(
245
- images,
246
- do_resize=do_resize,
247
- do_normalize=do_normalize,
248
- return_tensors=return_tensors,
249
- image_mean=image_mean,
250
- image_std=image_std,
251
- input_data_format=input_data_format,
252
- data_format=data_format,
253
- resample=resample,
254
- do_convert_rgb=do_convert_rgb,
255
- )["pixel_values"]
256
-
257
- if max_length is not None:
258
- max_length -= self.image_seq_length # max_length has to account for the image tokens
259
-
260
- text = self._construct_prompts(text)
261
-
262
- inputs = self.tokenizer(
263
- text,
264
- return_tensors=return_tensors,
265
- padding=padding,
266
- max_length=max_length,
267
- truncation=truncation,
268
- return_token_type_ids=return_token_type_ids,
269
- )
270
-
271
- return_data = {**inputs, "pixel_values": pixel_values}
272
-
273
- if return_token_type_ids:
274
- labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
275
- return_data.update({"labels": labels})
276
- return BatchFeature(data=return_data)
277
-
278
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
279
- def batch_decode(self, *args, **kwargs):
280
- """
281
- This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
282
- refer to the docstring of this method for more information.
283
- """
284
- return self.tokenizer.batch_decode(*args, **kwargs)
285
-
286
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
287
- def decode(self, *args, **kwargs):
288
- """
289
- This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
290
- the docstring of this method for more information.
291
- """
292
- return self.tokenizer.decode(*args, **kwargs)
293
-
294
- @property
295
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
296
- def model_input_names(self):
297
- tokenizer_input_names = self.tokenizer.model_input_names
298
- image_processor_input_names = self.image_processor.model_input_names
299
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
300
-
301
- def post_process_generation(self, text, task, image_size):
302
- """
303
- Post-process the output of the model to each of the task outputs.
304
- Args:
305
- text (`str`): The text to post-process.
306
- task (`str`): The task to post-process the text for.
307
- image_size (`Tuple[int, int]`): The size of the image. height x width.
308
- """
309
-
310
- task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
311
- task_answer = self.post_processor(
312
- text=text,
313
- image_size=image_size,
314
- parse_tasks=task_answer_post_processing_type,
315
- )[task_answer_post_processing_type]
316
-
317
- if task_answer_post_processing_type == 'pure_text':
318
- final_answer = task_answer
319
- # remove the special tokens
320
- final_answer = final_answer.replace('<s>', '').replace('</s>', '')
321
- elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
322
- od_instances = task_answer
323
- bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
324
- labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
325
- final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
326
- elif task_answer_post_processing_type in ['ocr']:
327
- bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
328
- labels = [str(_od_instance['text']) for _od_instance in task_answer]
329
- final_answer = {'quad_boxes': bboxes, 'labels': labels}
330
- elif task_answer_post_processing_type in ['phrase_grounding']:
331
- bboxes = []
332
- labels = []
333
- for _grounded_phrase in task_answer:
334
- for _bbox in _grounded_phrase['bbox']:
335
- bboxes.append(_bbox)
336
- labels.append(_grounded_phrase['cat_name'])
337
- final_answer = {'bboxes': bboxes, 'labels': labels}
338
- elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
339
- labels = []
340
- polygons = []
341
- for result in task_answer:
342
- label = result['cat_name']
343
- _polygons = result['polygons']
344
- labels.append(label)
345
- polygons.append(_polygons)
346
- final_answer = {'polygons': polygons, 'labels': labels}
347
- elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
348
- bboxes = []
349
- bboxes_labels = []
350
- polygons = []
351
- polygons_labels = []
352
- for result in task_answer:
353
- label = result['cat_name']
354
- if 'polygons' in result:
355
- _polygons = result['polygons']
356
- polygons.append(_polygons)
357
- polygons_labels.append(label)
358
- else:
359
- _bbox = result['bbox']
360
- bboxes.append(_bbox)
361
- bboxes_labels.append(label)
362
- final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
363
- else:
364
- raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
365
-
366
- final_answer = {
367
- task: final_answer}
368
- return final_answer
369
-
370
- class BoxQuantizer(object):
371
- def __init__(self, mode, bins):
372
- self.mode = mode
373
- self.bins = bins
374
-
375
- def quantize(self, boxes: torch.Tensor, size):
376
- bins_w, bins_h = self.bins # Quantization bins.
377
- size_w, size_h = size # Original image size.
378
- size_per_bin_w = size_w / bins_w
379
- size_per_bin_h = size_h / bins_h
380
- xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
381
-
382
- if self.mode == 'floor':
383
- quantized_xmin = (
384
- xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
385
- quantized_ymin = (
386
- ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
387
- quantized_xmax = (
388
- xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
389
- quantized_ymax = (
390
- ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
391
-
392
- elif self.mode == 'round':
393
- raise NotImplementedError()
394
-
395
- else:
396
- raise ValueError('Incorrect quantization type.')
397
-
398
- quantized_boxes = torch.cat(
399
- (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
400
- ).int()
401
-
402
- return quantized_boxes
403
-
404
- def dequantize(self, boxes: torch.Tensor, size):
405
- bins_w, bins_h = self.bins # Quantization bins.
406
- size_w, size_h = size # Original image size.
407
- size_per_bin_w = size_w / bins_w
408
- size_per_bin_h = size_h / bins_h
409
- xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
410
-
411
- if self.mode == 'floor':
412
- # Add 0.5 to use the center position of the bin as the coordinate.
413
- dequantized_xmin = (xmin + 0.5) * size_per_bin_w
414
- dequantized_ymin = (ymin + 0.5) * size_per_bin_h
415
- dequantized_xmax = (xmax + 0.5) * size_per_bin_w
416
- dequantized_ymax = (ymax + 0.5) * size_per_bin_h
417
-
418
- elif self.mode == 'round':
419
- raise NotImplementedError()
420
-
421
- else:
422
- raise ValueError('Incorrect quantization type.')
423
-
424
- dequantized_boxes = torch.cat(
425
- (dequantized_xmin, dequantized_ymin,
426
- dequantized_xmax, dequantized_ymax), dim=-1
427
- )
428
-
429
- return dequantized_boxes
430
-
431
-
432
- class CoordinatesQuantizer(object):
433
- """
434
- Quantize coornidates (Nx2)
435
- """
436
-
437
- def __init__(self, mode, bins):
438
- self.mode = mode
439
- self.bins = bins
440
-
441
- def quantize(self, coordinates: torch.Tensor, size):
442
- bins_w, bins_h = self.bins # Quantization bins.
443
- size_w, size_h = size # Original image size.
444
- size_per_bin_w = size_w / bins_w
445
- size_per_bin_h = size_h / bins_h
446
- assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
447
- x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
448
-
449
- if self.mode == 'floor':
450
- quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
451
- quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
452
-
453
- elif self.mode == 'round':
454
- raise NotImplementedError()
455
-
456
- else:
457
- raise ValueError('Incorrect quantization type.')
458
-
459
- quantized_coordinates = torch.cat(
460
- (quantized_x, quantized_y), dim=-1
461
- ).int()
462
-
463
- return quantized_coordinates
464
-
465
- def dequantize(self, coordinates: torch.Tensor, size):
466
- bins_w, bins_h = self.bins # Quantization bins.
467
- size_w, size_h = size # Original image size.
468
- size_per_bin_w = size_w / bins_w
469
- size_per_bin_h = size_h / bins_h
470
- assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
471
- x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
472
-
473
- if self.mode == 'floor':
474
- # Add 0.5 to use the center position of the bin as the coordinate.
475
- dequantized_x = (x + 0.5) * size_per_bin_w
476
- dequantized_y = (y + 0.5) * size_per_bin_h
477
-
478
- elif self.mode == 'round':
479
- raise NotImplementedError()
480
-
481
- else:
482
- raise ValueError('Incorrect quantization type.')
483
-
484
- dequantized_coordinates = torch.cat(
485
- (dequantized_x, dequantized_y), dim=-1
486
- )
487
-
488
- return dequantized_coordinates
489
-
490
-
491
- class Florence2PostProcesser(object):
492
- """
493
- Florence-2 post process for converting text prediction to various tasks results.
494
- Args:
495
- config: A dict of configs.
496
- tokenizer: A tokenizer for decoding text to spans.
497
- sample config:
498
- UNIFIED_POST_PROCESS:
499
- # commom configs
500
- NUM_BBOX_HEIGHT_BINS: 1000
501
- NUM_BBOX_WIDTH_BINS: 1000
502
- COORDINATES_HEIGHT_BINS: 1000
503
- COORDINATES_WIDTH_BINS: 1000
504
- # task specific configs, override the common configs
505
- PRASE_TASKS:
506
- - TASK_NAME: 'video_dense_caption'
507
- PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
508
- SCORE_MODE: 'avg_cat_name_scores'
509
- NUM_BINS: 100
510
- - TASK_NAME: 'od'
511
- PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
512
- SCORE_MODE: 'avg_cat_name_scores'
513
- Returns:
514
- parsed_dict (dict): A dict of parsed results.
515
- """
516
- def __init__(
517
- self,
518
- tokenizer=None
519
- ):
520
- parse_tasks = []
521
- parse_task_configs = {}
522
- config = self._create_default_config()
523
- for task in config['PARSE_TASKS']:
524
- parse_tasks.append(task['TASK_NAME'])
525
- parse_task_configs[task['TASK_NAME']] = task
526
-
527
- self.config = config
528
- self.parse_tasks = parse_tasks
529
- self.parse_tasks_configs = parse_task_configs
530
-
531
- self.tokenizer = tokenizer
532
- if self.tokenizer is not None:
533
- self.all_special_tokens = set(self.tokenizer.all_special_tokens)
534
-
535
- self.init_quantizers()
536
- self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
537
-
538
- def _create_black_list_of_phrase_grounding(self):
539
- black_list = {}
540
-
541
- if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
542
- black_list = set(
543
- ['it', 'I', 'me', 'mine',
544
- 'you', 'your', 'yours',
545
- 'he', 'him', 'his',
546
- 'she', 'her', 'hers',
547
- 'they', 'them', 'their', 'theirs',
548
- 'one', 'oneself',
549
- 'we', 'us', 'our', 'ours',
550
- 'you', 'your', 'yours',
551
- 'they', 'them', 'their', 'theirs',
552
- 'mine', 'yours', 'his', 'hers', 'its',
553
- 'ours', 'yours', 'theirs',
554
- 'myself', 'yourself', 'himself', 'herself', 'itself',
555
- 'ourselves', 'yourselves', 'themselves',
556
- 'this', 'that',
557
- 'these', 'those',
558
- 'who', 'whom', 'whose', 'which', 'what',
559
- 'who', 'whom', 'whose', 'which', 'that',
560
- 'all', 'another', 'any', 'anybody', 'anyone', 'anything',
561
- 'each', 'everybody', 'everyone', 'everything',
562
- 'few', 'many', 'nobody', 'none', 'one', 'several',
563
- 'some', 'somebody', 'someone', 'something',
564
- 'each other', 'one another',
565
- 'myself', 'yourself', 'himself', 'herself', 'itself',
566
- 'ourselves', 'yourselves', 'themselves',
567
- 'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
568
- 'other objects', 'lots', 'a set',
569
- ]
570
- )
571
-
572
- return black_list
573
-
574
- def _create_default_config(self):
575
- config = {
576
- 'NUM_BBOX_HEIGHT_BINS': 1000,
577
- 'NUM_BBOX_WIDTH_BINS': 1000,
578
- 'BOX_QUANTIZATION_MODE': 'floor',
579
- 'COORDINATES_HEIGHT_BINS': 1000,
580
- 'COORDINATES_WIDTH_BINS': 1000,
581
- 'COORDINATES_QUANTIZATION_MODE': 'floor',
582
- 'PARSE_TASKS': [
583
- {
584
- 'TASK_NAME': 'od',
585
- 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
586
- },
587
- {
588
- 'TASK_NAME': 'ocr',
589
- 'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
590
- 'AREA_THRESHOLD': 0.00
591
- },
592
- {
593
- 'TASK_NAME': 'phrase_grounding',
594
- 'FILTER_BY_BLACK_LIST': True
595
- },
596
- {
597
- 'TASK_NAME': 'pure_text',
598
- },
599
- {
600
- 'TASK_NAME': 'description_with_bboxes',
601
- },
602
- {
603
- 'TASK_NAME': 'description_with_polygons',
604
- },
605
- {
606
- 'TASK_NAME': 'polygons',
607
- },
608
- {
609
- 'TASK_NAME': 'bboxes',
610
- },
611
- {
612
- 'TASK_NAME': 'description_with_bboxes_or_polygons',
613
- }
614
- ]
615
- }
616
-
617
- return config
618
-
619
- def init_quantizers(self):
620
- # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
621
- num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
622
- num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
623
- box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
624
- self.box_quantizer = BoxQuantizer(
625
- box_quantization_mode,
626
- (num_bbox_width_bins, num_bbox_height_bins),
627
- )
628
-
629
- num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
630
- num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
631
- box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
632
- self.coordinates_quantizer = CoordinatesQuantizer(
633
- box_quantization_mode,
634
- (num_bbox_width_bins, num_bbox_height_bins),
635
- )
636
-
637
- def decode_with_spans(self, tokenizer, token_ids):
638
- filtered_tokens = tokenizer.convert_ids_to_tokens(
639
- token_ids, skip_special_tokens=False)
640
- assert len(filtered_tokens) == len(token_ids)
641
-
642
- # To avoid mixing byte-level and unicode for byte-level BPT
643
- # we need to build string separately for added tokens and byte-level tokens
644
- # cf. https://github.com/huggingface/transformers/issues/1133
645
- sub_texts = []
646
- for token in filtered_tokens:
647
- if token in self.all_special_tokens:
648
- sub_texts.append(token)
649
- else:
650
- if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
651
- sub_text = tokenizer.convert_tokens_to_string([token])
652
- elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
653
- # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
654
- # Note: Do not strip sub_text as it may have functional whitespace
655
- sub_text = token.replace('▁', ' ')
656
- else:
657
- raise ValueError(f'type {type(tokenizer)} not supported')
658
- sub_texts.append(sub_text)
659
-
660
- text = ''
661
- spans = []
662
- for sub_text in sub_texts:
663
- span = (len(text), len(text) + len(sub_text)) # [start index, end index).
664
- text += sub_text
665
- spans.append(span)
666
-
667
- # Text format:
668
- # 1. T5Tokenizer/T5TokenizerFast:
669
- # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
670
- # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
671
- # 2. BartTokenizer (need to double check):
672
- # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
673
- # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
674
- return text, spans
675
-
676
- def parse_od_from_text_and_spans(
677
- self,
678
- text,
679
- pattern,
680
- image_size,
681
- phrase_centric=False
682
- ):
683
- parsed = list(re.finditer(pattern, text))
684
-
685
- instances = []
686
- for i in range(len(parsed)):
687
- # Prepare instance.
688
- instance = {}
689
-
690
- if phrase_centric:
691
- bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
692
- else:
693
- bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
694
- instance['bbox'] = self.box_quantizer.dequantize(
695
- boxes=torch.tensor(bbox_bins),
696
- size=image_size
697
- ).tolist()
698
-
699
- if phrase_centric:
700
- instance['cat_name'] = parsed[i].group(1).lower().strip()
701
- else:
702
- instance['cat_name'] = parsed[i].group(5).lower().strip()
703
- instances.append(instance)
704
-
705
- return instances
706
-
707
- def parse_ocr_from_text_and_spans(self,
708
- text,
709
- pattern,
710
- image_size,
711
- area_threshold=-1.0,
712
- ):
713
- bboxes = []
714
- labels = []
715
- text = text.replace('<s>', '')
716
- # ocr with regions
717
- parsed = re.findall(pattern, text)
718
- instances = []
719
- image_width, image_height = image_size
720
-
721
- for ocr_line in parsed:
722
- ocr_content = ocr_line[0]
723
- quad_box = ocr_line[1:]
724
- quad_box = [int(i) for i in quad_box]
725
- quad_box = self.coordinates_quantizer.dequantize(
726
- torch.tensor(np.array(quad_box).reshape(-1, 2)),
727
- size=image_size
728
- ).reshape(-1).tolist()
729
-
730
- if area_threshold > 0:
731
- x_coords = [i for i in quad_box[0::2]]
732
- y_coords = [i for i in quad_box[1::2]]
733
-
734
- # apply the Shoelace formula
735
- area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
736
-
737
- if area < (image_width * image_height) * area_threshold:
738
- continue
739
-
740
- bboxes.append(quad_box)
741
- labels.append(ocr_content)
742
- instances.append({
743
- 'quad_box': quad_box,
744
- 'text': ocr_content,
745
- })
746
- return instances
747
-
748
- def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
749
- # ignore <s> </s> and <pad>
750
- cur_span = 0
751
- if text.startswith('<s>'):
752
- cur_span += 3
753
-
754
- text = text.replace('<s>', '')
755
- text = text.replace('</s>', '')
756
- text = text.replace('<pad>', '')
757
-
758
- pattern = r"([^<]+(?:<loc_\d+>){4,})"
759
- phrases = re.findall(pattern, text)
760
-
761
- # pattern should be text pattern and od pattern
762
- pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
763
- box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
764
-
765
- instances = []
766
- for pharse_text in phrases:
767
- phrase_text_strip = pharse_text.replace('<ground>', '', 1)
768
- phrase_text_strip = pharse_text.replace('<obj>', '', 1)
769
-
770
- if phrase_text_strip == '':
771
- cur_span += len(pharse_text)
772
- continue
773
-
774
- # Prepare instance.
775
- instance = {}
776
-
777
- # parse phrase, get string
778
- phrase = re.search(pattern, phrase_text_strip)
779
- if phrase is None:
780
- cur_span += len(pharse_text)
781
- continue
782
-
783
- # parse bboxes by box_pattern
784
- bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
785
- if len(bboxes_parsed) == 0:
786
- cur_span += len(pharse_text)
787
- continue
788
-
789
- phrase = phrase.group()
790
- # remove leading and trailing spaces
791
- phrase = phrase.strip()
792
-
793
- if phrase in self.black_list_of_phrase_grounding:
794
- cur_span += len(pharse_text)
795
- continue
796
-
797
- # a list of list
798
- bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
799
- instance['bbox'] = self.box_quantizer.dequantize(
800
- boxes=torch.tensor(bbox_bins),
801
- size=image_size
802
- ).tolist()
803
-
804
- # exclude non-ascii characters
805
- phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
806
- instance['cat_name'] = phrase
807
-
808
- instances.append(instance)
809
-
810
- return instances
811
-
812
- def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
813
- # temporary parse solution, split by '.'
814
- # ignore <s> </s> and <pad>
815
-
816
- text = text.replace('<s>', '')
817
- text = text.replace('</s>', '')
818
- text = text.replace('<pad>', '')
819
-
820
- if allow_empty_phrase:
821
- pattern = rf"(?:(?:<loc_\d+>){{4,}})"
822
- else:
823
- pattern = r"([^<]+(?:<loc_\d+>){4,})"
824
- phrases = re.findall(pattern, text)
825
-
826
- # pattern should be text pattern and od pattern
827
- pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
828
- box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
829
-
830
- instances = []
831
- for pharse_text in phrases:
832
- phrase_text_strip = pharse_text.replace('<ground>', '', 1)
833
- phrase_text_strip = pharse_text.replace('<obj>', '', 1)
834
-
835
- if phrase_text_strip == '' and not allow_empty_phrase:
836
- continue
837
-
838
- # parse phrase, get string
839
- phrase = re.search(pattern, phrase_text_strip)
840
- if phrase is None:
841
- continue
842
-
843
- phrase = phrase.group()
844
- # remove leading and trailing spaces
845
- phrase = phrase.strip()
846
-
847
- # parse bboxes by box_pattern
848
- bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
849
- if len(bboxes_parsed) == 0:
850
- continue
851
-
852
- # a list of list
853
- bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
854
-
855
- bboxes = self.box_quantizer.dequantize(
856
- boxes=torch.tensor(bbox_bins),
857
- size=image_size
858
- ).tolist()
859
-
860
- phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
861
- for _bboxes in bboxes:
862
- # Prepare instance.
863
- instance = {}
864
- instance['bbox'] = _bboxes
865
- # exclude non-ascii characters
866
- instance['cat_name'] = phrase
867
- instances.append(instance)
868
-
869
- return instances
870
-
871
- def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
872
- allow_empty_phrase=False,
873
- polygon_sep_token='<sep>',
874
- polygon_start_token='<poly>',
875
- polygon_end_token='</poly>',
876
- with_box_at_start=False,
877
- ):
878
-
879
- # ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
880
- # ignore <s> </s> and <pad>
881
-
882
- text = text.replace('<s>', '')
883
- text = text.replace('</s>', '')
884
- text = text.replace('<pad>', '')
885
-
886
- if allow_empty_phrase:
887
- pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
888
- else:
889
- # [^<]+: This part matches one or more characters that are not the < symbol.
890
- # The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
891
- #
892
- pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
893
- phrases = re.findall(pattern, text)
894
-
895
- phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
896
- box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
897
-
898
- # one polygons instance is separated by polygon_start_token and polygon_end_token
899
- polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
900
-
901
- instances = []
902
- for phrase_text in phrases:
903
-
904
- # exclude loc_\d+>
905
- # need to get span if want to include category score
906
- phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
907
-
908
- # phrase = phrase.replace('<poly>', '')
909
- # phrase = phrase.replace('poly>', '')
910
-
911
- if phrase_text_strip == '' and not allow_empty_phrase:
912
- continue
913
-
914
-
915
- # parse phrase, get string
916
- phrase = re.search(phrase_string_pattern, phrase_text_strip)
917
- if phrase is None:
918
- continue
919
- phrase = phrase.group()
920
- # remove leading and trailing spaces
921
- phrase = phrase.strip()
922
-
923
- # parse bboxes by box_pattern
924
-
925
- # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
926
- if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
927
- polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
928
- else:
929
- polygons_instances_parsed = [phrase_text]
930
-
931
- for _polygons_instances_parsed in polygons_instances_parsed:
932
- # Prepare instance.
933
- instance = {}
934
-
935
- # polygons_parsed= list(re.finditer(box_pattern, phrase_text))
936
- if isinstance(_polygons_instances_parsed, str):
937
- polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
938
- else:
939
- polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
940
- if len(polygons_parsed) == 0:
941
- continue
942
-
943
- # a list of list (polygon)
944
- bbox = []
945
- polygons = []
946
- for _polygon_parsed in polygons_parsed:
947
- # group 1: whole <loc_\d+>...</loc_\d+>
948
- _polygon = _polygon_parsed.group(1)
949
- # parse into list of int
950
- _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
951
- if with_box_at_start and len(bbox) == 0:
952
- if len(_polygon) > 4:
953
- # no valid bbox prediction
954
- bbox = _polygon[:4]
955
- _polygon = _polygon[4:]
956
- else:
957
- bbox = [0, 0, 0, 0]
958
- # abandon last element if is not paired
959
- if len(_polygon) % 2 == 1:
960
- _polygon = _polygon[:-1]
961
-
962
- # reshape into (n, 2)
963
- _polygon = self.coordinates_quantizer.dequantize(
964
- torch.tensor(np.array(_polygon).reshape(-1, 2)),
965
- size=image_size
966
- ).reshape(-1).tolist()
967
- # reshape back
968
- polygons.append(_polygon)
969
-
970
- instance['cat_name'] = phrase
971
- instance['polygons'] = polygons
972
- if len(bbox) != 0:
973
- instance['bbox'] = self.box_quantizer.dequantize(
974
- boxes=torch.tensor([bbox]),
975
- size=image_size
976
- ).tolist()[0]
977
-
978
- instances.append(instance)
979
-
980
- return instances
981
-
982
- def __call__(
983
- self,
984
- text=None,
985
- image_size=None,
986
- parse_tasks=None,
987
- ):
988
- """
989
- Args:
990
- text: model outputs
991
- image_size: (width, height)
992
- parse_tasks: a list of tasks to parse, if None, parse all tasks.
993
- """
994
- if parse_tasks is not None:
995
- if isinstance(parse_tasks, str):
996
- parse_tasks = [parse_tasks]
997
- for _parse_task in parse_tasks:
998
- assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
999
-
1000
- # sequence or text should be provided
1001
- assert text is not None, 'text should be provided'
1002
-
1003
- parsed_dict = {
1004
- 'text': text
1005
- }
1006
-
1007
- for task in self.parse_tasks:
1008
- if parse_tasks is not None and task not in parse_tasks:
1009
- continue
1010
-
1011
- pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1012
-
1013
- if task == 'ocr':
1014
- instances = self.parse_ocr_from_text_and_spans(
1015
- text,
1016
- pattern=pattern,
1017
- image_size=image_size,
1018
- area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0),
1019
- )
1020
- parsed_dict['ocr'] = instances
1021
- elif task == 'phrase_grounding':
1022
- instances = self.parse_phrase_grounding_from_text_and_spans(
1023
- text,
1024
- pattern=pattern,
1025
- image_size=image_size,
1026
- )
1027
- parsed_dict['phrase_grounding'] = instances
1028
- elif task == 'pure_text':
1029
- parsed_dict['pure_text'] = text
1030
- elif task == 'description_with_bboxes':
1031
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1032
- text,
1033
- pattern=pattern,
1034
- image_size=image_size,
1035
- )
1036
- parsed_dict['description_with_bboxes'] = instances
1037
- elif task == 'description_with_polygons':
1038
- instances = self.parse_description_with_polygons_from_text_and_spans(
1039
- text,
1040
- pattern=pattern,
1041
- image_size=image_size,
1042
- )
1043
- parsed_dict['description_with_polygons'] = instances
1044
- elif task == 'polygons':
1045
- instances = self.parse_description_with_polygons_from_text_and_spans(
1046
- text,
1047
- pattern=pattern,
1048
- image_size=image_size,
1049
- allow_empty_phrase=True,
1050
- )
1051
- parsed_dict['polygons'] = instances
1052
- elif task == 'bboxes':
1053
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1054
- text,
1055
- pattern=pattern,
1056
- image_size=image_size,
1057
- allow_empty_phrase=True,
1058
- )
1059
- parsed_dict['bboxes'] = instances
1060
- elif task == 'description_with_bboxes_or_polygons':
1061
- if '<poly>' in text:
1062
- # only support either polygons or bboxes, not both at the same time
1063
- instances = self.parse_description_with_polygons_from_text_and_spans(
1064
- text,
1065
- pattern=pattern,
1066
- image_size=image_size,
1067
- )
1068
- else:
1069
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1070
- text,
1071
- pattern=pattern,
1072
- image_size=image_size,
1073
- )
1074
- parsed_dict['description_with_bboxes_or_polygons'] = instances
1075
- else:
1076
- raise ValueError("task {} is not supported".format(task))
1077
-
1078
- return parsed_dict