Spaces:
Build error
Build error
Upload snli_ve_dataset.py
Browse files- data/mm_data/snli_ve_dataset.py +203 -0
data/mm_data/snli_ve_dataset.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The OFA-Sys Team.
|
2 |
+
# All rights reserved.
|
3 |
+
# This source code is licensed under the Apache 2.0 license
|
4 |
+
# found in the LICENSE file in the root directory.
|
5 |
+
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import base64
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
from PIL import Image, ImageFile
|
17 |
+
|
18 |
+
from data import data_utils
|
19 |
+
from data.ofa_dataset import OFADataset
|
20 |
+
|
21 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
22 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
23 |
+
Image.MAX_IMAGE_PIXELS = None
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
27 |
+
|
28 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
+
|
31 |
+
|
32 |
+
def collate(samples, pad_idx, eos_idx):
|
33 |
+
if len(samples) == 0:
|
34 |
+
return {}
|
35 |
+
|
36 |
+
def merge(key):
|
37 |
+
return data_utils.collate_tokens(
|
38 |
+
[s[key] for s in samples],
|
39 |
+
pad_idx,
|
40 |
+
eos_idx=eos_idx,
|
41 |
+
)
|
42 |
+
|
43 |
+
id = np.array([s["id"] for s in samples])
|
44 |
+
src_tokens = merge("source")
|
45 |
+
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
46 |
+
|
47 |
+
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
48 |
+
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
49 |
+
|
50 |
+
ref_dict = None
|
51 |
+
if samples[0].get("ref_dict", None) is not None:
|
52 |
+
ref_dict = np.array([s['ref_dict'] for s in samples])
|
53 |
+
|
54 |
+
constraint_masks = None
|
55 |
+
if samples[0].get("constraint_mask", None) is not None:
|
56 |
+
constraint_masks = merge("constraint_mask")
|
57 |
+
|
58 |
+
decoder_prompts = None
|
59 |
+
if samples[0].get("decoder_prompt", None) is not None:
|
60 |
+
decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
|
61 |
+
|
62 |
+
prev_output_tokens = None
|
63 |
+
target = None
|
64 |
+
if samples[0].get("target", None) is not None:
|
65 |
+
target = merge("target")
|
66 |
+
tgt_lengths = torch.LongTensor(
|
67 |
+
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
68 |
+
)
|
69 |
+
ntokens = tgt_lengths.sum().item()
|
70 |
+
|
71 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
72 |
+
prev_output_tokens = merge("prev_output_tokens")
|
73 |
+
else:
|
74 |
+
ntokens = src_lengths.sum().item()
|
75 |
+
|
76 |
+
batch = {
|
77 |
+
"id": id,
|
78 |
+
"nsentences": len(samples),
|
79 |
+
"ntokens": ntokens,
|
80 |
+
"net_input": {
|
81 |
+
"src_tokens": src_tokens,
|
82 |
+
"src_lengths": src_lengths,
|
83 |
+
"patch_images": patch_images,
|
84 |
+
"patch_masks": patch_masks,
|
85 |
+
"prev_output_tokens": prev_output_tokens
|
86 |
+
},
|
87 |
+
"ref_dict": ref_dict,
|
88 |
+
"constraint_masks": constraint_masks,
|
89 |
+
"decoder_prompts": decoder_prompts,
|
90 |
+
"target": target
|
91 |
+
}
|
92 |
+
|
93 |
+
return batch
|
94 |
+
|
95 |
+
|
96 |
+
class SnliVeDataset(OFADataset):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
split,
|
100 |
+
dataset,
|
101 |
+
bpe,
|
102 |
+
src_dict,
|
103 |
+
tgt_dict=None,
|
104 |
+
max_src_length=80,
|
105 |
+
max_tgt_length=30,
|
106 |
+
patch_image_size=224,
|
107 |
+
add_caption=False,
|
108 |
+
constraint_trie=None,
|
109 |
+
imagenet_default_mean_and_std=False,
|
110 |
+
prompt_type="none"
|
111 |
+
):
|
112 |
+
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
113 |
+
self.max_src_length = max_src_length
|
114 |
+
self.max_tgt_length = max_tgt_length
|
115 |
+
self.patch_image_size = patch_image_size
|
116 |
+
|
117 |
+
self.add_caption = add_caption
|
118 |
+
self.constraint_trie = constraint_trie
|
119 |
+
self.prompt_type = prompt_type
|
120 |
+
|
121 |
+
if imagenet_default_mean_and_std:
|
122 |
+
mean = IMAGENET_DEFAULT_MEAN
|
123 |
+
std = IMAGENET_DEFAULT_STD
|
124 |
+
else:
|
125 |
+
mean = [0.5, 0.5, 0.5]
|
126 |
+
std = [0.5, 0.5, 0.5]
|
127 |
+
|
128 |
+
self.patch_resize_transform = transforms.Compose([
|
129 |
+
lambda image: image.convert("RGB"),
|
130 |
+
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
131 |
+
transforms.ToTensor(),
|
132 |
+
transforms.Normalize(mean=mean, std=std),
|
133 |
+
])
|
134 |
+
|
135 |
+
def __getitem__(self, index):
|
136 |
+
uniq_id, image, hypothesis, caption, label = self.dataset[index]
|
137 |
+
if label == 'contradiction':
|
138 |
+
label = 'no'
|
139 |
+
elif label == 'entailment':
|
140 |
+
label = 'yes'
|
141 |
+
elif label == 'neutral':
|
142 |
+
label = 'maybe'
|
143 |
+
else:
|
144 |
+
raise NotImplementedError
|
145 |
+
|
146 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
147 |
+
patch_image = self.patch_resize_transform(image)
|
148 |
+
patch_mask = torch.tensor([True])
|
149 |
+
|
150 |
+
hypothesis = self.pre_caption(hypothesis, self.max_src_length)
|
151 |
+
src_item = self.encode_text(' does the image describe " {} "?'.format(hypothesis))
|
152 |
+
tgt_item = self.encode_text(" {}".format(label))
|
153 |
+
ref_dict = {label: 1.0}
|
154 |
+
|
155 |
+
if self.add_caption:
|
156 |
+
caption = self.pre_caption(caption, self.max_src_length)
|
157 |
+
src_item = self.encode_text(' can image and text1 " {} " imply text2 " {} "?'.format(caption, hypothesis))
|
158 |
+
|
159 |
+
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
160 |
+
if self.prompt_type == 'none':
|
161 |
+
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
162 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
163 |
+
decoder_prompt = self.bos_item
|
164 |
+
elif self.prompt_type == 'src':
|
165 |
+
prev_output_item = torch.cat([src_item, tgt_item])
|
166 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
167 |
+
decoder_prompt = src_item
|
168 |
+
elif self.prompt_type == 'prev_output':
|
169 |
+
prev_output_item = torch.cat([src_item[:-1], tgt_item])
|
170 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
171 |
+
decoder_prompt = src_item[:-1]
|
172 |
+
else:
|
173 |
+
raise NotImplementedError
|
174 |
+
target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
|
175 |
+
|
176 |
+
example = {
|
177 |
+
"id": uniq_id,
|
178 |
+
"source": src_item,
|
179 |
+
"patch_image": patch_image,
|
180 |
+
"patch_mask": patch_mask,
|
181 |
+
"target": target_item,
|
182 |
+
"prev_output_tokens": prev_output_item,
|
183 |
+
"decoder_prompt": decoder_prompt,
|
184 |
+
"ref_dict": ref_dict,
|
185 |
+
}
|
186 |
+
if self.constraint_trie is not None:
|
187 |
+
constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
|
188 |
+
start_idx = len(target_item) - len(tgt_item) - 1
|
189 |
+
for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
|
190 |
+
constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
|
191 |
+
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
|
192 |
+
constraint_mask[i][constraint_nodes] = True
|
193 |
+
example["constraint_mask"] = constraint_mask
|
194 |
+
return example
|
195 |
+
|
196 |
+
def collater(self, samples, pad_to_length=None):
|
197 |
+
"""Merge a list of samples to form a mini-batch.
|
198 |
+
Args:
|
199 |
+
samples (List[dict]): samples to collate
|
200 |
+
Returns:
|
201 |
+
dict: a mini-batch containing the data of the task
|
202 |
+
"""
|
203 |
+
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|