chrisc36 commited on
Commit
c106b4d
1 Parent(s): 68e0611

Delete multimodal_preprocessor.py

Browse files
Files changed (1) hide show
  1. multimodal_preprocessor.py +0 -1549
multimodal_preprocessor.py DELETED
@@ -1,1549 +0,0 @@
1
- import dataclasses
2
- import logging
3
- import re
4
- from collections import defaultdict
5
- from typing import Tuple, Optional, Any, Dict, List, Union, Mapping
6
-
7
- import einops
8
- import seqio
9
- import numpy as np
10
- import tensorflow as tf
11
-
12
- from .mm_data import seqio_tokenizer
13
- from .data_utils import pad_to_bounding_box, \
14
- get_3d_subsegments, _append_to_innermost_axis, resize_and_pad, \
15
- apply_with_random_selector, get_special_token_ids, make_autoregressive_inputs, \
16
- trim_and_pad_dataset, assert_not_truncated
17
- from .prompts import apply_keyword_prompt, STYLE_TO_GENERAL_PROMPT, GENERAL_PROMPTS_V1
18
- import .constants as config
19
-
20
-
21
- def siglip_resize(src, imgsize, truncate):
22
- """Resize and preprocess for SigLIP ViT in the offical jax implementation"""
23
- assert src.dtype == tf.uint8
24
- # SigCLIP removes aspect ratio by default
25
- resized = tf.image.resize(src, imgsize, method=tf.image.ResizeMethod.BILINEAR, antialias=False)
26
- dtype = src.dtype
27
- tf_dtype = tf.type_spec_from_value(src).dtype
28
- resized = tf.cast(tf.clip_by_value(resized, tf_dtype.min, tf_dtype.max), dtype)
29
-
30
- # Normalize between -1 and 1 without using imagenet standard mean/std
31
- vmin=-1; vmax=1; in_min=0; in_max=255.0
32
- in_min_t = tf.constant(in_min, tf.float32)
33
- in_max_t = tf.constant(in_max, tf.float32)
34
- image = tf.cast(resized, tf.float32)
35
- image = (image - in_min_t) / (in_max_t - in_min_t)
36
- image = vmin + image * (vmax - vmin)
37
- if truncate:
38
- image = image[:truncate, :truncate]
39
- return image
40
-
41
-
42
- def extract_bboxes(text, image_w, image_h):
43
- points = extract_points(text, image_w, image_h)
44
- boxes = []
45
- for i in range(len(points)//2):
46
- x1, y1 = points[i*2]
47
- x2, y2 = points[i*2 + 1]
48
- boxes.append([x1, y1, x2, y2])
49
- return boxes
50
-
51
-
52
- def extract_annotated_points(caption, image_w, image_h):
53
- points = []
54
- for match in re.finditer("<point x=\"([0-9\\.]*)\" y=\"([0-9\\.]*)\" alt=\"([^\"]*)\">", caption):
55
- x = float(match.group(1))
56
- y = float(match.group(2))
57
- points.append(([[x, y]], match.group(3)))
58
- for match in re.finditer("<points ([^<]*) alt=\"([^\"]*)\">", caption):
59
- loc_str = match.group(1)
60
- locations = defaultdict(dict)
61
- if loc_str.startswith("points="):
62
- point_grp = []
63
- for point_match in re.finditer(r"([0-9]+\.[0-9]),? ([0-9]+\.[0-9])", loc_str):
64
- try:
65
- point = [float(point_match.group(i)) for i in range(1, 3)]
66
- point_grp.append(point)
67
- except ValueError:
68
- pass
69
- else:
70
- for val in loc_str.split():
71
- try:
72
- key, val = val.split("=")
73
- locations[key[1:]][key[:1]] = float(val.strip("\""))
74
- except ValueError:
75
- import pdb; pdb.set_trace()
76
- logging.warning(f"Failed to parse {val} from {match.group(0)}")
77
- point_grp = []
78
- for key, coords in locations.items():
79
- if sorted(coords) == ["x", "y"]:
80
- point_grp.append([coords["x"], coords["y"]])
81
- if point_grp:
82
- points.append((point_grp, match.group(2)))
83
-
84
- normalized = []
85
- for point_grp, point_text in points:
86
- normalized.append((
87
- np.array(point_grp) / 100.0 * np.array([image_w, image_h]),
88
- point_text,
89
- ))
90
- return normalized
91
-
92
-
93
- def extract_points(text, image_w, image_h):
94
- all_points = []
95
- for match in re.finditer(r"Click\(([0-9]+\.[0-9]), ?([0-9]+\.[0-9])\)", text):
96
- try:
97
- point = [float(match.group(i)) for i in range(1, 3)]
98
- except ValueError:
99
- pass
100
- else:
101
- point = np.array(point)
102
- if np.max(point) > 100:
103
- # Treat as an invalid output
104
- continue
105
- point /= 100.0
106
- point = point * np.array([image_w, image_h])
107
- all_points.append(point)
108
-
109
- for match in re.finditer(r"\(([0-9]+\.[0-9]),? ?([0-9]+\.[0-9])\)", text):
110
- try:
111
- point = [float(match.group(i)) for i in range(1, 3)]
112
- except ValueError:
113
- pass
114
- else:
115
- point = np.array(point)
116
- if np.max(point) > 100:
117
- # Treat as an invalid output
118
- continue
119
- point /= 100.0
120
- point = point * np.array([image_w, image_h])
121
- all_points.append(point)
122
- for match in re.finditer(r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"', text):
123
- try:
124
- point = [float(match.group(i)) for i in range(1, 3)]
125
- except ValueError:
126
- pass
127
- else:
128
- point = np.array(point)
129
- if np.max(point) > 100:
130
- # Treat as an invalid output
131
- continue
132
- point /= 100.0
133
- point = point * np.array([image_w, image_h])
134
- all_points.append(point)
135
- for match in re.finditer(r'(?:\d+|p)\s*=\s*([0-9]{3})\s*,\s*([0-9]{3})', text):
136
- try:
137
- point = [int(match.group(i)) / 10.0 for i in range(1, 3)]
138
- except ValueError:
139
- pass
140
- else:
141
- point = np.array(point)
142
- if np.max(point) > 100:
143
- # Treat as an invalid output
144
- continue
145
- point /= 100.0
146
- point = point * np.array([image_w, image_h])
147
- all_points.append(point)
148
- return all_points
149
-
150
-
151
- def extract_points_from_point_count(text, image_w, image_h):
152
- all_points = []
153
- points = re.findall(r"(\d+\.\d+),\s*(\d+\.\d+)", text)
154
-
155
- for match in points:
156
- try:
157
- point = [float(match[0]), float(match[1])]
158
- except ValueError:
159
- pass
160
- else:
161
- point = np.array(point)
162
- if np.max(point) > 100:
163
- # Treat as an invalid output
164
- continue
165
- point = point * np.array([image_w, image_h])
166
- all_points.append(point)
167
- return all_points
168
-
169
-
170
- def select_tiling(h, w, patch_size, max_num_patches):
171
- """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
172
- original_size = tf.stack([h, w]) # [1, 2]
173
- original_res = h * w
174
- tilings = []
175
- for i in range(1, max_num_patches+1):
176
- for j in range(1, max_num_patches+1):
177
- if i*j <= max_num_patches:
178
- tilings.append((i, j))
179
- # sort so argmin and argmax favour smaller tilings in the event of a tie
180
- tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
181
- candidate_tilings = tf.constant(tilings, dtype=tf.int32) # [n_resolutions, 2]
182
- candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
183
-
184
- # How much we would need to scale the image to fit exactly in each tiling
185
- required_scale_d = tf.cast(candidate_resolutions, tf.float32) / tf.cast(original_size[None, :], tf.float32)
186
- required_scale = tf.reduce_min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
187
- if tf.reduce_all(required_scale < 1):
188
- # We are forced to downscale, so try to minimize the amount of downscaling
189
- ix = tf.argmax(required_scale)[0]
190
- else:
191
- # Pick the resolution that required the least upscaling so that it most closely fits the image
192
- required_scale = tf.where(required_scale < 1.0, 10e9, required_scale)
193
- ix = tf.argmin(required_scale)[0]
194
- return candidate_tilings[ix]
195
-
196
-
197
- DEMO_STYLES = [
198
- "point_count",
199
- "pointing",
200
- "user_qa",
201
- "scifi_charts_exp",
202
- "scifi_charts_exp",
203
- "scifi_charts_exp",
204
- "scifi_charts_exp",
205
- "long_caption",
206
- "named_entity"
207
- ]
208
-
209
-
210
- @dataclasses.dataclass
211
- class MultiModalPreprocessor:
212
- """Turns text/image inputs into tensors that can be input to the model"""
213
- tokenizer: Any
214
-
215
- # How to prompt the model
216
- prompt_templates: str = "none" # How to template prompts for examples
217
- message_format: str = "none" # How to format messages
218
- system_prompt: Optional[str] = None # How to generate system prompts
219
- prompt_override: Optional[str] = None # Used for setting prompt manually
220
- always_start_with_space: bool = False # Always include a leading space for the first bit of text
221
- default_inference_len: int = 65 # Inference len for length-conditioned prompting
222
-
223
- # How to crops/resize images
224
- crop_mode: str = "resize"
225
- max_crops: int = 6
226
- overlap_margins: Tuple[int, int] = (4, 4)
227
- do_random_scale: Optional[bool] = False
228
- resize: str = "default"
229
- random_scale_max: float = 1.1
230
- random_scale_min: float = 0.9
231
- random_scale_ratio: float = 0.5
232
- use_col_tokens: bool = True
233
-
234
- # Data about the ViT and connector we need when deciding the crops
235
- base_image_input_size: Tuple[int, int] = (336, 336)
236
- image_token_length_w: int = 12
237
- image_token_length_h: int = 12
238
- image_patch_size: int = 14
239
- image_padding_mask: bool = False
240
-
241
- # Other settings
242
- loss_token_weighting: Optional[str] = None
243
- unconditioned: Union[bool, float] = False # Ignore images
244
- fix_image_input_idx: int = 2 # backwards compatibility fix
245
- pad_to: Optional[int] = None # experimental feature
246
-
247
- _special_tokens: Dict[str, int] = None
248
- split_at: Optional[int] = None
249
-
250
- def get_max_total_crops(self):
251
- if self.crop_mode == "resize":
252
- return 1
253
- elif "resize" in self.crop_mode:
254
- return 1 + self.max_crops
255
- else:
256
- return self.max_crops
257
-
258
- @property
259
- def image_num_patch(self):
260
- h, w = self.base_image_input_size
261
- return h//self.image_patch_size, w//self.image_patch_size
262
-
263
- @property
264
- def special_token_ids(self):
265
- if self._special_tokens is None:
266
- self._special_tokens = get_special_token_ids(self.tokenizer)
267
- return self._special_tokens
268
-
269
- def image_to_patches_and_tokens(self, image, is_training):
270
- """Preprocesses an image
271
-
272
- Args:
273
- image: [h, w, 3] image to preprocessing
274
- Returns:
275
- crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
276
- change between images but the other dimension are fixed
277
- tokens: (n_tokens,) tf.int32 tokens, pad tokens indicate where to insert the
278
- patch features, might include other special tokens as well
279
- patch_ordering: (n_crops, n_tokens_per_crop) order image features should be inserted
280
- into the `tokens`, negative values indicates patches features to exclude
281
- padding_mask: (n_crops, h, w) mask of what pixels are padding, can be None
282
- """
283
- do_random_scale = self.do_random_scale
284
- if do_random_scale:
285
- do_random_scale = is_training
286
-
287
- base_image_input_size = self.base_image_input_size
288
- if isinstance(base_image_input_size, int):
289
- base_image_input_size = (base_image_input_size, base_image_input_size)
290
-
291
- image_token_length_w, image_token_length_h = self.image_token_length_w, self.image_token_length_h
292
- base_image_input_d = self.image_patch_size
293
- tokens_per_image = image_token_length_w * image_token_length_h
294
- image_base_patch_w = base_image_input_size[1] // base_image_input_d
295
- image_base_patch_h = base_image_input_size[0] // base_image_input_d
296
- extra_image = False
297
- patch_ordering = None
298
-
299
- if self.resize == "default":
300
- image = tf.image.convert_image_dtype(image, dtype=tf.float32)
301
- def _resize(_image, sz):
302
- return resize_and_pad(
303
- _image, sz,
304
- do_random_scale=do_random_scale,
305
- random_scale_max=self.random_scale_max,
306
- random_scale_min=self.random_scale_min,
307
- random_scale_ratio=self.random_scale_ratio,
308
- return_outputs=False,
309
- resize_method='random' if is_training else tf.image.ResizeMethod.BILINEAR)
310
- elif self.resize == "stretch":
311
- image = tf.image.convert_image_dtype(image, dtype=tf.float32)
312
- assert not do_random_scale
313
-
314
- def _resize(_image, sz):
315
- if not is_training:
316
- img = tf.image.resize(_image, sz, antialias=True, method=tf.image.ResizeMethod.BILINEAR)
317
- else:
318
- resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()])
319
- img = apply_with_random_selector(
320
- _image,
321
- lambda x, method_idx: tf.image.resize(x, sz,
322
- tf.image.ResizeMethod.__dict__[resize_methods[method_idx]],
323
- antialias=True),
324
- num_cases=len(resize_methods))
325
- return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
326
- elif self.resize in "siglip":
327
- assert not do_random_scale
328
-
329
- def _resize(_image, sz):
330
- img = siglip_resize(_image, sz, truncate=None)
331
- return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
332
- else:
333
- raise NotImplementedError(self.resize)
334
-
335
- def _img_to_patches(_img, _img_mask, dy=1, dx=1):
336
- _img = einops.rearrange(
337
- _img, '(dy h dh) (dx w dw) c -> (dy dx) (h w) (dh dw c)',
338
- dh=base_image_input_d,
339
- dw=base_image_input_d,
340
- dy=dy,
341
- dx=dx,
342
- h=image_base_patch_h,
343
- w=image_base_patch_w
344
- )
345
- _img_mask = einops.rearrange(
346
- _img_mask, '(dy h dh) (dx w dw) -> (dy dx) (h w) (dh dw)',
347
- dh=base_image_input_d,
348
- dw=base_image_input_d,
349
- dy=dy,
350
- dx=dx,
351
- h=image_base_patch_h,
352
- w=image_base_patch_w
353
- )
354
- return _img, tf.reduce_mean(tf.cast(_img_mask, tf.float32), -1)
355
-
356
- mode = self.crop_mode
357
- if mode == "resize":
358
- patches, img_mask = _resize(image, base_image_input_size)
359
- patches, img_mask = _img_to_patches(patches, img_mask)
360
- image_layout_impatch_w = 1
361
- image_layout_impatch_h = 1
362
- patch_ordering = tf.range(tokens_per_image)[None, :]
363
-
364
- elif mode in ["overlap", "overlap-and-resize-c2"]:
365
- original_image_h = tf.shape(image, out_type=tf.int32)[0]
366
- original_image_w = tf.shape(image, out_type=tf.int32)[1]
367
- crop_size = base_image_input_size[0]
368
-
369
- # Discard this many patches from the (left/top, right/bottom) of crops
370
- left_margin, right_margin = self.overlap_margins
371
- # left_margin, right_margin = 2, 2
372
- assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
373
- total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
374
- crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
375
- crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
376
- crop_window_size = crop_window_patches * base_image_input_d
377
- tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels,
378
- crop_window_size, self.max_crops)
379
- src, img_mask = _resize(
380
- image, [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels])
381
-
382
- n_crops = tiling[0]*tiling[1]
383
- patches_arr = tf.TensorArray(
384
- tf.float32, n_crops, element_shape=[crop_size, crop_size, 3])
385
- mask_arr = tf.TensorArray(
386
- tf.bool, n_crops, element_shape=[crop_size, crop_size])
387
- # We assume 2x2 pooling, but can allow padding the right/bottom with extra
388
- # patches if the number of patches per side is not even
389
- assert (crop_patches+1)//2 == image_token_length_h
390
- assert (crop_patches+1)//2 == image_token_length_w
391
- patch_ordering_arr = tf.TensorArray(
392
- tf.int32, n_crops, element_shape=[image_token_length_h, image_token_length_w])
393
- on = 0
394
- on_patch = 0
395
- for i in range(tiling[0]):
396
- y0 = i*crop_window_size
397
- if i == 0:
398
- crop_y0 = 0
399
- else:
400
- crop_y0 = left_margin // 2
401
-
402
- crop_h = image_base_patch_h - (right_margin + left_margin)
403
- if i == 0:
404
- crop_h += left_margin
405
- if i == (tiling[0]-1):
406
- crop_h += right_margin
407
- for j in range(tiling[1]):
408
- x0 = j*crop_window_size
409
- if j == 0:
410
- crop_x0 = 0
411
- else:
412
- crop_x0 = left_margin // 2
413
-
414
- crop_w = image_base_patch_w - (right_margin + left_margin)
415
- if j == 0:
416
- crop_w += left_margin
417
- if j == (tiling[1]-1):
418
- crop_w += right_margin
419
-
420
- pooled_w = (crop_w + 1) // 2
421
- pooled_h = (crop_h + 1) // 2
422
- patch_ordering_arr = patch_ordering_arr.write(
423
- on_patch,
424
- pad_to_bounding_box(
425
- tf.reshape(tf.range(on, on+pooled_h*pooled_w, dtype=tf.int32), (pooled_h, pooled_w, 1)),
426
- crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
427
- )[:, :, 0]
428
- )
429
- patches_arr = patches_arr.write(on_patch, src[y0:y0+crop_size, x0:x0+crop_size])
430
- mask_arr = mask_arr.write(on_patch, img_mask[y0:y0+crop_size, x0:x0+crop_size])
431
-
432
- on += pooled_h*pooled_w
433
- on_patch += 1
434
- patches = patches_arr.stack()
435
- patch_ordering = patch_ordering_arr.stack()
436
- img_mask = mask_arr.stack()
437
-
438
- image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
439
- patches = einops.rearrange(
440
- patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
441
- dh=base_image_input_d,
442
- dw=base_image_input_d,
443
- h=image_base_patch_h,
444
- w=image_base_patch_w
445
- )
446
- img_mask = einops.rearrange(
447
- img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
448
- dh=base_image_input_d,
449
- dw=base_image_input_d,
450
- h=image_base_patch_h,
451
- w=image_base_patch_w
452
- )
453
- img_mask = tf.reduce_mean(tf.cast(img_mask, tf.float32), -1)
454
- patch_ordering = tf.reshape(patch_ordering, [-1])
455
- valid = patch_ordering >= 0
456
-
457
- # Transpose, to get left-to-right order
458
- patch_ordering_rh = tf.reshape(patch_ordering,
459
- [tiling[0], tiling[1], image_token_length_h, image_token_length_w])
460
- patch_ordering_rh = tf.transpose(patch_ordering_rh, [0, 2, 1, 3])
461
- patch_ordering_rh = tf.reshape(patch_ordering_rh, [-1])
462
-
463
- # The tranpose will screw up which patches are masked, project the
464
- # new order into sparse structure of `patch_ordering` to fix this
465
- patch_ordering = tf.tensor_scatter_nd_update(
466
- patch_ordering,
467
- tf.where(valid),
468
- tf.boolean_mask(patch_ordering_rh, patch_ordering_rh >= 0),
469
- name="patch_order_transpose_Scatter"
470
- )
471
-
472
- h = tiling[0]*crop_window_patches + (right_margin+left_margin)
473
- w = tiling[1]*crop_window_patches + (right_margin+left_margin)
474
- special_token_ids = self.special_token_ids
475
- per_row = tf.fill(((w+1)//2,),
476
- special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
477
- if self.use_col_tokens:
478
- per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
479
-
480
- joint = tf.tile(per_row, [(h+1)//2])
481
- joint = [
482
- [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
483
- joint,
484
- [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
485
- ]
486
-
487
- if "resize" in mode:
488
- resized, resized_mask = _resize(image, base_image_input_size)
489
- resized, resized_mask = _img_to_patches(resized, resized_mask)
490
- if 'c2' in mode:
491
- patches = tf.concat([resized, patches], 0)
492
- image_mask = tf.concat([resized_mask, img_mask], 0)
493
- else:
494
- patches = tf.concat([patches, resized], 0)
495
- image_mask = tf.concat([img_mask, resized_mask], 0)
496
-
497
- if patch_ordering is not None:
498
- if 'c2' in mode:
499
- patch_ordering = tf.where(
500
- patch_ordering >= 0,
501
- patch_ordering + tokens_per_image,
502
- -1
503
- )
504
- patch_ordering = tf.concat([tf.range(0, tokens_per_image), patch_ordering], 0)
505
- else:
506
- raise ValueError()
507
- per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
508
- if self.use_col_tokens:
509
- per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
510
- extra_tokens = tf.tile(per_row, [image_token_length_h])
511
- joint = [
512
- [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
513
- extra_tokens,
514
- [special_token_ids[config.DEFAULT_IM_END_TOKEN]],
515
- ] + joint
516
-
517
- joint = tf.concat(joint, 0)
518
- return patches, joint, patch_ordering, img_mask
519
-
520
- elif mode in ["patchify", "patchify-and-resize", "patchify-v2", "patchify-v2-and-resize", "patchify-v2-and-resize-c2"]:
521
- original_image_w = tf.shape(image, out_type=tf.int32)[0]
522
- original_image_h = tf.shape(image, out_type=tf.int32)[1]
523
- assert base_image_input_size[0] == base_image_input_size[1]
524
- base_patch_size = base_image_input_size[0]
525
- tiling = select_tiling(original_image_w, original_image_h, base_patch_size, self.max_crops)
526
-
527
- patches, img_mask = _resize(
528
- image, [tiling[0]*base_patch_size, tiling[1]*base_patch_size])
529
- patches, img_mask = _img_to_patches(patches, img_mask, tiling[0], tiling[1])
530
- if 'v2' in mode:
531
- # Order patches left-to-right not crop-by-crop
532
- patch_ordering = tf.reshape(
533
- tf.range(tokens_per_image*tiling[0]*tiling[1]),
534
- [tiling[0], tiling[1], image_token_length_w, image_token_length_h])
535
- patch_ordering = tf.transpose(patch_ordering, [0, 2, 1, 3])
536
- patch_ordering = tf.reshape(patch_ordering, (-1, tokens_per_image))
537
- else:
538
- patch_ordering = None
539
-
540
- # given image size, determine the number of patch size.
541
- image_layout_impatch_w = tiling[0]
542
- image_layout_impatch_h = tiling[1]
543
-
544
- if "resize" in mode:
545
- extra_image = True
546
- resized, resized_mask = _resize(image, base_image_input_size)
547
- resized, resized_mask = _img_to_patches(resized, resized_mask)
548
- if 'c2' in mode:
549
- patches = tf.concat([resized, patches], 0)
550
- image_mask = tf.concat([resized_mask, img_mask], 0)
551
- else:
552
- patches = tf.concat([patches, resized], 0)
553
- image_mask = tf.concat([img_mask, resized_mask], 0)
554
-
555
- if patch_ordering is not None:
556
- if 'c2' in mode:
557
- patch_ordering = tf.concat(
558
- [tf.range(0, tokens_per_image)[None, :], patch_ordering+tokens_per_image], 0)
559
- else:
560
- n = tf.shape(patch_ordering)[0]
561
- patch_ordering = tf.concat(patch_ordering, [tf.range(n, n+tokens_per_image)[None, :]], 0)
562
- else:
563
- raise NotImplementedError(mode)
564
-
565
- special_token_ids = self.special_token_ids
566
-
567
- per_row = tf.fill((image_token_length_w*image_layout_impatch_w,),
568
- special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
569
- if self.use_col_tokens:
570
- per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
571
-
572
- joint = tf.tile(per_row, [image_token_length_h * image_layout_impatch_h])
573
- joint = [
574
- [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
575
- joint,
576
- [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
577
- ]
578
- if extra_image:
579
- assert not self.image_padding_mask
580
- per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
581
- if self.use_col_tokens:
582
- per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
583
- extra_tokens = tf.tile(per_row, [image_token_length_h])
584
- if 'c2' in mode:
585
- joint = [
586
- [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
587
- extra_tokens,
588
- [special_token_ids[config.DEFAULT_IM_END_TOKEN]],
589
- ] + joint
590
- else:
591
- joint += [
592
- [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
593
- extra_tokens,
594
- [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
595
- ]
596
- if self.pad_to is not None:
597
- n = [tf.shape(x)[0] for x in joint]
598
- assert len(joint[-1]) == 1
599
- to_pad = self.pad_to - tf.reduce_sum(tf.stack(n))
600
- joint = tf.concat(joint[:-1] + [
601
- tf.zeros(to_pad, dtype=tf.int32) - 1,
602
- joint[-1]
603
- ], axis=0)
604
- else:
605
- joint = tf.concat(joint, 0)
606
- return patches, tf.concat(joint, 0), patch_ordering, img_mask
607
-
608
- def build_image_input_idx(self, input_tokens, patch_order, no_image=None):
609
- """Builds the index used to insert patch features into `input_tokens`"""
610
- tokens_per_image = self.image_token_length_w * self.image_token_length_h
611
- if no_image is not None and no_image:
612
- return tf.zeros((0, tokens_per_image), tf.int32)
613
-
614
- image_input_idx = input_tokens == self.special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN]
615
- image_input_idx = tf.experimental.numpy.nonzero(image_input_idx)[0]
616
- image_input_idx = tf.cast(image_input_idx, tf.int32)
617
-
618
- if patch_order is not None:
619
- n_tokens = tf.shape(image_input_idx)[0]
620
- # Item N should have the value of image_input_index[where(patch_order == n)] if >= 0 else -1
621
- patch_order = tf.reshape(patch_order, [-1])
622
- n_patches = tf.shape(patch_order)[0]
623
- if n_tokens != n_patches:
624
- # Most complex case where some patches are dropped
625
- # First invert the valid tokens
626
- valid = patch_order >= 0
627
- sorted_patch_ixs = tf.scatter_nd(
628
- tf.boolean_mask(patch_order, valid)[:, None],
629
- tf.range(tf.reduce_sum(tf.cast(valid, tf.int32)), dtype=tf.int32),
630
- [n_tokens],
631
- name="valid_order_scatter"
632
- )
633
-
634
- # Project the inverted mapping into same sparse structure
635
- tmp = tf.fill(tf.shape(patch_order), -1)
636
- sorted_patch_ixs_ex = tf.tensor_scatter_nd_update(
637
- tmp,
638
- tf.where(valid),
639
- sorted_patch_ixs,
640
- name="order_with_padding_scatter"
641
- )
642
-
643
- # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
644
- valid = tf.cast(sorted_patch_ixs_ex >= 0, tf.int32)
645
- image_input_idx = tf.gather(image_input_idx, sorted_patch_ixs_ex*valid)
646
- image_input_idx = image_input_idx*valid - 100*(1 - valid)
647
- else:
648
- sorted_patch_ixs = tf.scatter_nd(patch_order[:, None], tf.range(n_patches), [n_patches])
649
- image_input_idx = tf.gather(tf.reshape(image_input_idx, [-1]), sorted_patch_ixs)
650
- image_input_idx = tf.reshape(image_input_idx, [-1, tokens_per_image])
651
- return image_input_idx
652
-
653
- def build_multimodel_features(self, tokens, mask, subsegments, images, is_training):
654
- """Builds input features by pre-processing `images` and modifying `tokens`
655
- to include image col/pad/start/end tokens instead image placeholder tokens
656
- """
657
- image_token_id = self.special_token_ids[config.IMAGE_PROMPT]
658
- image_idx = tf.experimental.numpy.nonzero(tokens == image_token_id)[0]
659
- if images is None or tf.shape(images)[0] == 0:
660
- tf.debugging.assert_equal(image_idx, tf.cast(0, tf.int64),
661
- "Image placeholders in input, but no images given!")
662
- tokens_per_image = self.image_token_length_w * self.image_token_length_h
663
- n_pixels = self.image_patch_size ** 2 * 3
664
- image_num_patch = np.prod(self.image_num_patch)
665
- crops = tf.zeros((0, image_num_patch, n_pixels), dtype=tf.float32)
666
- image_idx = tf.zeros((0, tokens_per_image), tf.int32)
667
- out = dict(
668
- target_tokens=tokens,
669
- images=crops,
670
- image_input_idx=image_idx,
671
- loss_masks=mask
672
- )
673
- if self.image_padding_mask:
674
- out["image_masks"] = tf.zeros((0, image_num_patch), dtype=tf.float32)
675
- if subsegments is not None:
676
- out["subsegment_ids"] = subsegments
677
- return out
678
- elif tf.shape(image_idx)[0] == 0 and tf.shape(images)[0] > 0:
679
- # As a special case, no image prompt means the images are all at the start
680
- image_idx = tf.zeros([tf.shape(images)[0]], tf.int64) - 1
681
- else:
682
- tf.debugging.assert_equal(
683
- tf.shape(images)[0], tf.shape(image_idx)[0],
684
- message="Different number of images and image placeholders")
685
-
686
- # Each image will produce a variable number of crops/tokens, so we aggregate things
687
- # the results tensor arrays and the concat them
688
- tokens_per_image = self.image_token_length_w * self.image_token_length_h
689
- n_pixels = self.image_patch_size*self.image_patch_size*3
690
- n_patches = self.image_num_patch[0]*self.image_num_patch[1]
691
-
692
- n = tf.shape(images)[0]
693
- all_crops = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
694
- element_shape=[None, n_patches, n_pixels])
695
- all_image_idx = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
696
- element_shape=[None, tokens_per_image])
697
- out_tokens = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
698
- element_shape=[None])
699
- out_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
700
- element_shape=[None])
701
- if self.image_padding_mask:
702
- all_crop_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
703
- element_shape=[None, None])
704
- else:
705
- # Dummy array to keep tensorflow's control analysis happy
706
- all_crop_masks = tf.TensorArray(dtype=tf.float32, size=0, infer_shape=False,
707
- element_shape=[None, None])
708
- if subsegments is not None:
709
- out_subsegments = tf.TensorArray(dtype=tf.int32, size=n, element_shape=[None])
710
- else:
711
- out_subsegments = tf.TensorArray(dtype=tf.int32, size=0, element_shape=[None])
712
-
713
- image_idx = tf.cast(image_idx, tf.int32)
714
- for ix in range(tf.shape(image_idx)[0]):
715
- token_ix = image_idx[ix]
716
- crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(images[ix], is_training)
717
- patch_idx = self.build_image_input_idx(image_tokens, patch_ordering)
718
-
719
- if token_ix == -1: # -1 is an image inserted at the very start
720
- start = 0
721
- token_ix = 0
722
- end = 0
723
- else:
724
- start = 0 if ix == 0 else image_idx[ix-1] + 1
725
- end = token_ix + 1
726
-
727
- all_image_idx = all_image_idx.write(ix, patch_idx + token_ix)
728
- all_crops = all_crops.write(ix, crops)
729
- image_token_mask = tf.zeros_like(image_tokens, dtype=tf.float32)
730
-
731
- if ix == (tf.shape(images)[0] - 1):
732
- tokens_part = tf.concat([tokens[start:token_ix], image_tokens, tokens[end:]], 0)
733
- mask_part = tf.concat([mask[start:token_ix], image_token_mask, mask[end:]], 0)
734
- else:
735
- tokens_part = tf.concat([tokens[start:token_ix], image_tokens], 0)
736
- mask_part = tf.concat([mask[start:token_ix], image_token_mask], 0)
737
-
738
- out_tokens = out_tokens.write(ix, tokens_part)
739
- out_masks = out_masks.write(ix, mask_part)
740
- if self.image_padding_mask:
741
- all_crop_masks = all_crop_masks.write(ix, img_mask)
742
- if subsegments is not None:
743
- parts = tf.fill([tf.shape(image_tokens)[0]], subsegments[token_ix])
744
- if ix == (tf.shape(images)[0] - 1):
745
- seg = tf.concat([subsegments[start:token_ix], parts, subsegments[end:]], 0)
746
- else:
747
- seg = tf.concat([subsegments[start:token_ix], parts], 0)
748
- out_subsegments = out_subsegments.write(ix, seg)
749
-
750
- out = dict(
751
- target_tokens=out_tokens.concat(),
752
- images=all_crops.concat(),
753
- image_input_idx=all_image_idx.concat(),
754
- loss_masks=out_masks.concat()
755
- )
756
- if self.image_padding_mask:
757
- out["image_masks"] = all_crop_masks.concat()
758
- if subsegments is not None:
759
- out["subsegment_ids"] = out_subsegments.concat()
760
- return out
761
-
762
- def _format_message(self, args):
763
- message, ix = args
764
- return self.format_message(message, ix)
765
-
766
- def format_message(self, message, ix):
767
- """Applies system formatting to ith message from a sequence of messages"""
768
- # If the image placeholder text is not preceded by space it will not get tokenized
769
- # correctly by some tokenizers, so double check it here
770
- assert config.IMAGE_PROMPT == "<|image|>"
771
- tf.debugging.assert_equal(
772
- tf.strings.regex_full_match(message, r".*[^ ]<\|image\|>.*"),
773
- False,
774
- message="Image token must always be preceded by a space"
775
- )
776
- is_user = ix % 2 == 0
777
- if self.message_format == "none" or self.message_format is None:
778
- pass
779
- elif self.message_format == "role":
780
- if is_user:
781
- # We put the "System:" prefix here since it doesn't need a loss
782
- message = tf.strings.join(["User: ", message, " Assistant:"])
783
- elif self.message_format == "cleanup":
784
- if is_user:
785
- # We put the "System:" prefix here since it doesn't need a loss
786
- message = tf.strings.join(
787
- [
788
- "[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
789
- message,
790
- "\n[[Assistant]]: {after}"
791
- ]
792
- )
793
- elif self.message_format == "mistral":
794
- if is_user:
795
- message = tf.strings.join(["[INST] ", message, " [/INST]"])
796
- else:
797
- raise NotImplementedError(self.message_format)
798
-
799
- # For now assume a space will be used to separate the messages
800
- if not self.tokenizer.adds_space:
801
- if ix != 0 or self.always_start_with_space:
802
- message = tf.strings.join([" ", message])
803
- # Else space added automatically by the tokenizer
804
-
805
- return message
806
-
807
- def get_multi_message_token_input(self, conversations, text_weights=None):
808
- """Build inputs for a ragged tensor of conversations, where each row of the tensor,
809
- is a different conversation"""
810
- tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
811
- conversations.values, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
812
-
813
- n_conversation = tf.shape(conversations)[0]
814
- ar = tf.TensorArray(dtype=tf.int32, infer_shape=False, element_shape=[None],
815
- size=n_conversation)
816
- n_messages_per_conversation = conversations.row_lengths()
817
- for ix in range(n_conversation):
818
- ar = ar.write(ix, tf.range(n_messages_per_conversation[ix], dtype=tf.int32))
819
- message_ix = ar.concat()
820
- messages = tf.map_fn(
821
- self._format_message, elems=(conversations.values, message_ix), fn_output_signature=tf.string)
822
- messages = self.tokenizer.encode_tf(messages)
823
-
824
- # Append EOS
825
- is_response = message_ix % 2 == 1
826
- is_response_int = tf.cast(is_response, tf.int32)
827
- eos = tf.RaggedTensor.from_row_lengths(
828
- tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
829
- tf.cast(is_response_int, messages.row_splits.dtype)
830
- )
831
- messages = tf.concat([messages, eos], axis=1)
832
-
833
- # Build mask over system responses
834
- mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
835
- decoder_loss_weights = tf.cast(mask.values, tf.float32)
836
-
837
- # Build subsegment ids for each conversation
838
- tokens_per_message = tf.RaggedTensor.from_row_splits(
839
- row_splits=conversations.row_splits,
840
- values=messages.row_lengths()
841
- )
842
- token_per_conversation = tf.reduce_sum(tokens_per_message, axis=1)
843
- subsegment_ids = tf.repeat(tf.range(n_conversation, dtype=tf.int32)+1, token_per_conversation)
844
-
845
- image_ix = self.special_token_ids[config.IMAGE_PROMPT]
846
- messages = tf.concat([[image_ix], messages.values], axis=0)
847
- decoder_loss_weights = tf.concat([[0], decoder_loss_weights], axis=0)
848
- subsegment_ids = tf.concat([[10000], subsegment_ids], axis=0)
849
- return messages, decoder_loss_weights, subsegment_ids
850
-
851
- def get_multi_response_token_input(self, user_prompt, text, text_weights=None):
852
- """Build tokens for a multi-response-per-image example"""
853
- # FIXME this could be relaxed to just having the same prefix
854
- tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
855
- user_prompt, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
856
- user_prompt = self.format_message(user_prompt, 0)
857
- vocab = self.tokenizer
858
- prompts = vocab.encode_tf(user_prompt)
859
- response = self.format_message(text, 1)
860
- responses = vocab.encode_tf(response)
861
- responses = _append_to_innermost_axis(responses, vocab.eos_token_id)
862
- response_mask = tf.ones_like(responses, dtype=tf.float32)
863
- if text_weights is not None:
864
- response_mask *= text_weights
865
- image_tokens = tf.constant([self.special_token_ids[config.IMAGE_PROMPT]])
866
-
867
- if len(responses.shape) == 3:
868
- # Tricky case where we have multiple questions, each of which has multiple answers
869
- assert len(prompts.shape) == 2
870
-
871
- # Also shift the last tokens to the response segment since that tokens will
872
- # have multiple possible target tokens to predict
873
- last_prompt_tokens = prompts[:, -1:]
874
- last_prompt_tokens = tf.repeat(last_prompt_tokens, responses.row_lengths())
875
- last_prompt_tokens = tf.RaggedTensor.from_row_splits(
876
- values=tf.RaggedTensor.from_row_lengths(
877
- values=last_prompt_tokens,
878
- row_lengths=tf.ones_like(last_prompt_tokens, dtype=responses.row_splits.dtype)
879
- ),
880
- row_splits=responses.row_splits
881
- )
882
- responses = tf.concat([last_prompt_tokens, responses], 2)
883
- prompts = prompts[:, :-1]
884
-
885
- shared_prefix = image_tokens
886
- segmented_suffix = tf.concat([tf.expand_dims(prompts, 1), responses], 1)
887
- targets = tf.concat([shared_prefix, segmented_suffix.values.values], 0)
888
-
889
- segmented_mask = tf.concat([
890
- tf.zeros_like(tf.expand_dims(prompts, 1), dtype=tf.float32),
891
- tf.concat([
892
- tf.zeros_like(last_prompt_tokens, dtype=tf.float32),
893
- response_mask
894
- ], 2)
895
- ], 1).values.values
896
- decoder_loss_weights = tf.concat(
897
- [tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
898
-
899
- text_segment_ids = get_3d_subsegments(segmented_suffix)
900
- subsegment_ids = tf.concat([
901
- tf.zeros_like(shared_prefix) + tf.reduce_max(text_segment_ids)+1,
902
- text_segment_ids], 0)
903
- subsegment_ids = tf.cast(subsegment_ids, tf.int32)
904
- else:
905
- if len(prompts.shape) == 1:
906
- # One prompt for all responses, we use the last token of the prompt as the
907
- # first token of each response segment since there will be multiple targets
908
- # for that token, the remaining targets are part of the prefix
909
- shared_prefix = tf.concat([image_tokens, prompts[:-1]], 0)
910
- prompts = prompts[-1:]
911
- prompts = tf.tile(tf.expand_dims(prompts, axis=0), [tf.shape(text)[0], 1])
912
- else:
913
- shared_prefix = image_tokens
914
-
915
- # Separate prompt for each response
916
- segmented_suffix = tf.concat([prompts, responses], 1)
917
- segmented_mask = tf.concat([tf.zeros_like(prompts, dtype=tf.float32), response_mask], 1).values
918
-
919
- targets = tf.concat([shared_prefix, segmented_suffix.values], 0)
920
- decoder_loss_weights = tf.concat(
921
- [tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
922
- subsegments = tf.ragged.row_splits_to_segment_ids(segmented_suffix.row_splits) + 1
923
- subsegment_ids = tf.concat([tf.zeros_like(shared_prefix)+10000,
924
- tf.cast(subsegments, tf.int32)], 0)
925
- return targets, decoder_loss_weights, subsegment_ids
926
-
927
- def get_tokens_input(self, messages, for_inference=False, text_weights=None):
928
- """Gets the token input for an example, using image placeholder tokens to
929
- indicate where images features should be inserted
930
-
931
- inputs
932
- messages: List or tensor users/system text messages, can have image placeholder tokens
933
- for_inference: bool, if true truncate the messages if it is a system message
934
- text_weights: Weights per a system message
935
-
936
- returns
937
- tokens: [n_tokens] tf.int32 token inputs with image placeholder tokens
938
- loss_mask: [n_tokens] tf.float32 token weights for loss
939
- subsegment: [n_tokens] tf.int32 or None, subsegment ids used to build more complex
940
- attention masks if needed
941
- """
942
- if isinstance(messages, tf.RaggedTensor):
943
- assert not for_inference, "Cannot have multiple target messages for inference"
944
- return self.get_multi_message_token_input(messages, text_weights)
945
- elif len(tf.shape(messages[-1])) > 0:
946
- assert not for_inference, "Cannot have multiple target messages for inference"
947
- assert len(messages) == 2
948
- prompt = messages[0]
949
- response = messages[1]
950
- return self.get_multi_response_token_input(prompt, response, text_weights)
951
- else:
952
- messages = tf.convert_to_tensor(messages)
953
- if for_inference:
954
- if tf.shape(messages) % 2 == 0:
955
- # Remove the last message since the model should predict it
956
- messages = messages[:-1]
957
-
958
- # Apply system formatting
959
- ix = tf.range(tf.shape(messages)[0])
960
- is_response = ix % 2 == 1
961
- messages = tf.map_fn(
962
- self._format_message, elems=(messages, ix), fn_output_signature=tf.string)
963
-
964
- # Tokenize
965
- messages = self.tokenizer.encode_tf(messages)
966
-
967
- # Add EOS to system messages
968
- is_response_int = tf.cast(is_response, tf.int32)
969
- eos = tf.RaggedTensor.from_row_lengths(
970
- tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
971
- tf.cast(is_response_int, messages.row_splits.dtype)
972
- )
973
- messages = tf.concat([messages, eos], axis=1)
974
- targets = messages.values
975
-
976
- # Build mask over system responses
977
- mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
978
- decoder_loss_weights = tf.cast(mask.values, tf.float32)
979
- if text_weights is not None:
980
- decoder_loss_weights = decoder_loss_weights * text_weights
981
- return messages.values, decoder_loss_weights, None
982
-
983
- def preprocess(self, image, input_text, is_training=False,
984
- seq_len=None, pad_images=1, style=None, for_inference=True):
985
- """Get input tensors for the given image/text data
986
-
987
- image: [h, w, 3] numpy uint8 array of image pixels
988
- input_text: string input text, a list of text for a multi-turn conversation or dictionary
989
- of inputs to use to build the prompt from a template
990
- is_training: allow training-time preprocessing (e.g., image augmentation)
991
- seq_len: pad input tokens to `seq_len`
992
- pad_images: pad input images to `self.get_max_total_crops()`
993
- style: Style to use for prompt templating
994
- """
995
- if image is not None and len(tf.shape(image)) == 3:
996
- image = tf.expand_dims(image, axis=0)
997
-
998
- messages = self.get_messages(input_text, style, is_training, for_inference=for_inference, user_prompt_seed=None, system_prompt_seed=None)
999
- targets, loss_masks, subsegments = self.get_tokens_input(messages, for_inference=for_inference)
1000
- batch = self.build_multimodel_features(
1001
- targets, loss_masks, subsegments, image, is_training)
1002
-
1003
- # Optionally padding to get constant sized arrays
1004
- if pad_images:
1005
- max_crops = self.get_max_total_crops() * pad_images
1006
- image = batch["images"]
1007
- n = max_crops - tf.shape(batch["images"])[0]
1008
- batch["images"] = tf.pad(image, [[0, n], [0, 0], [0, 0]], constant_values=-1)
1009
- if self.image_padding_mask:
1010
- m = max_crops - tf.shape(batch["image_masks"])[0]
1011
- batch["image_masks"] = tf.pad(batch["image_masks"], [[0, m], [0, 0]], constant_values=-1)
1012
- batch["image_input_idx"] = tf.pad(batch["image_input_idx"], [[0, n], [0, 0]], constant_values=-1)
1013
-
1014
- if seq_len is not None:
1015
- targets = batch["target_tokens"]
1016
- if seq_len < len(targets):
1017
- raise ValueError("Sequence length too short")
1018
- n = seq_len - len(targets)
1019
- batch["target_tokens"] = tf.pad(targets, [[0, n]], constant_values=-1)
1020
- batch["loss_masks"] = tf.pad(batch["loss_masks"], [[0, n]], constant_values=-1)
1021
-
1022
- batch = self.get_post_mixing_preprocessor(pack=False)._convert_example(batch)
1023
- return batch
1024
-
1025
- def get_user_prompt(self, style, example, is_training=True, for_inference=False, seed=None):
1026
- """Build a list of strings of what a user might type in to the model for the given example,
1027
- and its responses, by applying a prompt template to the fields in `example`
1028
-
1029
- Can return multiple strings for one message for multi-response examples
1030
- """
1031
- if "style" in example:
1032
- style = example["style"]
1033
-
1034
- if "prompt" in example:
1035
- # Examples have a complete user prompt pre-specified, usually for eval sets
1036
- prompt = example["prompt"]
1037
-
1038
- elif self.prompt_templates == "none":
1039
- # Bare-bone prompt with not templating of instructions
1040
- if "prompt" in example:
1041
- prompt = example["prompt"]
1042
- elif "refexp" in example:
1043
- prompt = example["refexp"]
1044
- elif "question" in example and "options" in example:
1045
- prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1046
- elif "question" in example:
1047
- prompt = example["question"]
1048
- else:
1049
- prompt = ""
1050
-
1051
- elif self.prompt_templates == "uber_model":
1052
- if not isinstance(style, str):
1053
- tf.debugging.assert_equal(tf.logical_or(
1054
- style == "ai2_diagram_no_letter",
1055
- style == "ai2_diagram",
1056
- ), True)
1057
- prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1058
- else:
1059
- # We template long captions and pointing since they are "demo" tasks, and use
1060
- # plain text for everything else
1061
- if style == "long_caption":
1062
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
1063
- elif style == "pointing":
1064
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
1065
- elif style == "point_count":
1066
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["point_count"], example, seed)
1067
- elif "prompt" in example:
1068
- prompt = example["prompt"]
1069
- elif "refexp" in example:
1070
- prompt = example["refexp"]
1071
- elif "question" in example and "options" in example:
1072
- prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1073
- elif "question" in example:
1074
- prompt = example["question"]
1075
- else:
1076
- prompt = ""
1077
-
1078
- elif self.prompt_templates == "uber_model_pointing":
1079
- if style == "long_caption":
1080
- long_captions = GENERAL_PROMPTS_V1["long_caption_no_pointing"]
1081
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
1082
- elif style == "pointing":
1083
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
1084
- elif style in [
1085
- "scifi_charts_explanation",
1086
- "scifi_table_explanation",
1087
- "scifi_document_explanation",
1088
- "scifi_diagram_explanation",
1089
- "user_qa",
1090
- "long_caption",
1091
- ]:
1092
- raise NotImplementedError()
1093
- if style == "long_caption":
1094
- prompts = GENERAL_PROMPTS_V1["long_caption"]
1095
- elif "prompt" in example:
1096
- prompts = tf.expand_dims(example["prompt"], axis=0)
1097
- else:
1098
- prompts = tf.expand_dims(example["question"], axis=0)
1099
- suffixes = []
1100
- for suffix in GENERAL_PROMPTS_V1["no_pointing_suffix"]:
1101
- if not suffix[0].isspace():
1102
- suffix = " " + suffix
1103
- suffixes.append(suffix)
1104
- no_point_prompts = tf.reshape(tf.strings.join([
1105
- tf.tile(tf.expand_dims(suffixes, 1), [1, tf.shape(prompts)[1]]),
1106
- tf.tile(prompts, [len(suffixes), 1]),
1107
- ]), [-1])
1108
- # prefixes = []
1109
- # for prefix in GENERAL_PROMPTS_V1["no_pointing_prefix"]:
1110
- # if not prefix[0].isspace():
1111
- # prefix = prefix + " "
1112
- # prefixes.append(prompts + prefix)
1113
- prompt = apply_keyword_prompt(no_point_prompts, example, seed, keywords=[])
1114
- elif "prompt" in example:
1115
- prompt = example["prompt"]
1116
- elif "refexp" in example:
1117
- prompt = example["refexp"]
1118
- elif "question" in example and "options" in example:
1119
- prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1120
- elif "question" in example:
1121
- prompt = example["question"]
1122
- else:
1123
- prompt = ""
1124
-
1125
- elif self.prompt_templates == "general_instructions_v1":
1126
- if isinstance(style, str):
1127
- prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1[STYLE_TO_GENERAL_PROMPT[style]], example, seed)
1128
- elif isinstance(style, list):
1129
- # This ia bit of hack to allow apply prompts to joint caption/transcript data
1130
- # FIXME ideally we can apply the templating to multiple styles more generally
1131
- def _apply(_style, ix):
1132
- tmp = dict(example)
1133
- # prevent apply_keyword_prompt for generating multiple templates
1134
- tmp["text"] = tmp["text"][0]
1135
- if _style == "long_caption":
1136
- return apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], tmp, seed)
1137
- elif _style == "transcript":
1138
- return apply_keyword_prompt(GENERAL_PROMPTS_V1["transcript"], tmp, seed)
1139
- else:
1140
- raise NotImplementedError(_style)
1141
- prompt = [_apply(x, ix) for ix, x in enumerate(style)]
1142
- else:
1143
- raise NotImplementedError()
1144
-
1145
- elif self.prompt_templates == "zero_shot_v1":
1146
- assert style is not None
1147
- if not isinstance(style, str):
1148
- # FIXME can we handle tensor style's in a better way?
1149
- if style == "ai2_diagram":
1150
- prompt = "Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"
1151
- prompt = apply_keyword_prompt([prompt], example, seed)
1152
- elif style == "ai2_diagram_no_letter":
1153
- prompt = "Question: {question}\nAnswer with correct answer option only\nOptions: {options}\nAnswer:"
1154
- prompt = apply_keyword_prompt([prompt], example, seed)
1155
- else:
1156
- prompt = ""
1157
- tf.debugging.assert_equal(prompt != "", True)
1158
- else:
1159
- general_style = STYLE_TO_GENERAL_PROMPT[style]
1160
- if general_style == "short_answer":
1161
- prompt = apply_keyword_prompt(["Question: {question} Answer with as few words as possible. Answer:"], example, seed)
1162
- elif general_style == "multiple_choice":
1163
- prompt = apply_keyword_prompt(["Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"], example, seed)
1164
- elif general_style == "count_bench":
1165
- prompt = apply_keyword_prompt(["Question: How many {object} are there?\nRespond with only a number.\nAnswer:"], example, seed)
1166
- else:
1167
- raise NotImplementedError(general_style)
1168
-
1169
- elif self.prompt_templates == "zero_shot_v2":
1170
- assert style is not None
1171
-
1172
- if self.prompt_override:
1173
- prompt = apply_keyword_prompt([self.prompt_override], example, seed)
1174
- elif not isinstance(style, str):
1175
- if style == "ai2_diagram":
1176
- prompt = "{question} Answer with correct answer option letter only. Options: {options}"
1177
- prompt = apply_keyword_prompt([prompt], example, seed)
1178
- elif style == "ai2_diagram_no_letter":
1179
- prompt = "{question} Answer with correct answer option only. Options: {options}"
1180
- prompt = apply_keyword_prompt([prompt], example, seed)
1181
- else:
1182
- prompt = ""
1183
- tf.debugging.assert_equal(prompt != "", True)
1184
- else:
1185
- if style in ["vqa2", "gqa", "tally_qa", "okvqa", "a_okvqa_da"]:
1186
- prompt = "Answer with a single word. {question}"
1187
- elif style in ["text_vqa", "doc_qa", "info_qa", "chart_qa", "st_qa", "ocr_vqa", "dv_qa", "tabwmp_da", "figure_qa", "figure_qa_zero_shot", "plot_qa"]:
1188
- prompt = "{question}\nRespond as concisely as possible, do not output anything other than the answer."
1189
- elif STYLE_TO_GENERAL_PROMPT[style] == "multiple_choice":
1190
- prompt = "{question} Answer with correct answer option letter only. Options: {options}"
1191
- elif STYLE_TO_GENERAL_PROMPT[style] == "short_answer":
1192
- prompt = "{question} Answer with as few words as possible."
1193
- elif style == "vtabfact":
1194
- prompt = "{question}"
1195
- elif style == "count_bench":
1196
- prompt = "How many {object} are there?\nRespond with only a number."
1197
- else:
1198
- raise NotImplementedError(style)
1199
- prompt = apply_keyword_prompt([prompt], example, seed)
1200
- else:
1201
- raise NotImplementedError(self.prompt_templates)
1202
-
1203
- if for_inference:
1204
- return [prompt]
1205
- else:
1206
- return [prompt, example["text"]]
1207
-
1208
- def get_system_prompt(self, style, example, for_inference,
1209
- messages, seed=None):
1210
- if isinstance(style, str) and style == "count_bench":
1211
- style = "ok_vqa"
1212
-
1213
- if self.system_prompt == "style":
1214
- if isinstance(style, str):
1215
- prefix = style + ":"
1216
- else:
1217
- prefix = tf.strings.join([style, ":"])
1218
-
1219
- elif self.system_prompt == "demo_or_style":
1220
- if isinstance(style, str):
1221
- if style == "android_control" or style == "demo":
1222
- # android is a special case since I hacked in prefix in the preprocessor
1223
- prefix = ""
1224
- elif style in ["scifi_demo", "synthetic_qa"] or style in DEMO_STYLES:
1225
- if style == "scifi_demo":
1226
- p_no_prompt = 0.2
1227
- elif style == "synthetic_qa":
1228
- p_no_prompt = 0.25
1229
- else:
1230
- p_no_prompt = 0.9
1231
- if len(tf.shape(messages)) > 1:
1232
- n_messages = tf.shape(messages)[1]
1233
- style = tf.tile(tf.expand_dims(style, axis=0), [n_messages])
1234
- r = tf.random.stateless_uniform([n_messages], seed, 0, 1)
1235
- else:
1236
- r = tf.random.stateless_uniform((), seed, 0, 1)
1237
- prefix = tf.where(r < p_no_prompt, "", tf.strings.join([style + ":"]))
1238
- else:
1239
- prefix = style + ":"
1240
- else:
1241
- if tf.reduce_any(style == tf.constant(DEMO_STYLES + ["scifi_demo", "android_control", "demo"])):
1242
- prefix = ""
1243
- else:
1244
- prefix = tf.strings.join([style, ":"])
1245
-
1246
- elif self.system_prompt in ["long_caption_length_hint", "style_long_caption_length_hint"]:
1247
- if seed is not None:
1248
- raise NotImplementedError("Determinism")
1249
- std = 25
1250
- use_hint = tf.logical_or(
1251
- tf.equal(style, "long_caption"), tf.equal(style, "transcript"))
1252
- if self.system_prompt == "style_long_caption_length_hint":
1253
- default = tf.strings.join([style, ": "])
1254
- else:
1255
- default = ""
1256
- if for_inference:
1257
- assert len(tf.shape(use_hint)) == 0
1258
- if self.default_inference_len and use_hint:
1259
- prefix = tf.strings.join([style, " ", str(self.default_inference_len), ": "])
1260
- else:
1261
- prefix = default
1262
- else:
1263
- std = 25
1264
- n = tf.strings.length(messages[-1])
1265
- n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
1266
- hint = tf.strings.join([style, " ", tf.strings.as_string(n//15), ": "])
1267
- use_hint = tf.logical_and(use_hint, tf.random.uniform(tf.shape(hint)) > 0.1)
1268
- prefix = tf.where(use_hint, hint, default)
1269
-
1270
- elif for_inference and self.system_prompt in ["style_and_length", "style_and_length_v2"]:
1271
- v2 = self.system_prompt == "style_and_length_v2"
1272
- if example.get("length_cond") is not None:
1273
- # Examples have individual length conditioning
1274
- n = tf.strings.as_string(example["length_cond"])
1275
- else:
1276
- inference_len = self.default_inference_len
1277
- n = None if inference_len is None else str(inference_len)
1278
- logging.warning(f"eval len: {n}")
1279
- if n is not None and tf.strings.length(n) > 0: # allow empty string to signal unconditioned
1280
- prefix = tf.strings.join([style, " ", n, ":"])
1281
- else:
1282
- prefix = tf.strings.join([style, ":" if v2 else " :"])
1283
- elif self.system_prompt in ["style_and_length", "style_and_length_v2"]:
1284
- v2 = self.system_prompt == "style_and_length_v2"
1285
- std = 25
1286
- logging.info(f"style prompt std={std}, percent=10")
1287
- if seed is not None:
1288
- seeds = tf.random.split(seed)
1289
- p = tf.random.stateless_uniform((), seed=seeds[0])
1290
- else:
1291
- p = tf.random.uniform(())
1292
- if p > 0.10:
1293
- n = tf.strings.length(messages[-1])
1294
- if seed is not None:
1295
- n += tf.cast(tf.random.stateless_normal(n.shape, seed=seeds[1])*std, tf.int32)
1296
- else:
1297
- n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
1298
- n = tf.strings.as_string(n//15)
1299
- prefix = tf.strings.join([style, " ", n, ":"])
1300
- else:
1301
- prefix = tf.strings.join([style, ":" if v2 else " :"])
1302
- else:
1303
- raise NotImplementedError(self.system_prompt)
1304
-
1305
- return prefix
1306
-
1307
- def preprend_system_prompt(self, style, example, for_inference, messages, seed=None):
1308
- prefix = self.get_system_prompt(style, example, for_inference, messages, seed=seed)
1309
- separator = tf.where(tf.logical_and(
1310
- tf.strings.length(prefix) > 0, tf.strings.length(messages[0]) > 0), " ", "")
1311
- with_system_prompt = tf.strings.join([prefix, separator, messages[0]])
1312
- if isinstance(messages, list):
1313
- messages = [with_system_prompt] + messages[1:]
1314
- else:
1315
- messages = tf.concat([tf.expand_dims(with_system_prompt, 0), messages[1:]], axis=0)
1316
- return messages
1317
-
1318
- def get_messages(self, ex, style, is_training, for_inference, user_prompt_seed, system_prompt_seed):
1319
- if isinstance(ex, list):
1320
- messages = ex
1321
- elif isinstance(ex, str):
1322
- messages = [ex]
1323
- elif "messages" in ex:
1324
- messages = ex["messages"]
1325
- else:
1326
- # Apply a prompt template
1327
- messages = self.get_user_prompt(style, ex, is_training, for_inference=for_inference, seed=user_prompt_seed)
1328
-
1329
- # Maybe add a system prompt. The system prompt gets concatenated with the first user input
1330
- if self.system_prompt and self.system_prompt != "none":
1331
- if isinstance(ex, dict):
1332
- style = ex.get("style", style)
1333
-
1334
- if isinstance(messages, tf.RaggedTensor):
1335
- n = tf.shape(messages)[0]
1336
- message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
1337
- seeds = tf.random.split(system_prompt_seed, n)
1338
- for i in range(n):
1339
- message_arr = message_arr.write(i, self.preprend_system_prompt(style, None, for_inference, messages[i], seed=seeds[i]))
1340
- messages = tf.RaggedTensor.from_row_splits(
1341
- values=message_arr.concat(), row_splits=messages.row_splits)
1342
- else:
1343
- messages = self.preprend_system_prompt(style, ex, for_inference, messages, seed=system_prompt_seed)
1344
-
1345
- return messages
1346
-
1347
- def get_preprocessor(self, is_training, for_inference, style=None, include_metadata=None):
1348
- """Build a preprocessing function that can be applied ot a tf.data.Dataset"""
1349
- vocab = self.tokenizer
1350
- include_response = not for_inference
1351
- if include_metadata is None:
1352
- include_metadata = for_inference
1353
-
1354
- @seqio.map_over_dataset(num_seeds=2)
1355
- def to_inputs_and_targets(ex, seeds):
1356
- if "unconditioned" in ex:
1357
- raise NotImplementedError()
1358
- if "image" not in ex:
1359
- image = None
1360
- elif ex['image'].dtype == tf.string:
1361
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1362
- else:
1363
- image = ex['image']
1364
- raw_image = image
1365
- if image is not None and len(tf.shape(image)) == 3:
1366
- image = tf.expand_dims(image, axis=0)
1367
-
1368
- unconditioned = self.unconditioned
1369
- if unconditioned and isinstance(unconditioned, float):
1370
- assert image is not None
1371
- if is_training and tf.random.uniform((), 0, 1, dtype=tf.float32) < unconditioned:
1372
- image = image[:0]
1373
- elif unconditioned:
1374
- image = None
1375
-
1376
- messages = self.get_messages(ex, style, is_training, for_inference, seeds[0], seeds[1])
1377
- targets, loss_masks, subsegments = self.get_tokens_input(
1378
- messages, for_inference, ex.get("text_weights"))
1379
- # if "scifi" in style and style.endswith("_explanation"):
1380
- # logging.warning(f"No loss on EOS for {style}")
1381
- # loss_masks = tf.where(targets == self.tokenizer.eos_token_id, tf.zeros_like(loss_masks), loss_masks)
1382
- out = self.build_multimodel_features(targets, loss_masks, subsegments, image, is_training)
1383
-
1384
- if include_metadata:
1385
- # FIXME remove these special cases
1386
- if "text" in ex:
1387
- if len(ex["text"].shape) > 0:
1388
- # FIXME can this be variable lengths after all?
1389
- out["metadata/captions"] = tf.strings.reduce_join(
1390
- tf.strings.regex_replace(ex['text'], "\\s+", " "),
1391
- separator="\n"
1392
- )
1393
- else:
1394
- out["metadata/captions"] = ex["text"]
1395
-
1396
- if "image_url" in ex:
1397
- out["metadata/image_url"] = ex["image_url"]
1398
- elif "url" in ex:
1399
- out["metadata/image_url"] = ex["url"]
1400
- if "image_id" in ex:
1401
- out["metadata/image_id"] = ex["image_id"]
1402
- for k, v in ex.items():
1403
- if k.startswith("metadata"):
1404
- out[k] = v
1405
- if raw_image is not None and "metadata/image_size" not in out:
1406
- img_h = tf.shape(raw_image)[0]
1407
- img_w = tf.shape(raw_image)[1]
1408
- out["metadata/image_size"] = [img_w, img_h]
1409
- if "metadata/image_url" not in out and raw_image is not None:
1410
- if len(ex["image"].shape) < 4:
1411
- # For visualizations FIXME can we make this variable length
1412
- out["metadata/image"] = tf.io.encode_jpeg(
1413
- tf.image.convert_image_dtype(raw_image, tf.uint8))
1414
- return out
1415
- return to_inputs_and_targets
1416
-
1417
- def get_post_mixing_preprocessor(self, pack=False):
1418
- """Build a feature conversion function that can be applied ot a tf.data.Dataset
1419
-
1420
- This function applies a second stage of pre-processing, but unlike `self.get_preprocessor`
1421
- this stage can be applied after mixing tf.data.Datasets into a mixture
1422
- """
1423
- return MultiModalLMFeatureConverter(
1424
- loss_token_weighting=self.loss_token_weighting,
1425
- bos_id=self.tokenizer.bos_token_id,
1426
- fix_image_input_idx=self.fix_image_input_idx,
1427
- pack=pack,
1428
- special_tokens=list(self.special_token_ids.values()),
1429
- )
1430
-
1431
-
1432
- class MultiModalLMFeatureConverter:
1433
-
1434
- def __init__(
1435
- self, pack: bool = False, loss_token_weighting: str=None, bos_id: int = 1,
1436
- special_tokens=None, fix_image_input_idx=2
1437
- ):
1438
- self.pack = pack
1439
- self.bos_id = bos_id
1440
- self.fix_image_input_idx = fix_image_input_idx
1441
- self.special_tokens = tf.constant(special_tokens) if special_tokens else None
1442
- self.loss_token_weighting = loss_token_weighting
1443
-
1444
- def _convert_example(
1445
- self, features: Mapping[str, tf.Tensor]
1446
- ) -> Mapping[str, tf.Tensor]:
1447
- """Convert an LM example into an example with model features."""
1448
- # targets_segment_id is present only for a packed dataset.
1449
- decoder_input_tokens = make_autoregressive_inputs(
1450
- features["target_tokens"],
1451
- sequence_id=features.get("targets_segment_ids", None),
1452
- bos_id=self.bos_id,
1453
- )
1454
-
1455
- tf.assert_equal(
1456
- True,
1457
- tf.reduce_all(decoder_input_tokens[-1] != self.special_tokens),
1458
- message="An input ends with an image special token",
1459
- )
1460
-
1461
- image_input_idx = features["image_input_idx"]
1462
- if self.fix_image_input_idx == 2:
1463
- # plus one sine we have added BOS to the inputs
1464
- image_input_idx = tf.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
1465
- else:
1466
- # Some old models trained like this, sometimes image_input_idx will go from -1 -> 0 didn't
1467
- # effect performance but keep this code path for backwards compatiblity with those checkpoints
1468
- image_input_idx = image_input_idx + 1
1469
-
1470
- d = {
1471
- "target_tokens": features["target_tokens"],
1472
- "input_tokens": decoder_input_tokens,
1473
- "loss_masks": features["loss_masks"],
1474
- "images": features["images"],
1475
- "image_input_idx": image_input_idx
1476
- }
1477
- if "image_masks" in features:
1478
- d["image_masks"] = features["image_masks"]
1479
-
1480
- has_custom_text_weight = features.get("has_custom_loss_weight", False)
1481
-
1482
- if "subsegment_ids" in features:
1483
- subsegment_ids = make_autoregressive_inputs(
1484
- features["subsegment_ids"],
1485
- sequence_id=features.get("targets_segment_ids", None),
1486
- bos_id=features["subsegment_ids"][0],
1487
- )
1488
-
1489
- # Subsegment have a position based on the sum of previous positions they can attend to
1490
- position_ids = tf.zeros_like(subsegment_ids)
1491
- unique_segments = tf.unique(subsegment_ids)[0]
1492
- for i in unique_segments:
1493
- segment_position_ids = tf.cumsum(tf.cast(subsegment_ids >= i, tf.int32)) - 1
1494
- position_ids = tf.where(subsegment_ids == i, segment_position_ids, position_ids)
1495
-
1496
- # Apply loss weighting, this is done here so it occurs after truncation
1497
- if has_custom_text_weight:
1498
- pass
1499
- elif self.loss_token_weighting in ["subsegments", "root_subsegments"]:
1500
- n_loss_segments = tf.shape(tf.unique(tf.boolean_mask(subsegment_ids, d["loss_masks"] > 0))[0])[0]
1501
- n_loss_segments = tf.maximum(tf.cast(n_loss_segments, tf.float32), 1)
1502
- weight = 1/n_loss_segments if self.loss_token_weighting == "subsegments" else tf.math.rsqrt(n_loss_segments)
1503
- d["loss_masks"] = tf.where(d["loss_masks"] > 0, d["loss_masks"]*weight, d["loss_masks"])
1504
- elif self.loss_token_weighting is not None:
1505
- raise NotImplementedError(self.loss_token_weighting)
1506
-
1507
- d["subsegment_ids"] = subsegment_ids
1508
- d["position_ids"] = position_ids
1509
- else:
1510
- if self.loss_token_weighting not in [None, "subsegments", "root_subsegments"] and not has_custom_text_weight:
1511
- raise NotImplementedError(self.loss_token_weighting)
1512
- if self.pack:
1513
- d["decoder_segment_ids"] = features["targets_segment_ids"]
1514
- d["decoder_positions"] = features["targets_positions"]
1515
-
1516
- for k in features:
1517
- if k.startswith("metadata/"):
1518
- d[k] = features[k]
1519
- return d
1520
-
1521
- def _pack_or_pad(self, ds, task_feature_lengths):
1522
- if self.pack:
1523
- raise NotImplementedError()
1524
- else:
1525
- return trim_and_pad_dataset(ds, task_feature_lengths)
1526
-
1527
- def __call__(self, ds: tf.data.Dataset, task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
1528
- """Convert the dataset to be fed to a language model."""
1529
- task_feature_lengths = dict(task_feature_lengths)
1530
-
1531
- if "images" in ds.element_spec and "images" in task_feature_lengths:
1532
- # Images should never be truncated
1533
- ds = assert_not_truncated(ds, ["images", "image_input_idx"], task_feature_lengths["images"])
1534
-
1535
- if any(x.startswith("metadata/") for x in ds.element_spec):
1536
- # Metadata indicates the dataset is being used for inference, inference datasets
1537
- # should not be truncated
1538
- ds = assert_not_truncated(ds, ["target_tokens"], task_feature_lengths["target_tokens"])
1539
-
1540
- if "image_masks" in ds.element_spec and "images" in task_feature_lengths:
1541
- task_feature_lengths["image_masks"] = task_feature_lengths["images"]
1542
- if "subsegment_ids" in ds.element_spec and "target_tokens" in task_feature_lengths:
1543
- task_feature_lengths["subsegment_ids"] = task_feature_lengths["target_tokens"]
1544
- if "loss_masks" not in task_feature_lengths and "target_tokens" in task_feature_lengths:
1545
- task_feature_lengths["loss_masks"] = task_feature_lengths["target_tokens"]
1546
- ds = self._pack_or_pad(ds, task_feature_lengths)
1547
-
1548
- return ds.map(
1549
- self._convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)