add confidence scores for `od` and `description_with_bboxes` tasks

#25
Files changed (1) hide show
  1. processing_florence2.py +83 -23
processing_florence2.py CHANGED
@@ -20,6 +20,7 @@ import re
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
 
23
 
24
  import torch
25
 
@@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
32
  TextInput,
33
  TruncationStrategy,
34
  )
 
35
  from transformers.utils import TensorType
36
 
37
 
@@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
304
  image_processor_input_names = self.image_processor.model_input_names
305
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
 
307
- def post_process_generation(self, text, task, image_size):
308
  """
309
  Post-process the output of the model to each of the task outputs.
310
 
@@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
317
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
  task_answer = self.post_processor(
319
  text=text,
 
 
320
  image_size=image_size,
321
  parse_tasks=task_answer_post_processing_type,
322
  )[task_answer_post_processing_type]
@@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
330
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
 
 
 
333
  elif task_answer_post_processing_type in ['ocr']:
334
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
@@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
591
  'PARSE_TASKS': [
592
  {
593
  'TASK_NAME': 'od',
594
- 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
 
595
  },
596
  {
597
  'TASK_NAME': 'ocr',
@@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
607
  },
608
  {
609
  'TASK_NAME': 'description_with_bboxes',
 
610
  },
611
  {
612
  'TASK_NAME': 'description_with_polygons',
@@ -648,9 +657,6 @@ class Florence2PostProcesser(object):
648
  token_ids, skip_special_tokens=False)
649
  assert len(filtered_tokens) == len(token_ids)
650
 
651
- # To avoid mixing byte-level and unicode for byte-level BPT
652
- # we need to build string separately for added tokens and byte-level tokens
653
- # cf. https://github.com/huggingface/transformers/issues/1133
654
  sub_texts = []
655
  for token in filtered_tokens:
656
  if token in self.all_special_tokens:
@@ -658,10 +664,6 @@ class Florence2PostProcesser(object):
658
  else:
659
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
  sub_text = tokenizer.convert_tokens_to_string([token])
661
- elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
- # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
- # Note: Do not strip sub_text as it may have functional whitespace
664
- sub_text = token.replace('▁', ' ')
665
  else:
666
  raise ValueError(f'type {type(tokenizer)} not supported')
667
  sub_texts.append(sub_text)
@@ -673,13 +675,6 @@ class Florence2PostProcesser(object):
673
  text += sub_text
674
  spans.append(span)
675
 
676
- # Text format:
677
- # 1. T5Tokenizer/T5TokenizerFast:
678
- # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
- # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
- # 2. BartTokenizer (need to double check):
681
- # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
- # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
  return text, spans
684
 
685
  def parse_od_from_text_and_spans(
@@ -714,7 +709,7 @@ class Florence2PostProcesser(object):
714
  return instances
715
 
716
  def parse_ocr_from_text_and_spans(self,
717
- text,
718
  pattern,
719
  image_size,
720
  area_threshold=-1.0,
@@ -818,9 +813,26 @@ class Florence2PostProcesser(object):
818
 
819
  return instances
820
 
821
- def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
- # temporary parse solution, split by '.'
823
- # ignore <s> </s> and <pad>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
  text = text.replace('<s>', '')
826
  text = text.replace('</s>', '')
@@ -842,13 +854,16 @@ class Florence2PostProcesser(object):
842
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
 
844
  if phrase_text_strip == '' and not allow_empty_phrase:
 
845
  continue
846
 
847
  # parse phrase, get string
848
  phrase = re.search(pattern, phrase_text_strip)
849
  if phrase is None:
 
850
  continue
851
 
 
852
  phrase = phrase.group()
853
  # remove leading and trailing spaces
854
  phrase = phrase.strip()
@@ -856,6 +871,7 @@ class Florence2PostProcesser(object):
856
  # parse bboxes by box_pattern
857
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
  if len(bboxes_parsed) == 0:
 
859
  continue
860
 
861
  # a list of list
@@ -866,14 +882,42 @@ class Florence2PostProcesser(object):
866
  size=image_size
867
  ).tolist()
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
- for _bboxes in bboxes:
871
  # Prepare instance.
872
  instance = {}
873
  instance['bbox'] = _bboxes
874
  # exclude non-ascii characters
875
  instance['cat_name'] = phrase
 
 
876
  instances.append(instance)
 
 
877
 
878
  return instances
879
 
@@ -991,6 +1035,8 @@ class Florence2PostProcesser(object):
991
  def __call__(
992
  self,
993
  text=None,
 
 
994
  image_size=None,
995
  parse_tasks=None,
996
  ):
@@ -999,7 +1045,6 @@ class Florence2PostProcesser(object):
999
  text: model outputs
1000
  image_size: (width, height)
1001
  parse_tasks: a list of tasks to parse, if None, parse all tasks.
1002
-
1003
  """
1004
  if parse_tasks is not None:
1005
  if isinstance(parse_tasks, str):
@@ -1008,7 +1053,18 @@ class Florence2PostProcesser(object):
1008
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
 
1010
  # sequence or text should be provided
1011
- assert text is not None, 'text should be provided'
 
 
 
 
 
 
 
 
 
 
 
1012
 
1013
  parsed_dict = {
1014
  'text': text
@@ -1019,6 +1075,7 @@ class Florence2PostProcesser(object):
1019
  continue
1020
 
1021
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
 
1022
 
1023
  if task == 'ocr':
1024
  instances = self.parse_ocr_from_text_and_spans(
@@ -1040,6 +1097,9 @@ class Florence2PostProcesser(object):
1040
  elif task == 'description_with_bboxes':
1041
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
  text,
 
 
 
1043
  pattern=pattern,
1044
  image_size=image_size,
1045
  )
 
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
23
+ import math
24
 
25
  import torch
26
 
 
33
  TextInput,
34
  TruncationStrategy,
35
  )
36
+ from transformers import BartTokenizer, BartTokenizerFast
37
  from transformers.utils import TensorType
38
 
39
 
 
306
  image_processor_input_names = self.image_processor.model_input_names
307
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
308
 
309
+ def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
310
  """
311
  Post-process the output of the model to each of the task outputs.
312
 
 
319
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
320
  task_answer = self.post_processor(
321
  text=text,
322
+ sequence=sequence,
323
+ transition_beam_score=transition_beam_score,
324
  image_size=image_size,
325
  parse_tasks=task_answer_post_processing_type,
326
  )[task_answer_post_processing_type]
 
334
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
335
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
336
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
337
+ if len(od_instances) and 'score' in od_instances[0]:
338
+ scores_od = [_od_instance['score'] for _od_instance in od_instances]
339
+ final_answer['scores'] = scores_od
340
  elif task_answer_post_processing_type in ['ocr']:
341
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
342
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
 
598
  'PARSE_TASKS': [
599
  {
600
  'TASK_NAME': 'od',
601
+ 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
602
+ 'SCORE_MODE': 'avg_loc_scores'
603
  },
604
  {
605
  'TASK_NAME': 'ocr',
 
615
  },
616
  {
617
  'TASK_NAME': 'description_with_bboxes',
618
+ 'SCORE_MODE': 'avg_loc_scores'
619
  },
620
  {
621
  'TASK_NAME': 'description_with_polygons',
 
657
  token_ids, skip_special_tokens=False)
658
  assert len(filtered_tokens) == len(token_ids)
659
 
 
 
 
660
  sub_texts = []
661
  for token in filtered_tokens:
662
  if token in self.all_special_tokens:
 
664
  else:
665
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
666
  sub_text = tokenizer.convert_tokens_to_string([token])
 
 
 
 
667
  else:
668
  raise ValueError(f'type {type(tokenizer)} not supported')
669
  sub_texts.append(sub_text)
 
675
  text += sub_text
676
  spans.append(span)
677
 
 
 
 
 
 
 
 
678
  return text, spans
679
 
680
  def parse_od_from_text_and_spans(
 
709
  return instances
710
 
711
  def parse_ocr_from_text_and_spans(self,
712
+ text,
713
  pattern,
714
  image_size,
715
  area_threshold=-1.0,
 
813
 
814
  return instances
815
 
816
+ def parse_description_with_bboxes_from_text_and_spans(
817
+ self,
818
+ text,
819
+ spans=None,
820
+ scores=None,
821
+ score_mode=None,
822
+ pattern=None,
823
+ image_size=None,
824
+ allow_empty_phrase=False
825
+ ):
826
+ def find_matched_token_indices(cur_span, token_spans):
827
+ inds = []
828
+ for i, token_span in enumerate(token_spans):
829
+ if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
830
+ inds.append(i)
831
+ return inds
832
+
833
+ cur_span = 0
834
+ if text.startswith('<s>'):
835
+ cur_span += 3
836
 
837
  text = text.replace('<s>', '')
838
  text = text.replace('</s>', '')
 
854
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
855
 
856
  if phrase_text_strip == '' and not allow_empty_phrase:
857
+ cur_span += len(pharse_text)
858
  continue
859
 
860
  # parse phrase, get string
861
  phrase = re.search(pattern, phrase_text_strip)
862
  if phrase is None:
863
+ cur_span += len(pharse_text)
864
  continue
865
 
866
+ phrase_span = phrase.span()
867
  phrase = phrase.group()
868
  # remove leading and trailing spaces
869
  phrase = phrase.strip()
 
871
  # parse bboxes by box_pattern
872
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
873
  if len(bboxes_parsed) == 0:
874
+ cur_span += len(pharse_text)
875
  continue
876
 
877
  # a list of list
 
882
  size=image_size
883
  ).tolist()
884
 
885
+ if score_mode == 'avg_loc_scores':
886
+ if spans is None or scores is None:
887
+ all_scores = None
888
+ else:
889
+ bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
890
+ all_scores = []
891
+ for _spans in bbox_end_spans:
892
+ token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
893
+ loc_scores = [scores[token_i] for token_i in token_inds]
894
+ score = sum(loc_scores) / len(loc_scores)
895
+ all_scores.append(score)
896
+ elif score_mode == 'avg_cat_name_scores':
897
+ if spans is None or scores is None:
898
+ all_scores = None
899
+ else:
900
+ cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
901
+ cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
902
+ score = sum(cat_name_scores) / len(cat_name_scores)
903
+ all_scores = [score] * len(bboxes)
904
+ elif score_mode is None:
905
+ all_scores = None
906
+ else:
907
+ raise ValueError('Unknown score mode: {}'.format(score_mode))
908
+
909
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
910
+ for _idx, _bboxes in enumerate(bboxes):
911
  # Prepare instance.
912
  instance = {}
913
  instance['bbox'] = _bboxes
914
  # exclude non-ascii characters
915
  instance['cat_name'] = phrase
916
+ if all_scores is not None:
917
+ instance['score'] = math.exp(all_scores[_idx])
918
  instances.append(instance)
919
+
920
+ cur_span += len(pharse_text)
921
 
922
  return instances
923
 
 
1035
  def __call__(
1036
  self,
1037
  text=None,
1038
+ sequence=None,
1039
+ transition_beam_score=None,
1040
  image_size=None,
1041
  parse_tasks=None,
1042
  ):
 
1045
  text: model outputs
1046
  image_size: (width, height)
1047
  parse_tasks: a list of tasks to parse, if None, parse all tasks.
 
1048
  """
1049
  if parse_tasks is not None:
1050
  if isinstance(parse_tasks, str):
 
1053
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1054
 
1055
  # sequence or text should be provided
1056
+ assert sequence is not None or text is not None, 'sequence or text should be provided'
1057
+ assert sequence is None or text is None, 'only one of sequence and text should be provided'
1058
+
1059
+ if sequence is not None:
1060
+ sequence = sequence.tolist()[1:]
1061
+ text, spans = self.decode_with_spans(self.tokenizer, sequence)
1062
+ if transition_beam_score is not None:
1063
+ transition_beam_score = transition_beam_score.tolist()
1064
+ assert len(sequence) == len(transition_beam_score)
1065
+ else:
1066
+ spans = None
1067
+ transition_beam_score = None
1068
 
1069
  parsed_dict = {
1070
  'text': text
 
1075
  continue
1076
 
1077
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1078
+ score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
1079
 
1080
  if task == 'ocr':
1081
  instances = self.parse_ocr_from_text_and_spans(
 
1097
  elif task == 'description_with_bboxes':
1098
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1099
  text,
1100
+ spans=spans,
1101
+ scores=transition_beam_score,
1102
+ score_mode=score_mode,
1103
  pattern=pattern,
1104
  image_size=image_size,
1105
  )