mshukor HF staff commited on
Commit
babacde
·
1 Parent(s): 76f9c23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -19,6 +19,9 @@ from tasks.mm_tasks.caption import CaptionTask
19
  from tasks.mm_tasks.refcoco import RefcocoTask
20
  from tasks.mm_tasks.vqa_gen import VqaGenTask
21
 
 
 
 
22
  # video
23
  from data.video_utils import VIDEO_READER_FUNCS
24
 
@@ -175,8 +178,8 @@ move2gpu(audio_caption_models, general_cfg)
175
  caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
176
  refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
177
  vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
178
- vqa_generator.zero_shot = True
179
- vqa_generator.constraint_trie = None
180
  general_generator = general_task.build_generator(general_models, general_cfg.generation)
181
 
182
  video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
@@ -449,8 +452,13 @@ def inference(image, audio, video, task_type, instruction):
449
 
450
  # Generate result
451
  with torch.no_grad():
452
- hypos = task.inference_step(generator, models, sample)
453
- tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
 
 
 
 
 
454
 
455
  if bins.strip() != '':
456
  w, h = image.size
 
19
  from tasks.mm_tasks.refcoco import RefcocoTask
20
  from tasks.mm_tasks.vqa_gen import VqaGenTask
21
 
22
+
23
+ from utils.zero_shot_utils import zero_shot_step
24
+
25
  # video
26
  from data.video_utils import VIDEO_READER_FUNCS
27
 
 
178
  caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
179
  refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
180
  vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
181
+ # vqa_generator.zero_shot = True
182
+ # vqa_generator.constraint_trie = None
183
  general_generator = general_task.build_generator(general_models, general_cfg.generation)
184
 
185
  video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
 
452
 
453
  # Generate result
454
  with torch.no_grad():
455
+ if task_type == 'Visual Question Answering':
456
+ result, scores = zero_shot_step(vqa_task, generator, models, sample)
457
+ tokens = result[0]['answer']
458
+ bins = ''
459
+ else:
460
+ hypos = task.inference_step(generator, models, sample)
461
+ tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
462
 
463
  if bins.strip() != '':
464
  w, h = image.size