Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,9 @@ from tasks.mm_tasks.caption import CaptionTask
|
|
19 |
from tasks.mm_tasks.refcoco import RefcocoTask
|
20 |
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
21 |
|
|
|
|
|
|
|
22 |
# video
|
23 |
from data.video_utils import VIDEO_READER_FUNCS
|
24 |
|
@@ -175,8 +178,8 @@ move2gpu(audio_caption_models, general_cfg)
|
|
175 |
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
176 |
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
177 |
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
178 |
-
vqa_generator.zero_shot = True
|
179 |
-
vqa_generator.constraint_trie = None
|
180 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
181 |
|
182 |
video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
|
@@ -449,8 +452,13 @@ def inference(image, audio, video, task_type, instruction):
|
|
449 |
|
450 |
# Generate result
|
451 |
with torch.no_grad():
|
452 |
-
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
if bins.strip() != '':
|
456 |
w, h = image.size
|
|
|
19 |
from tasks.mm_tasks.refcoco import RefcocoTask
|
20 |
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
21 |
|
22 |
+
|
23 |
+
from utils.zero_shot_utils import zero_shot_step
|
24 |
+
|
25 |
# video
|
26 |
from data.video_utils import VIDEO_READER_FUNCS
|
27 |
|
|
|
178 |
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
179 |
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
180 |
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
181 |
+
# vqa_generator.zero_shot = True
|
182 |
+
# vqa_generator.constraint_trie = None
|
183 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
184 |
|
185 |
video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
|
|
|
452 |
|
453 |
# Generate result
|
454 |
with torch.no_grad():
|
455 |
+
if task_type == 'Visual Question Answering':
|
456 |
+
result, scores = zero_shot_step(vqa_task, generator, models, sample)
|
457 |
+
tokens = result[0]['answer']
|
458 |
+
bins = ''
|
459 |
+
else:
|
460 |
+
hypos = task.inference_step(generator, models, sample)
|
461 |
+
tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
|
462 |
|
463 |
if bins.strip() != '':
|
464 |
w, h = image.size
|