guidel commited on
Commit
d19ac89
·
1 Parent(s): 2a90d2b

Upload snli_ve_dataset.py

Browse files
Files changed (1) hide show
  1. data/mm_data/snli_ve_dataset.py +203 -0
data/mm_data/snli_ve_dataset.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ from io import BytesIO
7
+
8
+ import logging
9
+ import warnings
10
+
11
+ import numpy as np
12
+ import torch
13
+ import base64
14
+ from torchvision import transforms
15
+
16
+ from PIL import Image, ImageFile
17
+
18
+ from data import data_utils
19
+ from data.ofa_dataset import OFADataset
20
+
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+ ImageFile.MAX_IMAGE_PIXELS = None
23
+ Image.MAX_IMAGE_PIXELS = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
+
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+
31
+
32
+ def collate(samples, pad_idx, eos_idx):
33
+ if len(samples) == 0:
34
+ return {}
35
+
36
+ def merge(key):
37
+ return data_utils.collate_tokens(
38
+ [s[key] for s in samples],
39
+ pad_idx,
40
+ eos_idx=eos_idx,
41
+ )
42
+
43
+ id = np.array([s["id"] for s in samples])
44
+ src_tokens = merge("source")
45
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
+
47
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
+
50
+ ref_dict = None
51
+ if samples[0].get("ref_dict", None) is not None:
52
+ ref_dict = np.array([s['ref_dict'] for s in samples])
53
+
54
+ constraint_masks = None
55
+ if samples[0].get("constraint_mask", None) is not None:
56
+ constraint_masks = merge("constraint_mask")
57
+
58
+ decoder_prompts = None
59
+ if samples[0].get("decoder_prompt", None) is not None:
60
+ decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
61
+
62
+ prev_output_tokens = None
63
+ target = None
64
+ if samples[0].get("target", None) is not None:
65
+ target = merge("target")
66
+ tgt_lengths = torch.LongTensor(
67
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
68
+ )
69
+ ntokens = tgt_lengths.sum().item()
70
+
71
+ if samples[0].get("prev_output_tokens", None) is not None:
72
+ prev_output_tokens = merge("prev_output_tokens")
73
+ else:
74
+ ntokens = src_lengths.sum().item()
75
+
76
+ batch = {
77
+ "id": id,
78
+ "nsentences": len(samples),
79
+ "ntokens": ntokens,
80
+ "net_input": {
81
+ "src_tokens": src_tokens,
82
+ "src_lengths": src_lengths,
83
+ "patch_images": patch_images,
84
+ "patch_masks": patch_masks,
85
+ "prev_output_tokens": prev_output_tokens
86
+ },
87
+ "ref_dict": ref_dict,
88
+ "constraint_masks": constraint_masks,
89
+ "decoder_prompts": decoder_prompts,
90
+ "target": target
91
+ }
92
+
93
+ return batch
94
+
95
+
96
+ class SnliVeDataset(OFADataset):
97
+ def __init__(
98
+ self,
99
+ split,
100
+ dataset,
101
+ bpe,
102
+ src_dict,
103
+ tgt_dict=None,
104
+ max_src_length=80,
105
+ max_tgt_length=30,
106
+ patch_image_size=224,
107
+ add_caption=False,
108
+ constraint_trie=None,
109
+ imagenet_default_mean_and_std=False,
110
+ prompt_type="none"
111
+ ):
112
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
113
+ self.max_src_length = max_src_length
114
+ self.max_tgt_length = max_tgt_length
115
+ self.patch_image_size = patch_image_size
116
+
117
+ self.add_caption = add_caption
118
+ self.constraint_trie = constraint_trie
119
+ self.prompt_type = prompt_type
120
+
121
+ if imagenet_default_mean_and_std:
122
+ mean = IMAGENET_DEFAULT_MEAN
123
+ std = IMAGENET_DEFAULT_STD
124
+ else:
125
+ mean = [0.5, 0.5, 0.5]
126
+ std = [0.5, 0.5, 0.5]
127
+
128
+ self.patch_resize_transform = transforms.Compose([
129
+ lambda image: image.convert("RGB"),
130
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
131
+ transforms.ToTensor(),
132
+ transforms.Normalize(mean=mean, std=std),
133
+ ])
134
+
135
+ def __getitem__(self, index):
136
+ uniq_id, image, hypothesis, caption, label = self.dataset[index]
137
+ if label == 'contradiction':
138
+ label = 'no'
139
+ elif label == 'entailment':
140
+ label = 'yes'
141
+ elif label == 'neutral':
142
+ label = 'maybe'
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
147
+ patch_image = self.patch_resize_transform(image)
148
+ patch_mask = torch.tensor([True])
149
+
150
+ hypothesis = self.pre_caption(hypothesis, self.max_src_length)
151
+ src_item = self.encode_text(' does the image describe " {} "?'.format(hypothesis))
152
+ tgt_item = self.encode_text(" {}".format(label))
153
+ ref_dict = {label: 1.0}
154
+
155
+ if self.add_caption:
156
+ caption = self.pre_caption(caption, self.max_src_length)
157
+ src_item = self.encode_text(' can image and text1 " {} " imply text2 " {} "?'.format(caption, hypothesis))
158
+
159
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
160
+ if self.prompt_type == 'none':
161
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
162
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
163
+ decoder_prompt = self.bos_item
164
+ elif self.prompt_type == 'src':
165
+ prev_output_item = torch.cat([src_item, tgt_item])
166
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
167
+ decoder_prompt = src_item
168
+ elif self.prompt_type == 'prev_output':
169
+ prev_output_item = torch.cat([src_item[:-1], tgt_item])
170
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
171
+ decoder_prompt = src_item[:-1]
172
+ else:
173
+ raise NotImplementedError
174
+ target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
175
+
176
+ example = {
177
+ "id": uniq_id,
178
+ "source": src_item,
179
+ "patch_image": patch_image,
180
+ "patch_mask": patch_mask,
181
+ "target": target_item,
182
+ "prev_output_tokens": prev_output_item,
183
+ "decoder_prompt": decoder_prompt,
184
+ "ref_dict": ref_dict,
185
+ }
186
+ if self.constraint_trie is not None:
187
+ constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
188
+ start_idx = len(target_item) - len(tgt_item) - 1
189
+ for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
190
+ constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
191
+ constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
192
+ constraint_mask[i][constraint_nodes] = True
193
+ example["constraint_mask"] = constraint_mask
194
+ return example
195
+
196
+ def collater(self, samples, pad_to_length=None):
197
+ """Merge a list of samples to form a mini-batch.
198
+ Args:
199
+ samples (List[dict]): samples to collate
200
+ Returns:
201
+ dict: a mini-batch containing the data of the task
202
+ """
203
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)