jadechoghari commited on
Commit
7f5b506
·
verified ·
1 Parent(s): 2345503

Create mm_utils.py

Browse files
Files changed (1) hide show
  1. mm_utils.py +260 -0
mm_utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+ from typing import Optional, Callable
8
+
9
+ from transformers import StoppingCriteria
10
+ from ferretui.constants import IMAGE_TOKEN_INDEX
11
+
12
+
13
+ def select_best_resolution(original_size, possible_resolutions):
14
+ """
15
+ Selects the best resolution from a list of possible resolutions based on the original size.
16
+
17
+ Args:
18
+ original_size (tuple): The original size of the image in the format (width, height).
19
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
20
+
21
+ Returns:
22
+ tuple: The best fit resolution in the format (width, height).
23
+ """
24
+ original_width, original_height = original_size
25
+ best_fit = None
26
+ max_effective_resolution = 0
27
+ min_wasted_resolution = float('inf')
28
+
29
+ for width, height in possible_resolutions:
30
+ scale = min(width / original_width, height / original_height)
31
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
32
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
33
+ wasted_resolution = (width * height) - effective_resolution
34
+
35
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
36
+ max_effective_resolution = effective_resolution
37
+ min_wasted_resolution = wasted_resolution
38
+ best_fit = (width, height)
39
+
40
+ return best_fit
41
+
42
+
43
+ def resize_and_pad_image(image, target_resolution, is_pad=False):
44
+ """
45
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+ Returns:
50
+ PIL.Image.Image: The resized and padded image.
51
+ """
52
+ original_width, original_height = image.size
53
+ target_width, target_height = target_resolution
54
+
55
+ if is_pad:
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+ else:
74
+ new_image = image.resize((target_width, target_height))
75
+
76
+ return new_image
77
+
78
+
79
+ def divide_to_patches(image, patch_size):
80
+ """
81
+ Divides an image into patches of a specified size.
82
+
83
+ Args:
84
+ image (PIL.Image.Image): The input image.
85
+ patch_size (int): The size of each patch.
86
+
87
+ Returns:
88
+ list: A list of PIL.Image.Image objects representing the patches.
89
+ """
90
+ patches = []
91
+ width, height = image.size
92
+ for i in range(0, height, patch_size):
93
+ for j in range(0, width, patch_size):
94
+ box = (j, i, j + patch_size, i + patch_size)
95
+ patch = image.crop(box)
96
+ patches.append(patch)
97
+
98
+ return patches
99
+
100
+
101
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
102
+ """
103
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
104
+
105
+ Args:
106
+ image_size (tuple): The size of the input image in the format (width, height).
107
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
108
+ patch_size (int): The size of each image patch.
109
+
110
+ Returns:
111
+ tuple: The shape of the image patch grid in the format (width, height).
112
+ """
113
+ if type(grid_pinpoints) is list:
114
+ possible_resolutions = grid_pinpoints
115
+ else:
116
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
117
+ width, height = select_best_resolution(image_size, possible_resolutions)
118
+ return width // patch_size, height // patch_size
119
+
120
+
121
+ def process_anyres_image(image, processor, grid_pinpoints, image_process_func: Optional[Callable] = None):
122
+ """
123
+ Process an image with variable resolutions.
124
+
125
+ Args:
126
+ image (PIL.Image.Image): The input image to be processed.
127
+ processor: The image processor object.
128
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
129
+
130
+ Returns:
131
+ torch.Tensor: A tensor containing the processed image patches.
132
+ """
133
+ if type(grid_pinpoints) is list:
134
+ possible_resolutions = grid_pinpoints
135
+ else:
136
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
137
+
138
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
139
+
140
+ # FIXME: not sure if do_pad or undo_pad may affect the referring side
141
+ image_padded = resize_and_pad_image(image, best_resolution, is_pad=False)
142
+
143
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
144
+
145
+ if image_process_func:
146
+ resized_image_h, resized_image_w = image_process_func.keywords['size']
147
+ image_original_resize = image.resize((resized_image_w, resized_image_h))
148
+ image_patches = [image_original_resize] + patches
149
+ image_patches = [image_process_func(image_patch)['pixel_values'][0]
150
+ for image_patch in image_patches]
151
+ else:
152
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
153
+ image_patches = [image_original_resize] + patches
154
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
155
+ for image_patch in image_patches]
156
+
157
+ return torch.stack(image_patches, dim=0)
158
+
159
+
160
+ def load_image_from_base64(image):
161
+ return Image.open(BytesIO(base64.b64decode(image)))
162
+
163
+
164
+ def expand2square(pil_img, background_color):
165
+ width, height = pil_img.size
166
+ if width == height:
167
+ return pil_img
168
+ elif width > height:
169
+ result = Image.new(pil_img.mode, (width, width), background_color)
170
+ result.paste(pil_img, (0, (width - height) // 2))
171
+ return result
172
+ else:
173
+ result = Image.new(pil_img.mode, (height, height), background_color)
174
+ result.paste(pil_img, ((height - width) // 2, 0))
175
+ return result
176
+
177
+
178
+ def process_images(images, image_processor, model_cfg, image_process_func: Optional[Callable] = None):
179
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
180
+ new_images = []
181
+ if image_aspect_ratio == 'pad':
182
+ for image in images:
183
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
184
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
185
+ new_images.append(image)
186
+ elif image_aspect_ratio == "anyres":
187
+ # image_processor(images, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])['pixel_values']
188
+ for image in images:
189
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints, image_process_func=image_process_func)
190
+ new_images.append(image)
191
+ else:
192
+ return image_processor(images, return_tensors='pt')['pixel_values']
193
+ if all(x.shape == new_images[0].shape for x in new_images):
194
+ new_images = torch.stack(new_images, dim=0)
195
+ return new_images
196
+
197
+
198
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
199
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
200
+
201
+ def insert_separator(X, sep):
202
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
203
+
204
+ input_ids = []
205
+ offset = 0
206
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
207
+ offset = 1
208
+ input_ids.append(prompt_chunks[0][0])
209
+
210
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
211
+ input_ids.extend(x[offset:])
212
+
213
+ if return_tensors is not None:
214
+ if return_tensors == 'pt':
215
+ return torch.tensor(input_ids, dtype=torch.long)
216
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
217
+ return input_ids
218
+
219
+
220
+ def get_model_name_from_path(model_path):
221
+ model_path = model_path.strip("/")
222
+ model_paths = model_path.split("/")
223
+ if model_paths[-1].startswith('checkpoint-'):
224
+ return model_paths[-2] + "_" + model_paths[-1]
225
+ else:
226
+ return model_paths[-1]
227
+
228
+ class KeywordsStoppingCriteria(StoppingCriteria):
229
+ def __init__(self, keywords, tokenizer, input_ids):
230
+ self.keywords = keywords
231
+ self.keyword_ids = []
232
+ self.max_keyword_len = 0
233
+ for keyword in keywords:
234
+ cur_keyword_ids = tokenizer(keyword).input_ids
235
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
236
+ cur_keyword_ids = cur_keyword_ids[1:]
237
+ if len(cur_keyword_ids) > self.max_keyword_len:
238
+ self.max_keyword_len = len(cur_keyword_ids)
239
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
240
+ self.tokenizer = tokenizer
241
+ self.start_len = input_ids.shape[1]
242
+
243
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
245
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
246
+ for keyword_id in self.keyword_ids:
247
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
248
+ if torch.equal(truncated_output_ids, keyword_id):
249
+ return True
250
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
251
+ for keyword in self.keywords:
252
+ if keyword in outputs:
253
+ return True
254
+ return False
255
+
256
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
257
+ outputs = []
258
+ for i in range(output_ids.shape[0]):
259
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
260
+ return all(outputs)