ydshieh HF staff commited on
Commit
3e604a5
·
1 Parent(s): a89397f

Upload 6 files

Browse files
image_processing_kosmos2.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Kosmos2."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ convert_to_rgb,
24
+ get_resize_output_image_size,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ )
28
+ from ...image_utils import (
29
+ OPENAI_CLIP_MEAN,
30
+ OPENAI_CLIP_STD,
31
+ ChannelDimension,
32
+ ImageInput,
33
+ PILImageResampling,
34
+ infer_channel_dimension_format,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ )
39
+ from ...utils import TensorType, is_vision_available, logging
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ if is_vision_available():
46
+ import PIL
47
+
48
+
49
+ class Kosmos2ImageProcessor(BaseImageProcessor):
50
+ r"""
51
+ Constructs a CLIP image processor.
52
+
53
+ Args:
54
+ do_resize (`bool`, *optional*, defaults to `True`):
55
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
56
+ `do_resize` in the `preprocess` method.
57
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
58
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
59
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
60
+ method.
61
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
62
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
63
+ do_center_crop (`bool`, *optional*, defaults to `True`):
64
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
65
+ `preprocess` method.
66
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
67
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
68
+ method.
69
+ do_rescale (`bool`, *optional*, defaults to `True`):
70
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
71
+ the `preprocess` method.
72
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
73
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
74
+ method.
75
+ do_normalize:
76
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
77
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
78
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
79
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
80
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
81
+ Image standard deviation.
82
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
83
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
84
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
85
+ """
86
+
87
+ model_input_names = ["pixel_values"]
88
+
89
+ def __init__(
90
+ self,
91
+ do_resize: bool = True,
92
+ size: Dict[str, int] = None,
93
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
94
+ do_center_crop: bool = True,
95
+ crop_size: Dict[str, int] = None,
96
+ do_rescale: bool = True,
97
+ rescale_factor: Union[int, float] = 1 / 255,
98
+ do_normalize: bool = True,
99
+ image_mean: Optional[Union[float, List[float]]] = None,
100
+ image_std: Optional[Union[float, List[float]]] = None,
101
+ do_convert_rgb: bool = True,
102
+ **kwargs,
103
+ ) -> None:
104
+ super().__init__(**kwargs)
105
+ size = size if size is not None else {"shortest_edge": 224}
106
+ size = get_size_dict(size, default_to_square=False)
107
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
108
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
109
+
110
+ self.do_resize = do_resize
111
+ self.size = size
112
+ self.resample = resample
113
+ self.do_center_crop = do_center_crop
114
+ self.crop_size = crop_size
115
+ self.do_rescale = do_rescale
116
+ self.rescale_factor = rescale_factor
117
+ self.do_normalize = do_normalize
118
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
119
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
120
+ self.do_convert_rgb = do_convert_rgb
121
+
122
+ def resize(
123
+ self,
124
+ image: np.ndarray,
125
+ size: Dict[str, int],
126
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
127
+ data_format: Optional[Union[str, ChannelDimension]] = None,
128
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
129
+ **kwargs,
130
+ ) -> np.ndarray:
131
+ """
132
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
133
+ resized to keep the input aspect ratio.
134
+
135
+ Args:
136
+ image (`np.ndarray`):
137
+ Image to resize.
138
+ size (`Dict[str, int]`):
139
+ Size of the output image.
140
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
141
+ Resampling filter to use when resiizing the image.
142
+ data_format (`str` or `ChannelDimension`, *optional*):
143
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
144
+ input_data_format (`ChannelDimension` or `str`, *optional*):
145
+ The channel dimension format of the input image. If not provided, it will be inferred.
146
+ """
147
+ size = get_size_dict(size)
148
+ if "shortest_edge" not in size:
149
+ raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
150
+ output_size = get_resize_output_image_size(
151
+ image, size=size["shortest_edge"], input_data_format=input_data_format
152
+ )
153
+ return resize(
154
+ image,
155
+ size=output_size,
156
+ resample=resample,
157
+ data_format=data_format,
158
+ input_data_format=input_data_format,
159
+ **kwargs,
160
+ )
161
+
162
+ def preprocess(
163
+ self,
164
+ images: ImageInput,
165
+ do_resize: bool = None,
166
+ size: Dict[str, int] = None,
167
+ resample: PILImageResampling = None,
168
+ do_center_crop: bool = None,
169
+ crop_size: int = None,
170
+ do_rescale: bool = None,
171
+ rescale_factor: float = None,
172
+ do_normalize: bool = None,
173
+ image_mean: Optional[Union[float, List[float]]] = None,
174
+ image_std: Optional[Union[float, List[float]]] = None,
175
+ do_convert_rgb: bool = None,
176
+ return_tensors: Optional[Union[str, TensorType]] = None,
177
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
178
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
179
+ **kwargs,
180
+ ) -> PIL.Image.Image:
181
+ """
182
+ Preprocess an image or batch of images.
183
+
184
+ Args:
185
+ images (`ImageInput`):
186
+ Image to preprocess.
187
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
188
+ Whether to resize the image.
189
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
190
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
191
+ the longest edge resized to keep the input aspect ratio.
192
+ resample (`int`, *optional*, defaults to `self.resample`):
193
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
194
+ has an effect if `do_resize` is set to `True`.
195
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
196
+ Whether to center crop the image.
197
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
198
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
199
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
200
+ Whether to rescale the image.
201
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
202
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
203
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
204
+ Whether to normalize the image.
205
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
206
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
207
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
208
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
209
+ `True`.
210
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
211
+ Whether to convert the image to RGB.
212
+ return_tensors (`str` or `TensorType`, *optional*):
213
+ The type of tensors to return. Can be one of:
214
+ - Unset: Return a list of `np.ndarray`.
215
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
216
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
217
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
218
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
219
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
220
+ The channel dimension format for the output image. Can be one of:
221
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
222
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
223
+ - Unset: Use the channel dimension format of the input image.
224
+ input_data_format (`ChannelDimension` or `str`, *optional*):
225
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
226
+ from the input image. Can be one of:
227
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
228
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
229
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
230
+ """
231
+ do_resize = do_resize if do_resize is not None else self.do_resize
232
+ size = size if size is not None else self.size
233
+ size = get_size_dict(size, param_name="size", default_to_square=False)
234
+ resample = resample if resample is not None else self.resample
235
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
236
+ crop_size = crop_size if crop_size is not None else self.crop_size
237
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
238
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
239
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
240
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
241
+ image_mean = image_mean if image_mean is not None else self.image_mean
242
+ image_std = image_std if image_std is not None else self.image_std
243
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
244
+
245
+ images = make_list_of_images(images)
246
+
247
+ if not valid_images(images):
248
+ raise ValueError(
249
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
250
+ "torch.Tensor, tf.Tensor or jax.ndarray."
251
+ )
252
+
253
+ if do_resize and size is None:
254
+ raise ValueError("Size must be specified if do_resize is True.")
255
+
256
+ if do_center_crop and crop_size is None:
257
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
258
+
259
+ if do_rescale and rescale_factor is None:
260
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
261
+
262
+ if do_normalize and (image_mean is None or image_std is None):
263
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
264
+
265
+ # PIL RGBA images are converted to RGB
266
+ if do_convert_rgb:
267
+ images = [convert_to_rgb(image) for image in images]
268
+
269
+ # All transformations expect numpy arrays.
270
+ images = [to_numpy_array(image) for image in images]
271
+
272
+ if input_data_format is None:
273
+ # We assume that all images have the same channel dimension format.
274
+ input_data_format = infer_channel_dimension_format(images[0])
275
+
276
+ if do_resize:
277
+ images = [
278
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
279
+ for image in images
280
+ ]
281
+
282
+ if do_center_crop:
283
+ images = [
284
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
285
+ ]
286
+
287
+ if do_rescale:
288
+ images = [
289
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
290
+ for image in images
291
+ ]
292
+
293
+ if do_normalize:
294
+ images = [
295
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
296
+ for image in images
297
+ ]
298
+
299
+ images = [
300
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
301
+ ]
302
+
303
+ data = {"pixel_values": images}
304
+ return BatchFeature(data=data, tensor_type=return_tensors)
processing_kosmos2.py CHANGED
@@ -58,7 +58,12 @@ class Kosmos2Processor(ProcessorMixin):
58
  An instance of ['Kosmos2TokenizerFast`]. The tokenizer is a required input.
59
  """
60
  attributes = ["image_processor", "tokenizer"]
61
- image_processor_class = "CLIPImageProcessor"
 
 
 
 
 
62
  tokenizer_class = "AutoTokenizer"
63
 
64
  def __init__(self, image_processor, tokenizer):
 
58
  An instance of ['Kosmos2TokenizerFast`]. The tokenizer is a required input.
59
  """
60
  attributes = ["image_processor", "tokenizer"]
61
+ # Better to use explicit classes if local code works
62
+ # image_processor_class = "Kosmos2ImageProcessor"
63
+ # tokenizer_class = ("Kosmos2Tokenizer", "Kosmos2TokenizerFast")
64
+
65
+ # To make remote code work
66
+ image_processor_class = "AutoImageProcessor"
67
  tokenizer_class = "AutoTokenizer"
68
 
69
  def __init__(self, image_processor, tokenizer):