ydshieh HF staff commited on
Commit
0dce8b2
·
1 Parent(s): 456aa68

Upload 5 files

Browse files
configuration_kosmos2.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ KOSMOS-2 model configuration"""
16
+
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "microsoft/kosmos-2-patch14-224": (
29
+ "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/config.json"
30
+ ),
31
+ # See all KOSMOS-2 models at https://huggingface.co/models?filter=kosmos-2
32
+ }
33
+
34
+
35
+ class Kosmos2TextConfig(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`Kosmos2TextModel`]. It is used to instantiate a KOSMOS-2 text decoder
38
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
39
+ defaults will yield a similar configuration to that of the text decoder of the KOSMOS-2
40
+ [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.
41
+
42
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
43
+ documentation from [`PretrainedConfig`] for more information.
44
+
45
+ Args:
46
+ vocab_size (`int`, *optional*, defaults to 65037):
47
+ Vocabulary size of the Kosmos2 model. Defines the number of different tokens that can be represented by the
48
+ `inputs_ids` passed when calling [`Kosmos2Model`].
49
+ embed_dim (`int`, *optional*, defaults to 2048):
50
+ Dimensionality of the layers and the pooler layer.
51
+ layers (`int`, *optional*, defaults to 24):
52
+ Number of hidden layers in the Transformer encoder.
53
+ attention_heads (`int`, *optional*, defaults to 32):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ ffn_dim (`int`, *optional*, defaults to 8192):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
67
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
68
+ just in case (e.g., 512 or 1024 or 2048).
69
+ layerdrop (`float`, *optional*, defaults to 0.0):
70
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
71
+ for more details.
72
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
73
+ The epsilon used by the layer normalization layers.
74
+ scale_embedding (`bool`, *optional*, defaults to `True`):
75
+ Scale embeddings by diving by sqrt(embed_dim).
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models).
78
+
79
+ Example:
80
+
81
+ ```python
82
+ >>> from transformers import Kosmos2TextConfig, Kosmos2TextModel
83
+
84
+ >>> # Initializing a Kosmos2TextConfig microsoft/kosmos-2-patch14-224 style configuration
85
+ >>> configuration = Kosmos2TextConfig()
86
+
87
+ >>> # Initializing a Kosmos2TextModel (with random weights) from the microsoft/kosmos-2-patch14-224 style configuration
88
+ >>> model = Kosmos2TextModel(configuration)
89
+
90
+ >>> # Accessing the model configuration
91
+ >>> configuration = model.config
92
+ ```"""
93
+ model_type = "kosmos_2_text_model"
94
+ keys_to_ignore_at_inference = ["past_key_values"]
95
+ attribute_map = {"num_attention_heads": "attention_heads", "hidden_size": "embed_dim"}
96
+
97
+ def __init__(
98
+ self,
99
+ vocab_size=65037,
100
+ max_position_embeddings=2048,
101
+ embed_dim=2048,
102
+ layers=24,
103
+ ffn_dim=8192,
104
+ attention_heads=32,
105
+ activation_function="gelu",
106
+ dropout=0.1,
107
+ attention_dropout=0.1,
108
+ activation_dropout=0.0,
109
+ layerdrop=0.0,
110
+ layer_norm_eps=1e-5,
111
+ scale_embedding=True,
112
+ use_cache=True,
113
+ pad_token_id=1,
114
+ bos_token_id=0,
115
+ eos_token_id=2,
116
+ **kwargs,
117
+ ):
118
+ super().__init__(
119
+ pad_token_id=pad_token_id,
120
+ bos_token_id=bos_token_id,
121
+ eos_token_id=eos_token_id,
122
+ **kwargs,
123
+ )
124
+
125
+ self.vocab_size = vocab_size
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.embed_dim = embed_dim
128
+ self.layers = layers
129
+ self.ffn_dim = ffn_dim
130
+ self.attention_heads = attention_heads
131
+ self.activation_function = activation_function
132
+ self.dropout = dropout
133
+ self.attention_dropout = attention_dropout
134
+ self.activation_dropout = activation_dropout
135
+ self.layerdrop = layerdrop
136
+ self.layer_norm_eps = layer_norm_eps
137
+ self.scale_embedding = scale_embedding
138
+ self.use_cache = use_cache
139
+
140
+ @classmethod
141
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
142
+ cls._set_token_in_kwargs(kwargs)
143
+
144
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
145
+
146
+ # get the text config dict if we are loading from Kosmos2Config
147
+ if config_dict.get("model_type") == "kosmos-2":
148
+ config_dict = config_dict["text_config"]
149
+
150
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
151
+ logger.warning(
152
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
153
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
154
+ )
155
+
156
+ return cls.from_dict(config_dict, **kwargs)
157
+
158
+
159
+ class Kosmos2VisionConfig(PretrainedConfig):
160
+ r"""
161
+ This is the configuration class to store the configuration of a [`Kosmos2VisionModel`]. It is used to instantiate a
162
+ KOSMOS-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
163
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2
164
+ [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.
165
+
166
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
167
+ documentation from [`PretrainedConfig`] for more information.
168
+
169
+ Args:
170
+ hidden_size (`int`, *optional*, defaults to 1024):
171
+ Dimensionality of the encoder layers and the pooler layer.
172
+ intermediate_size (`int`, *optional*, defaults to 4096):
173
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
174
+ num_hidden_layers (`int`, *optional*, defaults to 24):
175
+ Number of hidden layers in the Transformer encoder.
176
+ num_attention_heads (`int`, *optional*, defaults to 16):
177
+ Number of attention heads for each attention layer in the Transformer encoder.
178
+ image_size (`int`, *optional*, defaults to 224):
179
+ The size (resolution) of each image.
180
+ patch_size (`int`, *optional*, defaults to 14):
181
+ The size (resolution) of each patch.
182
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
183
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
184
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
185
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
186
+ The epsilon used by the layer normalization layers.
187
+ attention_dropout (`float`, *optional*, defaults to 0.0):
188
+ The dropout ratio for the attention probabilities.
189
+ initializer_range (`float`, *optional*, defaults to 0.02):
190
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
191
+ initializer_factor (`float`, *optional*, defaults to 1):
192
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
193
+ testing).
194
+
195
+ Example:
196
+
197
+ ```python
198
+ >>> from transformers import Kosmos2VisionConfig, Kosmos2VisionModel
199
+
200
+ >>> # Initializing a Kosmos2VisionConfig with microsoft/kosmos-2-patch14-224 style configuration
201
+ >>> configuration = Kosmos2VisionConfig()
202
+
203
+ >>> # Initializing a Kosmos2VisionModel (with random weights) from the microsoft/kosmos-2-patch14-224 style configuration
204
+ >>> model = Kosmos2VisionModel(configuration)
205
+
206
+ >>> # Accessing the model configuration
207
+ >>> configuration = model.config
208
+ ```"""
209
+
210
+ model_type = "kosmos_2_vision_model"
211
+
212
+ def __init__(
213
+ self,
214
+ hidden_size=1024,
215
+ intermediate_size=4096,
216
+ projection_dim=512,
217
+ num_hidden_layers=24,
218
+ num_attention_heads=16,
219
+ num_channels=3,
220
+ image_size=224,
221
+ patch_size=14,
222
+ hidden_act="quick_gelu",
223
+ layer_norm_eps=1e-5,
224
+ attention_dropout=0.0,
225
+ initializer_range=0.02,
226
+ initializer_factor=1.0,
227
+ **kwargs,
228
+ ):
229
+ super().__init__(**kwargs)
230
+
231
+ self.hidden_size = hidden_size
232
+ self.intermediate_size = intermediate_size
233
+ self.projection_dim = projection_dim
234
+ self.num_hidden_layers = num_hidden_layers
235
+ self.num_attention_heads = num_attention_heads
236
+ self.num_channels = num_channels
237
+ self.patch_size = patch_size
238
+ self.image_size = image_size
239
+ self.initializer_range = initializer_range
240
+ self.initializer_factor = initializer_factor
241
+ self.attention_dropout = attention_dropout
242
+ self.layer_norm_eps = layer_norm_eps
243
+ self.hidden_act = hidden_act
244
+
245
+ @classmethod
246
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
247
+ cls._set_token_in_kwargs(kwargs)
248
+
249
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
250
+
251
+ # get the vision config dict if we are loading from Kosmos2Config
252
+ if config_dict.get("model_type") == "kosmos-2":
253
+ config_dict = config_dict["vision_config"]
254
+
255
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
256
+ logger.warning(
257
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
258
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
259
+ )
260
+
261
+ return cls.from_dict(config_dict, **kwargs)
262
+
263
+
264
+ class Kosmos2Config(PretrainedConfig):
265
+ r"""
266
+ This is the configuration class to store the configuration of a [`Kosmos2Model`]. It is used to instantiate a KOSMOS-2
267
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
268
+ defaults will yield a similar configuration to that of the KOSMOS-2
269
+ [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.
270
+
271
+ Args:
272
+ text_config (`dict`, *optional*):
273
+ Dictionary of configuration options used to initialize [`Kosmos2TextConfig`].
274
+ vision_config (`dict`, *optional*):
275
+ Dictionary of configuration options used to initialize [`Kosmos2VisionConfig`].
276
+ latent_query_num (`int`, *optional*, defaults to 64):
277
+ The number of latent query tokens that represent the image features used in the text decoder component.
278
+ kwargs (*optional*):
279
+ Dictionary of keyword arguments.
280
+
281
+ Example:
282
+
283
+ ```python
284
+ >>> from transformers import Kosmos2Config, Kosmos2Model
285
+
286
+ >>> # Initializing a Kosmos-2 kosmos-2-patch14-224 style configuration
287
+ >>> configuration = Kosmos2Config()
288
+
289
+ >>> # Initializing a model (with random weights) from the kosmos-2-patch14-224 style configuration
290
+ >>> model = Kosmos2Model(configuration)
291
+
292
+ >>> # Accessing the model configuration
293
+ >>> configuration = model.config
294
+ ```"""
295
+ model_type = "kosmos-2"
296
+ is_composition = True
297
+
298
+ def __init__(
299
+ self,
300
+ text_config=None,
301
+ vision_config=None,
302
+ latent_query_num=64,
303
+ **kwargs,
304
+ ):
305
+ super().__init__(**kwargs)
306
+
307
+ if text_config is None:
308
+ text_config = {}
309
+ logger.info("`text_config` is `None`. Initializing the `Kosmos2TextConfig` with default values.")
310
+
311
+ if vision_config is None:
312
+ vision_config = {}
313
+ logger.info("`vision_config` is `None`. Initializing the `Kosmos2VisionConfig` with default values.")
314
+
315
+ self.text_config = Kosmos2TextConfig(**text_config)
316
+ self.vision_config = Kosmos2VisionConfig(**vision_config)
317
+
318
+ self.latent_query_num = latent_query_num
319
+
320
+ def to_dict(self):
321
+ """
322
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
323
+
324
+ Returns:
325
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
326
+ """
327
+ output = copy.deepcopy(self.__dict__)
328
+ output["text_config"] = self.text_config.to_dict()
329
+ output["vision_config"] = self.vision_config.to_dict()
330
+ output["model_type"] = self.__class__.model_type
331
+ return output
modeling_kosmos2.py ADDED
@@ -0,0 +1,1747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch KOSMOS-2 model."""
16
+
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPooling,
31
+ CausalLMOutputWithCrossAttentions,
32
+ )
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CHECKPOINT_FOR_DOC = "microsoft/kosmos-2-patch14-224"
47
+ _CONFIG_FOR_DOC = Kosmos2Config
48
+ _EXPECTED_OUTPUT_SHAPE = None
49
+
50
+
51
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
52
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
53
+ """
54
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
55
+ """
56
+ bsz, src_len = mask.size()
57
+ tgt_len = tgt_len if tgt_len is not None else src_len
58
+
59
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
60
+
61
+ inverted_mask = 1.0 - expanded_mask
62
+
63
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
64
+
65
+
66
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
69
+ ):
70
+ """
71
+ Make causal mask used for bi-directional self-attention.
72
+ """
73
+ bsz, tgt_len = input_ids_shape
74
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
75
+ mask_cond = torch.arange(mask.size(-1), device=device)
76
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
77
+ mask = mask.to(dtype)
78
+
79
+ if past_key_values_length > 0:
80
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
81
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
82
+
83
+
84
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
85
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
86
+ """
87
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
88
+ are ignored. This is modified from fairseq's `utils.make_positions`.
89
+
90
+ Args:
91
+ x: torch.Tensor x:
92
+
93
+ Returns: torch.Tensor
94
+ """
95
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
96
+ mask = input_ids.ne(padding_idx).int()
97
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
98
+ return incremental_indices.long() + padding_idx
99
+
100
+
101
+ KOSMOS2_START_DOCSTRING = r"""Kosmos-2"""
102
+ KOSMOS2_VISION_INPUTS_DOCSTRING = r"""Kosmos-2"""
103
+ KOSMOS2_TEXT_INPUTS_DOCSTRING = r"""Kosmos-2"""
104
+ KOSMOS2_INPUTS_DOCSTRING = r"""Kosmos-2"""
105
+
106
+
107
+ @dataclass
108
+ class Kosmos2ModelOutput(ModelOutput):
109
+ """
110
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
111
+
112
+ Args:
113
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
114
+ Sequence of hidden-states at the output of the last layer of the model.
115
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
116
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
117
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
118
+
119
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
120
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
121
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
122
+ sequence_length)`.
123
+
124
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
125
+ heads.
126
+ image_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when being computed by the model):
127
+ Sequence of hidden-states at the output of `Kosmos2ImageToTextConnector`.
128
+ image_connector_attention (`tuple(torch.FloatTensor)`, *optional, returned when being computed by the model):
129
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
130
+ sequence_length)`.
131
+
132
+ Attentions weights given by `Kosmos2ImageToTextConnector`, after the attention softmax, used to compute the weighted average in the self-attention
133
+ heads.
134
+ vision_model_output(`BaseModelOutputWithPooling`, *optional*, returned when being computed by the model):
135
+ The output of the [`Kosmos2VisionModel`].
136
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
137
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
138
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
139
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
140
+ encoder_sequence_length, embed_size_per_head)`.
141
+
142
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
143
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
144
+ input) to speed up sequential decoding.
145
+ """
146
+
147
+ last_hidden_states: torch.FloatTensor = None
148
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
149
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
150
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
151
+ image_features: Optional[torch.FloatTensor] = None
152
+ image_connector_attention: Optional[Tuple[torch.FloatTensor]] = None
153
+ vision_model_output: BaseModelOutputWithPooling = None
154
+
155
+
156
+ @dataclass
157
+ class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
158
+ """
159
+ Model output class for `Kosmos2ForConditionalGeneration`.
160
+
161
+ Args:
162
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
163
+ Language modeling loss (for next-token prediction).
164
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
165
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
166
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
167
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
168
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
169
+
170
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
171
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
172
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
173
+ sequence_length)`.
174
+
175
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
176
+ heads.
177
+ image_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when being computed by the model):
178
+ Sequence of hidden-states at the output of `Kosmos2ImageToTextConnector`.
179
+ image_connector_attention (`tuple(torch.FloatTensor)`, *optional, returned when being computed by the model):
180
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
181
+ sequence_length)`.
182
+
183
+ Attentions weights given by `Kosmos2ImageToTextConnector`, after the attention softmax, used to compute the weighted average in the self-attention
184
+ heads.
185
+ vision_model_output(`BaseModelOutputWithPooling`, *optional*, returned when being computed by the model):
186
+ The output of the [`Kosmos2VisionModel`].
187
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
188
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
189
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
190
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
191
+ encoder_sequence_length, embed_size_per_head)`.
192
+
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
194
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
195
+ input) to speed up sequential decoding.
196
+ """
197
+
198
+ loss: Optional[torch.FloatTensor] = None
199
+ logits: torch.FloatTensor = None
200
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
201
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
202
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
203
+ image_features: Optional[torch.FloatTensor] = None
204
+ image_connector_attention: Optional[Tuple[torch.FloatTensor]] = None
205
+ vision_model_output: BaseModelOutputWithPooling = None
206
+
207
+
208
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2
209
+ class Kosmos2VisionEmbeddings(nn.Module):
210
+ def __init__(self, config: Kosmos2VisionConfig):
211
+ super().__init__()
212
+ self.config = config
213
+ self.embed_dim = config.hidden_size
214
+ self.image_size = config.image_size
215
+ self.patch_size = config.patch_size
216
+
217
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
218
+
219
+ self.patch_embedding = nn.Conv2d(
220
+ in_channels=config.num_channels,
221
+ out_channels=self.embed_dim,
222
+ kernel_size=self.patch_size,
223
+ stride=self.patch_size,
224
+ bias=False,
225
+ )
226
+
227
+ self.num_patches = (self.image_size // self.patch_size) ** 2
228
+ self.num_positions = self.num_patches + 1
229
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
230
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
231
+
232
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
233
+ batch_size = pixel_values.shape[0]
234
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
235
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
236
+
237
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
238
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
239
+ embeddings = embeddings + self.position_embedding(self.position_ids)
240
+ return embeddings
241
+
242
+
243
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Kosmos2Vision
244
+ class Kosmos2VisionAttention(nn.Module):
245
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
246
+
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+ self.embed_dim = config.hidden_size
251
+ self.num_heads = config.num_attention_heads
252
+ self.head_dim = self.embed_dim // self.num_heads
253
+ if self.head_dim * self.num_heads != self.embed_dim:
254
+ raise ValueError(
255
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
256
+ f" {self.num_heads})."
257
+ )
258
+ self.scale = self.head_dim**-0.5
259
+ self.dropout = config.attention_dropout
260
+
261
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
262
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
263
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
264
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
265
+
266
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
267
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ causal_attention_mask: Optional[torch.Tensor] = None,
274
+ output_attentions: Optional[bool] = False,
275
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
276
+ """Input shape: Batch x Time x Channel"""
277
+
278
+ bsz, tgt_len, embed_dim = hidden_states.size()
279
+
280
+ # get query proj
281
+ query_states = self.q_proj(hidden_states) * self.scale
282
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
283
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
284
+
285
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
286
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
287
+ key_states = key_states.view(*proj_shape)
288
+ value_states = value_states.view(*proj_shape)
289
+
290
+ src_len = key_states.size(1)
291
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
292
+
293
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
294
+ raise ValueError(
295
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
296
+ f" {attn_weights.size()}"
297
+ )
298
+
299
+ # apply the causal_attention_mask first
300
+ if causal_attention_mask is not None:
301
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
302
+ raise ValueError(
303
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
304
+ f" {causal_attention_mask.size()}"
305
+ )
306
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
307
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
308
+
309
+ if attention_mask is not None:
310
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
311
+ raise ValueError(
312
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
313
+ )
314
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
315
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
316
+
317
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
318
+
319
+ if output_attentions:
320
+ # this operation is a bit akward, but it's required to
321
+ # make sure that attn_weights keeps its gradient.
322
+ # In order to do so, attn_weights have to reshaped
323
+ # twice and have to be reused in the following
324
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
325
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
326
+ else:
327
+ attn_weights_reshaped = None
328
+
329
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
330
+
331
+ attn_output = torch.bmm(attn_probs, value_states)
332
+
333
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
334
+ raise ValueError(
335
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
336
+ f" {attn_output.size()}"
337
+ )
338
+
339
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
340
+ attn_output = attn_output.transpose(1, 2)
341
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
342
+
343
+ attn_output = self.out_proj(attn_output)
344
+
345
+ return attn_output, attn_weights_reshaped
346
+
347
+
348
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
349
+ class Kosmos2VisionMLP(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.activation_fn = ACT2FN[config.hidden_act]
354
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
355
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
356
+
357
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
358
+ hidden_states = self.fc1(hidden_states)
359
+ hidden_states = self.activation_fn(hidden_states)
360
+ hidden_states = self.fc2(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Kosmos2Vision
365
+ class Kosmos2VisionEncoderLayer(nn.Module):
366
+ def __init__(self, config: Kosmos2VisionConfig):
367
+ super().__init__()
368
+ self.embed_dim = config.hidden_size
369
+ self.self_attn = Kosmos2VisionAttention(config)
370
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
371
+ self.mlp = Kosmos2VisionMLP(config)
372
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
373
+
374
+ def forward(
375
+ self,
376
+ hidden_states: torch.Tensor,
377
+ attention_mask: torch.Tensor,
378
+ causal_attention_mask: torch.Tensor,
379
+ output_attentions: Optional[bool] = False,
380
+ ) -> Tuple[torch.FloatTensor]:
381
+ """
382
+ Args:
383
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
384
+ attention_mask (`torch.FloatTensor`): attention mask of size
385
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
386
+ `(config.encoder_attention_heads,)`.
387
+ output_attentions (`bool`, *optional*):
388
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
389
+ returned tensors for more detail.
390
+ """
391
+ residual = hidden_states
392
+
393
+ hidden_states = self.layer_norm1(hidden_states)
394
+ hidden_states, attn_weights = self.self_attn(
395
+ hidden_states=hidden_states,
396
+ attention_mask=attention_mask,
397
+ causal_attention_mask=causal_attention_mask,
398
+ output_attentions=output_attentions,
399
+ )
400
+ hidden_states = residual + hidden_states
401
+
402
+ residual = hidden_states
403
+ hidden_states = self.layer_norm2(hidden_states)
404
+ hidden_states = self.mlp(hidden_states)
405
+ hidden_states = residual + hidden_states
406
+
407
+ outputs = (hidden_states,)
408
+
409
+ if output_attentions:
410
+ outputs += (attn_weights,)
411
+
412
+ return outputs
413
+
414
+
415
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Kosmos2Vision
416
+ class Kosmos2VisionEncoder(nn.Module):
417
+ """
418
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
419
+ [`Kosmos2VisionEncoderLayer`].
420
+
421
+ Args:
422
+ config: Kosmos2VisionConfig
423
+ """
424
+
425
+ def __init__(self, config: Kosmos2VisionConfig):
426
+ super().__init__()
427
+ self.config = config
428
+ self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
429
+ self.gradient_checkpointing = False
430
+
431
+ def forward(
432
+ self,
433
+ inputs_embeds,
434
+ attention_mask: Optional[torch.Tensor] = None,
435
+ causal_attention_mask: Optional[torch.Tensor] = None,
436
+ output_attentions: Optional[bool] = None,
437
+ output_hidden_states: Optional[bool] = None,
438
+ return_dict: Optional[bool] = None,
439
+ ) -> Union[Tuple, BaseModelOutput]:
440
+ r"""
441
+ Args:
442
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
443
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
444
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
445
+ than the model's internal embedding lookup matrix.
446
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
447
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
448
+
449
+ - 1 for tokens that are **not masked**,
450
+ - 0 for tokens that are **masked**.
451
+
452
+ [What are attention masks?](../glossary#attention-mask)
453
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
454
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
455
+
456
+ - 1 for tokens that are **not masked**,
457
+ - 0 for tokens that are **masked**.
458
+
459
+ [What are attention masks?](../glossary#attention-mask)
460
+ output_attentions (`bool`, *optional*):
461
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
462
+ returned tensors for more detail.
463
+ output_hidden_states (`bool`, *optional*):
464
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
465
+ for more detail.
466
+ return_dict (`bool`, *optional*):
467
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
468
+ """
469
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
470
+ output_hidden_states = (
471
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
472
+ )
473
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
474
+
475
+ encoder_states = () if output_hidden_states else None
476
+ all_attentions = () if output_attentions else None
477
+
478
+ hidden_states = inputs_embeds
479
+ for idx, encoder_layer in enumerate(self.layers):
480
+ if output_hidden_states:
481
+ encoder_states = encoder_states + (hidden_states,)
482
+ if self.gradient_checkpointing and self.training:
483
+
484
+ def create_custom_forward(module):
485
+ def custom_forward(*inputs):
486
+ return module(*inputs, output_attentions)
487
+
488
+ return custom_forward
489
+
490
+ layer_outputs = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(encoder_layer),
492
+ hidden_states,
493
+ attention_mask,
494
+ causal_attention_mask,
495
+ )
496
+ else:
497
+ layer_outputs = encoder_layer(
498
+ hidden_states,
499
+ attention_mask,
500
+ causal_attention_mask,
501
+ output_attentions=output_attentions,
502
+ )
503
+
504
+ hidden_states = layer_outputs[0]
505
+
506
+ if output_attentions:
507
+ all_attentions = all_attentions + (layer_outputs[1],)
508
+
509
+ if output_hidden_states:
510
+ encoder_states = encoder_states + (hidden_states,)
511
+
512
+ if not return_dict:
513
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
514
+ return BaseModelOutput(
515
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
516
+ )
517
+
518
+
519
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVision->Kosmos2Vision,CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2Vision
520
+ class Kosmos2VisionTransformer(nn.Module):
521
+ def __init__(self, config: Kosmos2VisionConfig):
522
+ super().__init__()
523
+ self.config = config
524
+ embed_dim = config.hidden_size
525
+
526
+ self.embeddings = Kosmos2VisionEmbeddings(config)
527
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
528
+ self.encoder = Kosmos2VisionEncoder(config)
529
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
530
+
531
+ @add_start_docstrings_to_model_forward(KOSMOS2_VISION_INPUTS_DOCSTRING)
532
+ def forward(
533
+ self,
534
+ pixel_values: Optional[torch.FloatTensor] = None,
535
+ output_attentions: Optional[bool] = None,
536
+ output_hidden_states: Optional[bool] = None,
537
+ return_dict: Optional[bool] = None,
538
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
539
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
540
+ output_hidden_states = (
541
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
542
+ )
543
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
544
+
545
+ if pixel_values is None:
546
+ raise ValueError("You have to specify pixel_values")
547
+
548
+ hidden_states = self.embeddings(pixel_values)
549
+ hidden_states = self.pre_layrnorm(hidden_states)
550
+
551
+ encoder_outputs = self.encoder(
552
+ inputs_embeds=hidden_states,
553
+ output_attentions=output_attentions,
554
+ output_hidden_states=output_hidden_states,
555
+ return_dict=return_dict,
556
+ )
557
+
558
+ last_hidden_state = encoder_outputs[0]
559
+ pooled_output = last_hidden_state[:, 0, :]
560
+ pooled_output = self.post_layernorm(pooled_output)
561
+
562
+ if not return_dict:
563
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
564
+
565
+ return BaseModelOutputWithPooling(
566
+ last_hidden_state=last_hidden_state,
567
+ pooler_output=pooled_output,
568
+ hidden_states=encoder_outputs.hidden_states,
569
+ attentions=encoder_outputs.attentions,
570
+ )
571
+
572
+
573
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding with M2M100->Kosmos2
574
+ class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
575
+ """This module produces sinusoidal positional embeddings of any length."""
576
+
577
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
578
+ super().__init__()
579
+ self.offset = 2
580
+ self.embedding_dim = embedding_dim
581
+ self.padding_idx = padding_idx
582
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
583
+
584
+ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
585
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
586
+ if hasattr(self, "weights"):
587
+ # in forward put the weights on the correct dtype and device of the param
588
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
589
+
590
+ self.register_buffer("weights", emb_weights, persistent=False)
591
+
592
+ @staticmethod
593
+ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
594
+ """
595
+ Build sinusoidal embeddings.
596
+
597
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
598
+ "Attention Is All You Need".
599
+ """
600
+ half_dim = embedding_dim // 2
601
+ emb = math.log(10000) / (half_dim - 1)
602
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
603
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
604
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
605
+ if embedding_dim % 2 == 1:
606
+ # zero pad
607
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
608
+ if padding_idx is not None:
609
+ emb[padding_idx, :] = 0
610
+
611
+ return emb.to(torch.get_default_dtype())
612
+
613
+ @torch.no_grad()
614
+ def forward(
615
+ self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
616
+ ):
617
+ if input_ids is not None:
618
+ bsz, seq_len = input_ids.size()
619
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
620
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
621
+ input_ids.device
622
+ )
623
+ else:
624
+ bsz, seq_len = inputs_embeds.size()[:-1]
625
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
626
+
627
+ # expand embeddings if needed
628
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
629
+ if max_pos > self.weights.size(0):
630
+ self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
631
+
632
+ return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
633
+
634
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
635
+ """
636
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
637
+
638
+ Args:
639
+ inputs_embeds: torch.Tensor
640
+
641
+ Returns: torch.Tensor
642
+ """
643
+ input_shape = inputs_embeds.size()[:-1]
644
+ sequence_length = input_shape[1]
645
+
646
+ position_ids = torch.arange(
647
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
648
+ )
649
+ return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
650
+
651
+
652
+ # Similar to transformers.models.bart.modeling_bart.BartAttention with an additional `inner_attn_ln`.
653
+ class KosmosTextAttention(nn.Module):
654
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
655
+
656
+ def __init__(
657
+ self,
658
+ config,
659
+ embed_dim: int,
660
+ num_heads: int,
661
+ dropout: float = 0.0,
662
+ is_decoder: bool = False,
663
+ add_inner_attn_layernorm: bool = False,
664
+ bias: bool = True,
665
+ ):
666
+ super().__init__()
667
+ self.embed_dim = embed_dim
668
+ self.num_heads = num_heads
669
+ self.dropout = dropout
670
+ self.head_dim = embed_dim // num_heads
671
+
672
+ if (self.head_dim * num_heads) != self.embed_dim:
673
+ raise ValueError(
674
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
675
+ f" and `num_heads`: {num_heads})."
676
+ )
677
+ self.scaling = self.head_dim**-0.5
678
+ self.is_decoder = is_decoder
679
+
680
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
681
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
682
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
683
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
684
+
685
+ self.inner_attn_ln = None
686
+ if add_inner_attn_layernorm:
687
+ self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
688
+
689
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
690
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
691
+
692
+ def forward(
693
+ self,
694
+ hidden_states: torch.Tensor,
695
+ key_value_states: Optional[torch.Tensor] = None,
696
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
697
+ attention_mask: Optional[torch.Tensor] = None,
698
+ layer_head_mask: Optional[torch.Tensor] = None,
699
+ output_attentions: bool = False,
700
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
701
+ """Input shape: Batch x Time x Channel"""
702
+
703
+ # if key_value_states are provided this layer is used as a cross-attention layer
704
+ # for the decoder
705
+ is_cross_attention = key_value_states is not None
706
+
707
+ bsz, tgt_len, _ = hidden_states.size()
708
+
709
+ # get query proj
710
+ query_states = self.q_proj(hidden_states) * self.scaling
711
+ # get key, value proj
712
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
713
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
714
+ # the provided `key_value_states` to support prefix tuning
715
+ if (
716
+ is_cross_attention
717
+ and past_key_value is not None
718
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
719
+ ):
720
+ # reuse k,v, cross_attentions
721
+ key_states = past_key_value[0]
722
+ value_states = past_key_value[1]
723
+ elif is_cross_attention:
724
+ # cross_attentions
725
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
726
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
727
+ elif past_key_value is not None:
728
+ # reuse k, v, self_attention
729
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
730
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
731
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
732
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
733
+ else:
734
+ # self_attention
735
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
736
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
737
+
738
+ if self.is_decoder:
739
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
740
+ # Further calls to cross_attention layer can then reuse all cross-attention
741
+ # key/value_states (first "if" case)
742
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
743
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
744
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
745
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
746
+ past_key_value = (key_states, value_states)
747
+
748
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
749
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
750
+ key_states = key_states.reshape(*proj_shape)
751
+ value_states = value_states.reshape(*proj_shape)
752
+
753
+ src_len = key_states.size(1)
754
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
755
+
756
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
757
+ raise ValueError(
758
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
759
+ f" {attn_weights.size()}"
760
+ )
761
+
762
+ if attention_mask is not None:
763
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
764
+ raise ValueError(
765
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
766
+ )
767
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
768
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
769
+
770
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
771
+
772
+ if layer_head_mask is not None:
773
+ if layer_head_mask.size() != (self.num_heads,):
774
+ raise ValueError(
775
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
776
+ f" {layer_head_mask.size()}"
777
+ )
778
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
779
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
780
+
781
+ if output_attentions:
782
+ # this operation is a bit awkward, but it's required to
783
+ # make sure that attn_weights keeps its gradient.
784
+ # In order to do so, attn_weights have to be reshaped
785
+ # twice and have to be reused in the following
786
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
787
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
788
+ else:
789
+ attn_weights_reshaped = None
790
+
791
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
792
+
793
+ attn_output = torch.bmm(attn_probs, value_states)
794
+
795
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
796
+ raise ValueError(
797
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
798
+ f" {attn_output.size()}"
799
+ )
800
+
801
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
802
+ attn_output = attn_output.transpose(1, 2)
803
+
804
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
805
+ # partitioned across GPUs when using tensor-parallelism.
806
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
807
+
808
+ if self.inner_attn_ln is not None:
809
+ attn_output = self.inner_attn_ln(attn_output)
810
+
811
+ attn_output = self.out_proj(attn_output)
812
+
813
+ return attn_output, attn_weights_reshaped, past_key_value
814
+
815
+
816
+ class Kosmos2TextFFN(nn.Module):
817
+ def __init__(self, config: Kosmos2TextConfig):
818
+ super().__init__()
819
+
820
+ self.dropout = config.dropout
821
+ self.activation_fn = ACT2FN[config.activation_function]
822
+ self.activation_dropout = config.activation_dropout
823
+
824
+ self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim)
825
+ self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim)
826
+
827
+ self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps)
828
+
829
+ def forward(self, hidden_states):
830
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
831
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
832
+ hidden_states = self.ffn_layernorm(hidden_states)
833
+ hidden_states = self.fc2(hidden_states)
834
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
835
+
836
+ return hidden_states
837
+
838
+
839
+ class Kosmos2TextBlock(nn.Module):
840
+ def __init__(self, config: Kosmos2TextConfig):
841
+ super().__init__()
842
+ self.embed_dim = config.embed_dim
843
+
844
+ self.self_attn = KosmosTextAttention(
845
+ config,
846
+ embed_dim=self.embed_dim,
847
+ num_heads=config.attention_heads,
848
+ dropout=config.attention_dropout,
849
+ is_decoder=True,
850
+ add_inner_attn_layernorm=True,
851
+ )
852
+ self.dropout = config.dropout
853
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
854
+
855
+ if config.add_cross_attention:
856
+ self.encoder_attn = KosmosTextAttention(
857
+ config,
858
+ embed_dim=self.embed_dim,
859
+ num_heads=config.attention_heads,
860
+ dropout=config.attention_dropout,
861
+ is_decoder=True,
862
+ add_inner_attn_layernorm=False,
863
+ )
864
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
865
+
866
+ self.ffn = Kosmos2TextFFN(config)
867
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
868
+
869
+ def forward(
870
+ self,
871
+ hidden_states: torch.Tensor,
872
+ attention_mask: Optional[torch.Tensor] = None,
873
+ encoder_hidden_states: Optional[torch.Tensor] = None,
874
+ encoder_attention_mask: Optional[torch.Tensor] = None,
875
+ layer_head_mask: Optional[torch.Tensor] = None,
876
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
877
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
878
+ output_attentions: Optional[bool] = False,
879
+ use_cache: Optional[bool] = True,
880
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
881
+ residual = hidden_states
882
+
883
+ # Self Attention
884
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
885
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
886
+
887
+ hidden_states = self.self_attn_layer_norm(hidden_states)
888
+
889
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
890
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
891
+ hidden_states=hidden_states,
892
+ past_key_value=self_attn_past_key_value,
893
+ attention_mask=attention_mask,
894
+ layer_head_mask=layer_head_mask,
895
+ output_attentions=output_attentions,
896
+ )
897
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
898
+ hidden_states = residual + hidden_states
899
+
900
+ # Cross-Attention Block
901
+ cross_attn_present_key_value = None
902
+ cross_attn_weights = None
903
+ if encoder_hidden_states is not None:
904
+ if not hasattr(self, "encoder_attn"):
905
+ raise ValueError(
906
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
907
+ " by setting `config.add_cross_attention=True`"
908
+ )
909
+
910
+ residual = hidden_states
911
+
912
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
913
+
914
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
915
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
916
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
917
+ hidden_states=hidden_states,
918
+ key_value_states=encoder_hidden_states,
919
+ attention_mask=encoder_attention_mask,
920
+ layer_head_mask=cross_attn_layer_head_mask,
921
+ past_key_value=cross_attn_past_key_value,
922
+ output_attentions=output_attentions,
923
+ )
924
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
925
+ hidden_states = residual + hidden_states
926
+
927
+ # add cross-attn to positions 3,4 of present_key_value tuple
928
+ present_key_value = present_key_value + cross_attn_present_key_value
929
+
930
+ # Fully Connected
931
+ residual = hidden_states
932
+
933
+ hidden_states = self.final_layer_norm(hidden_states)
934
+
935
+ # FFN
936
+ hidden_states = self.ffn(hidden_states)
937
+ hidden_states = residual + hidden_states
938
+
939
+ outputs = (hidden_states,)
940
+
941
+ if output_attentions:
942
+ outputs += (self_attn_weights, cross_attn_weights)
943
+
944
+ if use_cache:
945
+ outputs += (present_key_value,)
946
+
947
+ return outputs
948
+
949
+
950
+ class Kosmos2TextTransformer(nn.Module):
951
+ """
952
+ Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].
953
+
954
+ Args:
955
+ config: Kosmos2TextConfig
956
+ """
957
+
958
+ def __init__(self, config: Kosmos2TextConfig):
959
+ super().__init__()
960
+ self.config = config
961
+ self.dropout = config.dropout
962
+ self.layerdrop = config.layerdrop
963
+
964
+ self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0
965
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)
966
+
967
+ self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
968
+ num_positions=config.max_position_embeddings,
969
+ embedding_dim=config.embed_dim,
970
+ padding_idx=config.pad_token_id,
971
+ )
972
+
973
+ self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)])
974
+ self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
975
+
976
+ self.gradient_checkpointing = False
977
+
978
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
979
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
980
+ # create causal mask
981
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
982
+ combined_attention_mask = None
983
+ if input_shape[-1] > 1:
984
+ combined_attention_mask = _make_causal_mask(
985
+ input_shape,
986
+ inputs_embeds.dtype,
987
+ device=inputs_embeds.device,
988
+ past_key_values_length=past_key_values_length,
989
+ )
990
+
991
+ if attention_mask is not None:
992
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
993
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
994
+ inputs_embeds.device
995
+ )
996
+ combined_attention_mask = (
997
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
998
+ )
999
+
1000
+ return combined_attention_mask
1001
+
1002
+ def forward_embedding(
1003
+ self, input_ids, inputs_embeds=None, img_features=None, img_input_mask=None, past_key_values_length: int = 0
1004
+ ):
1005
+ # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`.
1006
+ if inputs_embeds is None:
1007
+ inputs_embeds = self.embed_tokens(input_ids)
1008
+
1009
+ if img_features is not None:
1010
+ inputs_embeds[img_input_mask.to(dtype=torch.bool)] = img_features
1011
+
1012
+ inputs_embeds = inputs_embeds * self.embed_scale
1013
+
1014
+ # embed positions
1015
+ positions = self.embed_positions(
1016
+ input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length
1017
+ )
1018
+ positions = positions.to(inputs_embeds.device)
1019
+
1020
+ hidden_states = inputs_embeds + positions
1021
+
1022
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1023
+
1024
+ return hidden_states
1025
+
1026
+ def forward(
1027
+ self,
1028
+ input_ids: Optional[torch.Tensor] = None,
1029
+ attention_mask: Optional[torch.Tensor] = None,
1030
+ img_features: Optional[torch.Tensor] = None,
1031
+ img_attn_mask: Optional[torch.Tensor] = None,
1032
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1033
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1034
+ head_mask: Optional[torch.Tensor] = None,
1035
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1036
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1037
+ inputs_embeds: Optional[torch.Tensor] = None,
1038
+ use_cache: Optional[bool] = None,
1039
+ output_attentions: Optional[bool] = None,
1040
+ output_hidden_states: Optional[bool] = None,
1041
+ return_dict: Optional[bool] = None,
1042
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
1043
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1044
+ output_hidden_states = (
1045
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1046
+ )
1047
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1048
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
+
1050
+ if input_ids is not None and inputs_embeds is not None:
1051
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1052
+ elif input_ids is not None:
1053
+ input_shape = input_ids.shape
1054
+ input_ids = input_ids.view(-1, input_shape[-1])
1055
+ elif inputs_embeds is not None:
1056
+ input_shape = inputs_embeds.size()[:-1]
1057
+ else:
1058
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1059
+
1060
+ # past_key_values_length
1061
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1062
+
1063
+ # We don't need img info. when `past_key_values_length` > 0
1064
+ if past_key_values_length > 0:
1065
+ img_features = None
1066
+ img_attn_mask = None
1067
+
1068
+ hidden_states = self.forward_embedding(
1069
+ input_ids=input_ids,
1070
+ inputs_embeds=inputs_embeds,
1071
+ img_features=img_features,
1072
+ img_input_mask=img_attn_mask,
1073
+ past_key_values_length=past_key_values_length,
1074
+ )
1075
+
1076
+ attention_mask = self._prepare_decoder_attention_mask(
1077
+ attention_mask, input_shape, hidden_states, past_key_values_length
1078
+ )
1079
+
1080
+ # expand encoder attention mask
1081
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1082
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1083
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1084
+
1085
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1086
+
1087
+ if self.gradient_checkpointing and self.training:
1088
+ if use_cache:
1089
+ logger.warning_once(
1090
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1091
+ )
1092
+ use_cache = False
1093
+
1094
+ # decoder layers
1095
+ all_hidden_states = () if output_hidden_states else None
1096
+ all_self_attns = () if output_attentions else None
1097
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1098
+ next_decoder_cache = () if use_cache else None
1099
+
1100
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1101
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1102
+ if attn_mask is not None:
1103
+ if attn_mask.size()[0] != (len(self.layers)):
1104
+ raise ValueError(
1105
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1106
+ f" {head_mask.size()[0]}."
1107
+ )
1108
+
1109
+ for idx, decoder_layer in enumerate(self.layers):
1110
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1111
+ if output_hidden_states:
1112
+ all_hidden_states += (hidden_states,)
1113
+ if self.training:
1114
+ dropout_probability = torch.rand([])
1115
+ if dropout_probability < self.layerdrop:
1116
+ continue
1117
+
1118
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1119
+
1120
+ if self.gradient_checkpointing and self.training:
1121
+
1122
+ def create_custom_forward(module):
1123
+ def custom_forward(*inputs):
1124
+ # None for past_key_value
1125
+ return module(*inputs, output_attentions, use_cache)
1126
+
1127
+ return custom_forward
1128
+
1129
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1130
+ create_custom_forward(decoder_layer),
1131
+ hidden_states,
1132
+ attention_mask,
1133
+ encoder_hidden_states,
1134
+ encoder_attention_mask,
1135
+ head_mask[idx] if head_mask is not None else None,
1136
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1137
+ None,
1138
+ )
1139
+ else:
1140
+ layer_outputs = decoder_layer(
1141
+ hidden_states,
1142
+ attention_mask=attention_mask,
1143
+ encoder_hidden_states=encoder_hidden_states,
1144
+ encoder_attention_mask=encoder_attention_mask,
1145
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1146
+ cross_attn_layer_head_mask=(
1147
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1148
+ ),
1149
+ past_key_value=past_key_value,
1150
+ output_attentions=output_attentions,
1151
+ use_cache=use_cache,
1152
+ )
1153
+ hidden_states = layer_outputs[0]
1154
+
1155
+ if use_cache:
1156
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1157
+
1158
+ if output_attentions:
1159
+ all_self_attns += (layer_outputs[1],)
1160
+
1161
+ if encoder_hidden_states is not None:
1162
+ all_cross_attentions += (layer_outputs[2],)
1163
+
1164
+ # add final layer norm
1165
+ hidden_states = self.layer_norm(hidden_states)
1166
+
1167
+ # add hidden states from the last decoder layer
1168
+ if output_hidden_states:
1169
+ all_hidden_states += (hidden_states,)
1170
+
1171
+ next_cache = next_decoder_cache if use_cache else None
1172
+ if not return_dict:
1173
+ return tuple(
1174
+ v
1175
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1176
+ if v is not None
1177
+ )
1178
+ return BaseModelOutputWithPastAndCrossAttentions(
1179
+ last_hidden_state=hidden_states,
1180
+ past_key_values=next_cache,
1181
+ hidden_states=all_hidden_states,
1182
+ attentions=all_self_attns,
1183
+ cross_attentions=all_cross_attentions,
1184
+ )
1185
+
1186
+
1187
+ class Kosmos2PreTrainedModel(PreTrainedModel):
1188
+ """
1189
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1190
+ models.
1191
+ """
1192
+
1193
+ config_class = Kosmos2Config
1194
+ supports_gradient_checkpointing = True
1195
+
1196
+
1197
+ @add_start_docstrings(
1198
+ """The vision model from KOSMOS-2 without any head or projection on top.""",
1199
+ KOSMOS2_START_DOCSTRING,
1200
+ )
1201
+ class Kosmos2VisionModel(Kosmos2PreTrainedModel):
1202
+ config_class = Kosmos2VisionConfig
1203
+ main_input_name = "pixel_values"
1204
+
1205
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2
1206
+ def __init__(self, config: Kosmos2VisionConfig):
1207
+ super().__init__(config)
1208
+ self.model = Kosmos2VisionTransformer(config)
1209
+ # Initialize weights and apply final processing
1210
+ self.post_init()
1211
+
1212
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2
1213
+ def get_input_embeddings(self) -> nn.Module:
1214
+ return self.model.embeddings.patch_embedding
1215
+
1216
+ @add_start_docstrings_to_model_forward(KOSMOS2_VISION_INPUTS_DOCSTRING)
1217
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Kosmos2VisionConfig)
1218
+ def forward(
1219
+ self,
1220
+ pixel_values: Optional[torch.FloatTensor] = None,
1221
+ output_attentions: Optional[bool] = None,
1222
+ output_hidden_states: Optional[bool] = None,
1223
+ return_dict: Optional[bool] = None,
1224
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1225
+ r"""
1226
+ Returns:
1227
+
1228
+ """
1229
+ return self.model(
1230
+ pixel_values=pixel_values,
1231
+ output_attentions=output_attentions,
1232
+ output_hidden_states=output_hidden_states,
1233
+ return_dict=return_dict,
1234
+ )
1235
+
1236
+
1237
+ @add_start_docstrings(
1238
+ """The text model from KOSMOS-2 without any head or projection on top.""",
1239
+ KOSMOS2_START_DOCSTRING,
1240
+ )
1241
+ class Kosmos2TextModel(Kosmos2PreTrainedModel):
1242
+ config_class = Kosmos2TextConfig
1243
+
1244
+ _no_split_modules = ["Kosmos2TextBlock"]
1245
+
1246
+ def __init__(self, config: Kosmos2TextConfig):
1247
+ super().__init__(config)
1248
+ self.model = Kosmos2TextTransformer(config)
1249
+ # Initialize weights and apply final processing
1250
+ self.post_init()
1251
+
1252
+ def get_input_embeddings(self) -> nn.Module:
1253
+ return self.model.embed_tokens
1254
+
1255
+ def set_input_embeddings(self, value):
1256
+ self.model.embed_tokens = value
1257
+
1258
+ @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
1259
+ @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
1260
+ def forward(
1261
+ self,
1262
+ input_ids: Optional[torch.Tensor] = None,
1263
+ attention_mask: Optional[torch.Tensor] = None,
1264
+ img_features: Optional[torch.Tensor] = None,
1265
+ img_attn_mask: Optional[torch.Tensor] = None,
1266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1267
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1268
+ head_mask: Optional[torch.Tensor] = None,
1269
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1270
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1271
+ inputs_embeds: Optional[torch.Tensor] = None,
1272
+ use_cache: Optional[bool] = None,
1273
+ output_attentions: Optional[bool] = None,
1274
+ output_hidden_states: Optional[bool] = None,
1275
+ return_dict: Optional[bool] = None,
1276
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
1277
+ r"""
1278
+ Returns:
1279
+
1280
+ """
1281
+ return self.model(
1282
+ input_ids=input_ids,
1283
+ attention_mask=attention_mask,
1284
+ img_features=img_features,
1285
+ img_attn_mask=img_attn_mask,
1286
+ encoder_hidden_states=encoder_hidden_states,
1287
+ encoder_attention_mask=encoder_attention_mask,
1288
+ head_mask=head_mask,
1289
+ cross_attn_head_mask=cross_attn_head_mask,
1290
+ past_key_values=past_key_values,
1291
+ inputs_embeds=inputs_embeds,
1292
+ use_cache=use_cache,
1293
+ output_attentions=output_attentions,
1294
+ output_hidden_states=output_hidden_states,
1295
+ return_dict=return_dict,
1296
+ )
1297
+
1298
+
1299
+ @add_start_docstrings(
1300
+ """
1301
+ The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
1302
+ embeddings).
1303
+ """,
1304
+ KOSMOS2_START_DOCSTRING,
1305
+ )
1306
+ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel):
1307
+ config_class = Kosmos2TextConfig
1308
+ _tied_weights_keys = ["lm_head.weight"]
1309
+
1310
+ def __init__(self, config: Kosmos2TextConfig):
1311
+ super().__init__(config)
1312
+
1313
+ self.model = Kosmos2TextTransformer(config)
1314
+ self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
1315
+
1316
+ # Initialize weights and apply final processing
1317
+ self.post_init()
1318
+
1319
+ def get_input_embeddings(self) -> nn.Module:
1320
+ return self.model.embed_tokens
1321
+
1322
+ def set_input_embeddings(self, value):
1323
+ self.model.embed_tokens = value
1324
+
1325
+ def get_output_embeddings(self) -> nn.Module:
1326
+ return self.lm_head
1327
+
1328
+ def set_output_embeddings(self, new_embeddings):
1329
+ self.lm_head = new_embeddings
1330
+
1331
+ @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
1332
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2TextConfig)
1333
+ def forward(
1334
+ self,
1335
+ input_ids: Optional[torch.Tensor] = None,
1336
+ attention_mask: Optional[torch.Tensor] = None,
1337
+ img_features: Optional[torch.Tensor] = None,
1338
+ img_attn_mask: Optional[torch.Tensor] = None,
1339
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1340
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1341
+ head_mask: Optional[torch.Tensor] = None,
1342
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1343
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1344
+ inputs_embeds: Optional[torch.Tensor] = None,
1345
+ labels: Optional[torch.LongTensor] = None,
1346
+ use_cache: Optional[bool] = None,
1347
+ output_attentions: Optional[bool] = None,
1348
+ output_hidden_states: Optional[bool] = None,
1349
+ return_dict: Optional[bool] = None,
1350
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1351
+ r"""
1352
+ Returns:
1353
+
1354
+ """
1355
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1356
+
1357
+ if labels is not None:
1358
+ if use_cache:
1359
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1360
+ use_cache = False
1361
+
1362
+ outputs = self.model(
1363
+ input_ids=input_ids,
1364
+ attention_mask=attention_mask,
1365
+ img_features=img_features,
1366
+ img_attn_mask=img_attn_mask,
1367
+ encoder_hidden_states=encoder_hidden_states,
1368
+ encoder_attention_mask=encoder_attention_mask,
1369
+ head_mask=head_mask,
1370
+ cross_attn_head_mask=cross_attn_head_mask,
1371
+ past_key_values=past_key_values,
1372
+ inputs_embeds=inputs_embeds,
1373
+ use_cache=use_cache,
1374
+ output_attentions=output_attentions,
1375
+ output_hidden_states=output_hidden_states,
1376
+ return_dict=return_dict,
1377
+ )
1378
+ logits = self.lm_head(outputs[0])
1379
+
1380
+ loss = None
1381
+ if labels is not None:
1382
+ # Shift so that tokens < n predict n
1383
+ shift_logits = logits[..., :-1, :].contiguous()
1384
+ shift_labels = labels[..., 1:].contiguous()
1385
+ # Flatten the tokens
1386
+ loss_fct = CrossEntropyLoss()
1387
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1388
+ shift_labels = shift_labels.view(-1)
1389
+ # Enable model parallelism
1390
+ shift_labels = shift_labels.to(shift_logits.device)
1391
+ loss = loss_fct(shift_logits, shift_labels)
1392
+
1393
+ if not return_dict:
1394
+ output = (logits,) + outputs[1:]
1395
+ return (loss,) + output if loss is not None else output
1396
+
1397
+ return CausalLMOutputWithCrossAttentions(
1398
+ loss=loss,
1399
+ logits=logits,
1400
+ past_key_values=outputs.past_key_values,
1401
+ hidden_states=outputs.hidden_states,
1402
+ attentions=outputs.attentions,
1403
+ cross_attentions=outputs.cross_attentions,
1404
+ )
1405
+
1406
+ def prepare_inputs_for_generation(
1407
+ self,
1408
+ input_ids,
1409
+ img_features,
1410
+ img_attn_mask,
1411
+ past_key_values=None,
1412
+ attention_mask=None,
1413
+ use_cache=None,
1414
+ **model_kwargs,
1415
+ ):
1416
+ input_shape = input_ids.shape
1417
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1418
+ if attention_mask is None:
1419
+ attention_mask = input_ids.new_ones(input_shape)
1420
+
1421
+ # cut input_ids if past_key_values is used
1422
+ if past_key_values is not None:
1423
+ input_ids = input_ids[:, -1:]
1424
+ # the image info. is already encoded into the past keys/values
1425
+ img_features = None
1426
+ img_attn_mask = None
1427
+ elif img_attn_mask is not None:
1428
+ # appending `False` to `img_attn_mask` (because `input_ids` grows during generation)
1429
+ batch_size, seq_len = input_ids.size()
1430
+ mask_len = img_attn_mask.size()[-1]
1431
+ img_attn_mask = torch.cat(
1432
+ (img_attn_mask, torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool)), dim=1
1433
+ )
1434
+
1435
+ return {
1436
+ "input_ids": input_ids,
1437
+ "img_features": img_features,
1438
+ "img_attn_mask": img_attn_mask,
1439
+ "past_key_values": past_key_values,
1440
+ "attention_mask": attention_mask,
1441
+ "use_cache": use_cache,
1442
+ }
1443
+
1444
+
1445
+ class Kosmos2ImageToTextConnector(nn.Module):
1446
+ """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""
1447
+
1448
+ def __init__(self, config: Kosmos2Config):
1449
+ super().__init__()
1450
+ self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
1451
+ self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))
1452
+
1453
+ self.x_attn = KosmosTextAttention(
1454
+ config.text_config,
1455
+ config.text_config.embed_dim,
1456
+ config.text_config.attention_heads,
1457
+ dropout=config.text_config.attention_dropout,
1458
+ is_decoder=False,
1459
+ add_inner_attn_layernorm=False,
1460
+ )
1461
+
1462
+ def forward(self, features):
1463
+ hidden_states = self.dense(features)
1464
+
1465
+ # shape = [batch, latent_query_num, h_dim]
1466
+ latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
1467
+ key_value_states = torch.cat([hidden_states, latent_query], dim=1)
1468
+
1469
+ hidden_states, attn_weights, _ = self.x_attn(
1470
+ hidden_states=latent_query,
1471
+ key_value_states=key_value_states,
1472
+ past_key_value=None,
1473
+ attention_mask=None,
1474
+ output_attentions=None,
1475
+ )
1476
+
1477
+ return hidden_states, attn_weights
1478
+
1479
+
1480
+ @add_start_docstrings(
1481
+ """
1482
+ KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder (CLIP) and a language
1483
+ model.
1484
+ """,
1485
+ KOSMOS2_START_DOCSTRING,
1486
+ )
1487
+ class Kosmos2Model(Kosmos2PreTrainedModel):
1488
+ config_class = Kosmos2Config
1489
+
1490
+ def __init__(self, config: Kosmos2Config):
1491
+ super().__init__(config)
1492
+
1493
+ self.text_model = Kosmos2TextModel(config.text_config)
1494
+ self.vision_model = Kosmos2VisionModel(config.vision_config)
1495
+ self.image_to_text_connector = Kosmos2ImageToTextConnector(config)
1496
+
1497
+ # Initialize weights and apply final processing
1498
+ self.post_init()
1499
+
1500
+ def get_input_embeddings(self) -> nn.Module:
1501
+ return self.text_model.model.embed_tokens
1502
+
1503
+ def set_input_embeddings(self, value):
1504
+ self.text_model.model.embed_tokens = value
1505
+
1506
+ @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
1507
+ @replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=Kosmos2Config)
1508
+ def forward(
1509
+ self,
1510
+ pixel_values: Optional[torch.Tensor] = None,
1511
+ input_ids: Optional[torch.Tensor] = None,
1512
+ attention_mask: Optional[torch.Tensor] = None,
1513
+ img_attn_mask: Optional[torch.Tensor] = None,
1514
+ head_mask: Optional[torch.Tensor] = None,
1515
+ img_features: Optional[torch.Tensor] = None,
1516
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1517
+ inputs_embeds: Optional[torch.Tensor] = None,
1518
+ use_cache: Optional[bool] = None,
1519
+ output_attentions: Optional[bool] = None,
1520
+ output_hidden_states: Optional[bool] = None,
1521
+ return_dict: Optional[bool] = None,
1522
+ ) -> Union[Tuple, Kosmos2ModelOutput]:
1523
+ # TODO: Add this
1524
+ r"""
1525
+ Returns:
1526
+
1527
+ ```"""
1528
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1529
+ output_hidden_states = (
1530
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1531
+ )
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+
1534
+ vision_model_output = None
1535
+ image_connector_attention = None
1536
+ if img_features is None:
1537
+ if pixel_values is None:
1538
+ raise ValueError("You have to specify either `pixel_values` or `img_features`.")
1539
+
1540
+ vision_model_output = self.vision_model(pixel_values)
1541
+ # HF's CLIP has `last_hidden_state` without going through `post_layernorm`.
1542
+ # Here we need the whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
1543
+ img_features = self.vision_model.model.post_layernorm(vision_model_output.last_hidden_state)
1544
+ # normalized features
1545
+ img_features = nn.functional.normalize(img_features, dim=-1)
1546
+ img_features, image_connector_attention = self.image_to_text_connector(img_features)
1547
+
1548
+ outputs = self.text_model(
1549
+ input_ids=input_ids,
1550
+ attention_mask=attention_mask,
1551
+ img_features=img_features,
1552
+ img_attn_mask=img_attn_mask,
1553
+ head_mask=head_mask,
1554
+ past_key_values=past_key_values,
1555
+ inputs_embeds=inputs_embeds,
1556
+ use_cache=use_cache,
1557
+ output_attentions=output_attentions,
1558
+ output_hidden_states=output_hidden_states,
1559
+ return_dict=return_dict,
1560
+ )
1561
+
1562
+ if not return_dict:
1563
+ outputs = outputs + (img_features, image_connector_attention, vision_model_output)
1564
+ return tuple(output for output in outputs if output is not None)
1565
+
1566
+ return Kosmos2ModelOutput(
1567
+ last_hidden_states=outputs.last_hidden_state,
1568
+ past_key_values=outputs.past_key_values,
1569
+ hidden_states=outputs.hidden_states,
1570
+ attentions=outputs.attentions,
1571
+ image_features=img_features,
1572
+ image_connector_attention=image_connector_attention,
1573
+ vision_model_output=vision_model_output,
1574
+ )
1575
+
1576
+
1577
+ @add_start_docstrings(
1578
+ """
1579
+ KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder (CLIP)
1580
+ and a language model.
1581
+ """,
1582
+ KOSMOS2_START_DOCSTRING,
1583
+ )
1584
+ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel):
1585
+ config_class = Kosmos2Config
1586
+ _tied_weights_keys = ["text_model.lm_head.weight"]
1587
+
1588
+ def __init__(self, config: Kosmos2Config):
1589
+ super().__init__(config)
1590
+
1591
+ self.text_model = Kosmos2TextForCausalLM(config.text_config)
1592
+ self.vision_model = Kosmos2VisionModel(config.vision_config)
1593
+
1594
+ self.image_to_text_connector = Kosmos2ImageToTextConnector(config)
1595
+
1596
+ # Initialize weights and apply final processing
1597
+ self.post_init()
1598
+
1599
+ def get_input_embeddings(self) -> nn.Module:
1600
+ return self.text_model.model.embed_tokens
1601
+
1602
+ def set_input_embeddings(self, value):
1603
+ self.text_model.model.embed_tokens = value
1604
+
1605
+ def get_output_embeddings(self) -> nn.Module:
1606
+ return self.text_model.get_output_embeddings()
1607
+
1608
+ def set_output_embeddings(self, new_embeddings):
1609
+ self.text_model.set_output_embeddings(new_embeddings)
1610
+
1611
+ @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
1612
+ @replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=Kosmos2Config)
1613
+ def forward(
1614
+ self,
1615
+ pixel_values: Optional[torch.Tensor] = None,
1616
+ img_attn_mask=None,
1617
+ input_ids: Optional[torch.Tensor] = None,
1618
+ attention_mask=None,
1619
+ head_mask: Optional[torch.Tensor] = None,
1620
+ img_features: Optional[List[torch.FloatTensor]] = None,
1621
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1622
+ inputs_embeds: Optional[torch.Tensor] = None,
1623
+ labels: Optional[torch.LongTensor] = None,
1624
+ use_cache: Optional[bool] = None,
1625
+ output_attentions: Optional[bool] = None,
1626
+ output_hidden_states: Optional[bool] = None,
1627
+ return_dict: Optional[bool] = None,
1628
+ ) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]:
1629
+ r"""
1630
+ Returns:
1631
+
1632
+ Examples:
1633
+
1634
+ ```python
1635
+ >>> from PIL import Image
1636
+ >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
1637
+
1638
+ >>> model = Kosmos2ForConditionalGeneration.from_pretrained("ydshieh/kosmos-2-patch14-224")
1639
+ >>> processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224")
1640
+
1641
+ >>> prompt = "<grounding> An image of"
1642
+ >>> image = Image.open("snowman.jpg")
1643
+
1644
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
1645
+
1646
+ >>> generated_ids = model.generate(
1647
+ ... pixel_values=inputs["pixel_values"],
1648
+ ... input_ids=inputs["input_ids"][:, :-1],
1649
+ ... attention_mask=inputs["attention_mask"][:, :-1],
1650
+ ... img_features=None,
1651
+ ... img_attn_mask=inputs["img_attn_mask"][:, :-1],
1652
+ ... use_cache=True,
1653
+ ... max_new_tokens=64,
1654
+ ... )
1655
+
1656
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1657
+ >>> result = processor.post_processor_generation(generated_text)
1658
+ >>> result
1659
+ <grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.
1660
+ ```"""
1661
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1662
+ output_hidden_states = (
1663
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1664
+ )
1665
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1666
+
1667
+ vision_model_output = None
1668
+ image_connector_attention = None
1669
+ if img_features is None:
1670
+ if pixel_values is None:
1671
+ raise ValueError("You have to specify either `pixel_values` or `img_features`.")
1672
+
1673
+ vision_model_output = self.vision_model(pixel_values)
1674
+ # HF's CLIP has `last_hidden_state` without going through `post_layernorm`.
1675
+ # Here we need the whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
1676
+ img_features = self.vision_model.model.post_layernorm(vision_model_output.last_hidden_state)
1677
+ # normalized features
1678
+ img_features = nn.functional.normalize(img_features, dim=-1)
1679
+ img_features, image_connector_attention = self.image_to_text_connector(img_features)
1680
+
1681
+ lm_outputs = self.text_model(
1682
+ input_ids=input_ids,
1683
+ attention_mask=attention_mask,
1684
+ img_features=img_features,
1685
+ img_attn_mask=img_attn_mask,
1686
+ head_mask=head_mask,
1687
+ past_key_values=past_key_values,
1688
+ inputs_embeds=inputs_embeds,
1689
+ labels=labels,
1690
+ use_cache=use_cache,
1691
+ output_attentions=output_attentions,
1692
+ output_hidden_states=output_hidden_states,
1693
+ return_dict=return_dict,
1694
+ )
1695
+
1696
+ if not return_dict:
1697
+ outputs = lm_outputs + (img_features, image_connector_attention, vision_model_output)
1698
+ return tuple(output for output in outputs if output is not None)
1699
+
1700
+ return Kosmos2ForConditionalGenerationModelOutput(
1701
+ loss=lm_outputs.loss,
1702
+ logits=lm_outputs.logits,
1703
+ past_key_values=lm_outputs.past_key_values,
1704
+ hidden_states=lm_outputs.hidden_states,
1705
+ attentions=lm_outputs.attentions,
1706
+ image_features=img_features,
1707
+ image_connector_attention=image_connector_attention,
1708
+ vision_model_output=vision_model_output,
1709
+ )
1710
+
1711
+ def generate(
1712
+ self,
1713
+ input_ids=None,
1714
+ attention_mask=None,
1715
+ img_features=None,
1716
+ inputs_embeds=None,
1717
+ pixel_values=None,
1718
+ **kwargs,
1719
+ ):
1720
+ # in order to allow `inputs` argument (as in `GenerationMixin`)
1721
+ inputs = kwargs.pop("inputs", None)
1722
+ if pixel_values is not None and inputs is not None:
1723
+ raise ValueError(
1724
+ f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
1725
+ f"Make sure to either pass `inputs` or pixel_values=..."
1726
+ )
1727
+ if pixel_values is None and inputs is not None:
1728
+ pixel_values = inputs
1729
+
1730
+ if img_features is None:
1731
+ vision_model_output = self.vision_model(pixel_values)
1732
+ # HF's CLIP has `last_hidden_state` without going through `post_layernorm`.
1733
+ # Here we need the whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
1734
+ img_features = self.vision_model.model.post_layernorm(vision_model_output.last_hidden_state)
1735
+ # normalized features
1736
+ img_features = nn.functional.normalize(img_features, dim=-1)
1737
+ img_features, image_connector_attention = self.image_to_text_connector(img_features)
1738
+
1739
+ output = self.text_model.generate(
1740
+ input_ids=input_ids,
1741
+ attention_mask=attention_mask,
1742
+ img_features=img_features,
1743
+ input_embeds=inputs_embeds,
1744
+ **kwargs,
1745
+ )
1746
+
1747
+ return output
processing_kosmos2.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Processor class for KOSMOS-2."""
16
+
17
+ import copy
18
+ import math
19
+ import re
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ...image_processing_utils import BatchFeature
25
+ from ...image_utils import ImageInput, is_batched
26
+ from ...processing_utils import ProcessorMixin
27
+ from ...tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
28
+ from ...utils import TensorType, is_tf_available, is_torch_available
29
+
30
+
31
+ if is_torch_available():
32
+ import torch
33
+
34
+ if is_tf_available():
35
+ import tensorflow as tf
36
+
37
+
38
+ BboxInput = Union[
39
+ List[Tuple[int, int]],
40
+ List[Tuple[float, float, float, float]],
41
+ List[List[Tuple[int, int]]],
42
+ List[List[Tuple[float, float, float]]],
43
+ ]
44
+
45
+
46
+ class Kosmos2Processor(ProcessorMixin):
47
+ r"""
48
+ Constructs an KOSMOS-2 processor which wraps a CLIP image processor and a KOSMOS-2 tokenizer into a single
49
+ processor.
50
+
51
+ [`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`Kosmos2TokenizerFast`]. See the
52
+ docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`] for more information.
53
+
54
+ Args:
55
+ image_processor (`CLIPImageProcessor`):
56
+ An instance of [`CLIPImageProcessor`]. The image processor is a required input.
57
+ tokenizer (`Kosmos2TokenizerFast`):
58
+ An instance of ['Kosmos2TokenizerFast`]. The tokenizer is a required input.
59
+ """
60
+ attributes = ["image_processor", "tokenizer"]
61
+ image_processor_class = "CLIPImageProcessor"
62
+ tokenizer_class = ("Kosmos2Tokenizer", "Kosmos2TokenizerFast")
63
+
64
+ def __init__(self, image_processor, tokenizer):
65
+ tokenizer.return_token_type_ids = False
66
+ super().__init__(image_processor, tokenizer)
67
+ self.current_processor = self.image_processor
68
+
69
+ def __call__(
70
+ self,
71
+ images: ImageInput = None,
72
+ text: Union[TextInput, List[TextInput]] = None,
73
+ bboxes: BboxInput = None,
74
+ num_image_tokens: Optional[int] = 64,
75
+ first_image_token_id: Optional[int] = None,
76
+ add_special_tokens: bool = True,
77
+ padding: Union[bool, str, PaddingStrategy] = False,
78
+ truncation: Union[bool, str, TruncationStrategy] = None,
79
+ max_length: Optional[int] = None,
80
+ stride: int = 0,
81
+ pad_to_multiple_of: Optional[int] = None,
82
+ return_attention_mask: Optional[bool] = None,
83
+ return_overflowing_tokens: bool = False,
84
+ return_special_tokens_mask: bool = False,
85
+ return_offsets_mapping: bool = False,
86
+ return_token_type_ids: bool = False,
87
+ return_length: bool = False,
88
+ verbose: bool = True,
89
+ return_tensors: Optional[Union[str, TensorType]] = None,
90
+ **kwargs,
91
+ ) -> BatchFeature:
92
+ """
93
+ This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and
94
+ [`Kosmos2TokenizerFast.__call__`] to prepare text for the model.
95
+
96
+ Please refer to the docstring of the above two methods for more information.
97
+ """
98
+ if text is None:
99
+ raise ValueError("You have to specify at least `text`.")
100
+
101
+ text = self.preprocess_text(text, images, bboxes, num_image_tokens=num_image_tokens)
102
+
103
+ encoding = BatchFeature()
104
+
105
+ text_encoding = self.tokenizer(
106
+ text=text,
107
+ add_special_tokens=add_special_tokens,
108
+ padding=padding,
109
+ truncation=truncation,
110
+ max_length=max_length,
111
+ stride=stride,
112
+ pad_to_multiple_of=pad_to_multiple_of,
113
+ return_attention_mask=return_attention_mask,
114
+ return_overflowing_tokens=return_overflowing_tokens,
115
+ return_special_tokens_mask=return_special_tokens_mask,
116
+ return_offsets_mapping=return_offsets_mapping,
117
+ return_token_type_ids=return_token_type_ids,
118
+ return_length=return_length,
119
+ verbose=verbose,
120
+ return_tensors=return_tensors,
121
+ **kwargs,
122
+ )
123
+ encoding.update(text_encoding)
124
+
125
+ if images is not None:
126
+ image_encoding = self.image_processor(images, return_tensors=return_tensors)
127
+ encoding.update(image_encoding)
128
+
129
+ # Use the id of the first token after <unk>
130
+ if first_image_token_id is None:
131
+ first_image_token_id = self.tokenizer.unk_token_id + 1
132
+
133
+ # To see if we need one more `0` (for `<s>`) at the beginning of `img_attn_mask`.
134
+ with_bos = add_special_tokens
135
+
136
+ # The first (actual) `<image>` token is always at the 1st or 2nd place (after `<s>` if any). Here we look
137
+ # for the second `<image>` token (which indicate the first image token).
138
+ start_index = int(with_bos) + 1
139
+
140
+ if return_tensors:
141
+ # change the ids for the fake `<image>` tokens in `input_ids`
142
+ input_ids = np.array(encoding["input_ids"])
143
+ input_ids[:, start_index : (start_index + num_image_tokens)] = np.arange(
144
+ first_image_token_id, first_image_token_id + num_image_tokens
145
+ )
146
+
147
+ batch_size, seq_len = input_ids.shape[:2]
148
+ img_attn_mask = []
149
+ if with_bos:
150
+ # for `<s>`
151
+ img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64))
152
+ # for `<image>` (the real one)
153
+ img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64))
154
+ # for image tokens
155
+ img_attn_mask.append(np.ones(shape=(batch_size, 64), dtype=np.int64))
156
+ # for `</image>`
157
+ img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64))
158
+ # trailing part (which are not related to the image)
159
+ seq_len -= int(with_bos) + 1 + num_image_tokens + 1
160
+ img_attn_mask.append(np.zeros(shape=(batch_size, seq_len), dtype=np.int64))
161
+
162
+ # concatenate along the sequence dimension
163
+ img_attn_mask = np.concatenate(img_attn_mask, axis=1)
164
+
165
+ # to the target tensor type
166
+ if return_tensors == "pt":
167
+ input_ids = torch.from_numpy(input_ids)
168
+ img_attn_mask = torch.from_numpy(img_attn_mask)
169
+ elif return_tensors == "tf":
170
+ input_ids = tf.convert_to_tensor(input_ids)
171
+ img_attn_mask = tf.convert_to_tensor(img_attn_mask)
172
+
173
+ encoding["input_ids"] = input_ids
174
+ encoding["img_attn_mask"] = img_attn_mask
175
+
176
+ else:
177
+ # Add `img_attn_mask`: the leading and trailing `0` are for `boi` and `eoi` tokens. The `1` indicates
178
+ # the places of image tokens.
179
+ image_token_ids = list(range(first_image_token_id, first_image_token_id + num_image_tokens))
180
+ base_img_attn_mask = [0] + [1] * num_image_tokens + [0]
181
+
182
+ # loop over `encoding["input_ids"]`
183
+ input_ids = []
184
+ img_attn_mask = []
185
+ all_input_ids = encoding["input_ids"]
186
+ # not batched -> (changed to) batch of size 1
187
+ if isinstance(text, str):
188
+ all_input_ids = [all_input_ids]
189
+ for text_ids in all_input_ids:
190
+ # change the ids for the fake `<image>` tokens in `input_ids`
191
+ text_ids = text_ids[:start_index] + image_token_ids + text_ids[start_index + num_image_tokens :]
192
+ input_ids.append(text_ids)
193
+
194
+ mask = copy.copy(base_img_attn_mask)
195
+ if with_bos:
196
+ # for `<s>`
197
+ mask = [0] + mask
198
+ # trailing part (which are not related to the image)
199
+ mask += [0] * (len(text_ids) - len(mask))
200
+ img_attn_mask.append(mask)
201
+
202
+ # un-batch if necessary
203
+ if isinstance(text, str):
204
+ input_ids = input_ids[0]
205
+ img_attn_mask = img_attn_mask[0]
206
+
207
+ encoding["input_ids"] = input_ids
208
+ encoding["img_attn_mask"] = img_attn_mask
209
+
210
+ return encoding
211
+
212
+ def preprocess_text(
213
+ self,
214
+ texts: Union[TextInput, List[TextInput]],
215
+ images: ImageInput = None,
216
+ bboxes: BboxInput = None,
217
+ num_image_tokens: Optional[int] = 64,
218
+ ) -> Union[str, List[str]]:
219
+ """Add image and bounding box information to `texts` as image and patch index tokens.
220
+
221
+ Args:
222
+ texts (`Union[TextInput, List[TextInput]]`): The texts to be processed.
223
+ images (`ImageInput`, *optional*): The images associated to `texts`.
224
+ bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*): The bounding bboxes associated to `texts`.
225
+ num_image_tokens (`int`, *optional*, defaults to 64): The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num` attribute in `Kosmos2Config`.
226
+
227
+ Returns:
228
+ `Union[TextInput, List[TextInput]]`: The processed texts with image and patch index tokens.
229
+ """
230
+ # These are fake `<image>` tokens enclosed between (the actual) `<image>` token and `</image>`.
231
+ img_tokens = ["<image>"] * num_image_tokens
232
+ img_info = " ".join(["<image>"] + img_tokens + ["</image>"])
233
+
234
+ def check_bboxes_for_single_text(bboxes):
235
+ """
236
+ Check `bboxes` for a single text example. It could be
237
+ - `None`: no bounding box associated to a text.
238
+ - A list with each element being the bounding boxes associated to one `<phrase> ... </phrase>` pair
239
+ found in a text. This could be:
240
+ - `None`: no bounding box associated to a `<phrase> ... </phrase>` pair.
241
+ - A tuple of 2 integers: A single bounding box specified by patch indices.
242
+ - A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates.
243
+ - A list containing the above 2 tuple types: Multiple bounding boxes for a
244
+ `<phrase> ... </phrase>` pair.
245
+ """
246
+ if bboxes is None:
247
+ return
248
+ elif not isinstance(bboxes, list):
249
+ raise ValueError("`bboxes` (for a single text example) should be `None` or a list.")
250
+
251
+ # `bbox` is the bounding boxes for a single <phrase> </phrase> pair
252
+ for bbox in bboxes:
253
+ if bbox is None:
254
+ continue
255
+ elif not isinstance(bbox, list):
256
+ bbox = [bbox]
257
+ for elt in bbox:
258
+ if not isinstance(elt, tuple) or not (
259
+ (len(elt) == 2 and all(isinstance(x, int) for x in elt))
260
+ or (len(elt) == 4 and all(isinstance(x, float) for x in elt))
261
+ ):
262
+ raise ValueError(
263
+ "Each element in `bboxes` (for a single text example) should be `None`, a tuple containing "
264
+ "2 integers or 4 float point numbers, or a list containing such tuples. Also "
265
+ "make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in "
266
+ "batches or both for a single example."
267
+ )
268
+
269
+ def preprocess_single(text, image, bboxes):
270
+ if image is not None:
271
+ # Add `<image> ... (fake) image tokens ... </image>`
272
+ text = f"{img_info} {text}"
273
+
274
+ # Add `<object> <patch_idx_xxxx> <patch_idx_yyy> </object>` after `<phrase> phrase text </phrase>`
275
+ text = self._insert_patch_index_tokens(text, bboxes)
276
+ text = self._add_remove_spaces_around_tag_tokens(text)
277
+
278
+ return text
279
+
280
+ # make batch to simplify processing logic
281
+ batched = True
282
+ if isinstance(texts, str):
283
+ batched = False
284
+ texts = [texts]
285
+
286
+ if images is None:
287
+ images = [None] * len(texts)
288
+ elif not is_batched(images):
289
+ images = [images]
290
+ if len(texts) != len(images):
291
+ raise ValueError(
292
+ f"The number of examples in `texts` and `images` should be the same. Got {len(texts)} v.s. {len(images)} instead."
293
+ )
294
+
295
+ if not batched:
296
+ check_bboxes_for_single_text(bboxes)
297
+ bboxes = [bboxes]
298
+ elif bboxes is not None:
299
+ if not isinstance(bboxes, list):
300
+ raise ValueError("`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.")
301
+ for x in bboxes:
302
+ check_bboxes_for_single_text(x)
303
+ else:
304
+ bboxes = [None] * len(texts)
305
+
306
+ if len(bboxes) != len(texts):
307
+ raise ValueError(
308
+ f"The number of examples in `texts` and `bboxes` should be the same. Got {len(texts)} v.s. {len(bboxes)} instead."
309
+ )
310
+
311
+ result = [preprocess_single(text, image, bbox) for text, image, bbox in zip(texts, images, bboxes)]
312
+ # un-batch if necessary
313
+ if not batched:
314
+ result = result[0]
315
+
316
+ return result
317
+
318
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
319
+ def batch_decode(self, *args, **kwargs):
320
+ """
321
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
322
+ refer to the docstring of this method for more information.
323
+ """
324
+ return self.tokenizer.batch_decode(*args, **kwargs)
325
+
326
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
327
+ def decode(self, *args, **kwargs):
328
+ """
329
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
330
+ to the docstring of this method for more information.
331
+ """
332
+ return self.tokenizer.decode(*args, **kwargs)
333
+
334
+ def post_processor_generation(self, text):
335
+ return text.split("</image>")[-1]
336
+
337
+ @property
338
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
339
+ def model_input_names(self):
340
+ tokenizer_input_names = self.tokenizer.model_input_names
341
+ image_processor_input_names = self.image_processor.model_input_names
342
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
343
+
344
+ def _insert_patch_index_tokens(self, text: str, bboxes: Union[List[Tuple[int]], List[Tuple[float]]]) -> str:
345
+ if bboxes is None or len(bboxes) == 0:
346
+ return text
347
+
348
+ matched_phrases = list(re.finditer(r"<phrase>.+?</phrase>", string=text))
349
+ if len(matched_phrases) != len(bboxes):
350
+ raise ValueError(
351
+ f"The number of elements in `bboxes` should be the same as the number of `<phrase> ... </phrase>` pairs in `text`. Got {len(matched_phrases)} v.s. {len(bboxes)} instead."
352
+ )
353
+
354
+ # insert object's patch index tokens
355
+ # the found `<phrase> ... </phrase>` pairs.
356
+ curr_pos = 0
357
+ buffer = []
358
+ for matched, bbox in zip(matched_phrases, bboxes):
359
+ _, end = matched.span()
360
+ buffer.append(text[curr_pos:end])
361
+ curr_pos = end
362
+ # A phrase without bbox
363
+ if bbox is None:
364
+ continue
365
+ # A phrase with a single bbox
366
+ if isinstance(bbox, tuple):
367
+ bbox = [bbox]
368
+ patch_index_strings = []
369
+ # A phrase could have multiple bboxes
370
+ for box in bbox:
371
+ patch_index_1, patch_index_2 = self._convert_bbox_to_patch_index_tokens(box)
372
+ patch_index_strings.append(f"{patch_index_1} {patch_index_2}")
373
+ position_str = " </delimiter_of_multi_objects/> ".join(patch_index_strings)
374
+ buffer.append(f"<object> {position_str} </object>")
375
+ # remaining
376
+ if curr_pos < len(text):
377
+ buffer.append(text[curr_pos:])
378
+
379
+ text = "".join(buffer)
380
+ return text
381
+
382
+ def _convert_bbox_to_patch_index_tokens(
383
+ self, bbox: Union[Tuple[int, int], Tuple[float, float, float, float]]
384
+ ) -> Tuple[str, str]:
385
+ # already computed patch indices
386
+ if len(bbox) == 2:
387
+ idx_1, idx_2 = bbox
388
+ # bbox specified with (normalized) coordinates
389
+ else:
390
+ # use `self.tokenizer` to get `num_patches_per_side`
391
+ num_patches_per_side = int(math.sqrt(self.tokenizer.num_patch_index_tokens))
392
+ idx_1, idx_2 = coordinate_to_patch_index(bbox, num_patches_per_side)
393
+
394
+ token_1 = f"<patch_index_{str(idx_1).zfill(4)}>"
395
+ token_2 = f"<patch_index_{str(idx_2).zfill(4)}>"
396
+
397
+ return token_1, token_2
398
+
399
+ def _add_remove_spaces_around_tag_tokens(self, text):
400
+ """
401
+ Remove spaces before tag tokens (e.g. `<x>`). Also ensure a space after a tag token, if it is not followed by
402
+ another tag token (this is not technically necessary, but good for a standard/consistent format). This avoids
403
+ the inconsistency of tokenization results between kosmos-2 slow and fast tokenizers.
404
+ """
405
+
406
+ tag_tokens = set(
407
+ self.tokenizer.tag_tokens
408
+ + [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.tokenizer.num_patch_index_tokens)]
409
+ )
410
+ pattern = "|".join(tag_tokens)
411
+ splits = re.split(rf"({pattern})", text)
412
+
413
+ output = ""
414
+ prev_str_in_targets = False
415
+ for split in splits:
416
+ if split in tag_tokens:
417
+ prev_str_in_targets = True
418
+ output = output.rstrip() + split
419
+ else:
420
+ # we don't need to ensure a space before a normal token that is after a tag token. But having it and
421
+ # keeps a standard format is good anyway.
422
+ if prev_str_in_targets and not split.startswith(" "):
423
+ output += " " + split
424
+ else:
425
+ output += split
426
+ prev_str_in_targets = False
427
+
428
+ return output
429
+
430
+
431
+ def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patches_per_side: int) -> Tuple[int, int]:
432
+ """Convert a bounding box to a pair of patch indices.
433
+
434
+ Args:
435
+ bbox (`Tuple[float, float, float, float]`):
436
+ The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left
437
+ and lower-right corners of the box. It should have x2 > x1 and y1 > y2.
438
+ num_patches_per_side (`int`): the number of patches along each side.
439
+
440
+ Returns:
441
+ `Tuple[int, int]`: A pair of patch indices.
442
+ """
443
+ (x1, y1, x2, y2) = bbox
444
+
445
+ ul_x = math.floor(x1 * num_patches_per_side)
446
+ ul_y = math.floor(y1 * num_patches_per_side)
447
+
448
+ lr_x = math.ceil(x2 * num_patches_per_side - 1)
449
+ lr_y = math.ceil(y2 * num_patches_per_side - 1)
450
+
451
+ ul_idx = ul_y * num_patches_per_side + ul_x
452
+ lr_idx = lr_y * num_patches_per_side + lr_x
453
+
454
+ return ul_idx, lr_idx
455
+
456
+
457
+ # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
458
+ def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
459
+ """
460
+ Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
461
+ bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).
462
+
463
+ Args:
464
+ ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box.
465
+ lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box.
466
+ num_patches_per_side (`int`): the number of patches along each side.
467
+
468
+ Returns:
469
+ `Tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).
470
+ """
471
+ # Compute the size of each cell in the grid
472
+ cell_size = 1.0 / num_patches_per_side
473
+
474
+ # Compute the x and y indices of the upper-left and lower-right corners of the bounding box
475
+ ul_x = ul_idx % num_patches_per_side
476
+ ul_y = ul_idx // num_patches_per_side
477
+
478
+ lr_x = lr_idx % num_patches_per_side
479
+ lr_y = lr_idx // num_patches_per_side
480
+
481
+ # Compute the normalized coordinates of the bounding box
482
+ if ul_idx == lr_idx:
483
+ x1 = ul_x * cell_size
484
+ y1 = ul_y * cell_size
485
+ x2 = lr_x * cell_size + cell_size
486
+ y2 = lr_y * cell_size + cell_size
487
+ elif ul_x == lr_x or ul_y == lr_y:
488
+ x1 = ul_x * cell_size
489
+ y1 = ul_y * cell_size
490
+ x2 = lr_x * cell_size + cell_size
491
+ y2 = lr_y * cell_size + cell_size
492
+ else:
493
+ x1 = ul_x * cell_size + cell_size / 2
494
+ y1 = ul_y * cell_size + cell_size / 2
495
+ x2 = lr_x * cell_size + cell_size / 2
496
+ y2 = lr_y * cell_size + cell_size / 2
497
+
498
+ return x1, y1, x2, y2
tokenization_kosmos2.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for KOSMOS-2 model."""
16
+
17
+
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ SPIECE_UNDERLINE = "▁"
31
+
32
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
33
+
34
+ PRETRAINED_VOCAB_FILES_MAP = {
35
+ "vocab_file": {
36
+ "microsoft/kosmos-2-patch14-224": "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/sentencepiece.bpe.model",
37
+ }
38
+ }
39
+
40
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
41
+ "microsoft/kosmos-2-patch14-224": 2048,
42
+ }
43
+
44
+
45
+ class Kosmos2Tokenizer(PreTrainedTokenizer):
46
+ """
47
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
48
+ [SentencePiece](https://github.com/google/sentencepiece).
49
+
50
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
51
+ this superclass for more information regarding those methods.
52
+
53
+ Args:
54
+ vocab_file (`str`):
55
+ Path to the vocabulary file.
56
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
57
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
58
+
59
+ <Tip>
60
+
61
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
62
+ sequence. The token used is the `cls_token`.
63
+
64
+ </Tip>
65
+
66
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
67
+ The end of sequence token.
68
+
69
+ <Tip>
70
+
71
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
72
+ The token used is the `sep_token`.
73
+
74
+ </Tip>
75
+
76
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
77
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
78
+ sequence classification or for a text and a question for question answering. It is also used as the last
79
+ token of a sequence built with special tokens.
80
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
81
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
82
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
83
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
84
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
85
+ token instead.
86
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
87
+ The token used for padding, for example when batching sequences of different lengths.
88
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
89
+ The token used for masking values. This is the token used when training this model with masked language
90
+ modeling. This is the token which the model will try to predict.
91
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
92
+ Additional special tokens used by the tokenizer.
93
+ num_patch_index_tokens (`int`, *optional*, defaults to `1024`):
94
+ The number of tokens used to specify the patch indices of bounding boxes in an image. These tokens have the
95
+ format `<patch_index_xxxx>` where `xxxx` is an integer.
96
+ sp_model_kwargs (`dict`, *optional*):
97
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
98
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
99
+ to set:
100
+
101
+ - `enable_sampling`: Enable subword regularization.
102
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
103
+
104
+ - `nbest_size = {0,1}`: No sampling is performed.
105
+ - `nbest_size > 1`: samples from the nbest_size results.
106
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
107
+ using forward-filtering-and-backward-sampling algorithm.
108
+
109
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
110
+ BPE-dropout.
111
+
112
+ Attributes:
113
+ sp_model (`SentencePieceProcessor`):
114
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
115
+ """
116
+
117
+ vocab_files_names = VOCAB_FILES_NAMES
118
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
119
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
120
+ model_input_names = ["input_ids", "attention_mask"]
121
+
122
+ def __init__(
123
+ self,
124
+ vocab_file,
125
+ bos_token="<s>",
126
+ eos_token="</s>",
127
+ sep_token="</s>",
128
+ cls_token="<s>",
129
+ unk_token="<unk>",
130
+ pad_token="<pad>",
131
+ mask_token="<mask>",
132
+ num_patch_index_tokens=1024,
133
+ add_tag_and_patch_index_tokens=False,
134
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
135
+ **kwargs,
136
+ ) -> None:
137
+ # Mask token behave like a normal word, i.e. include the space before it
138
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
139
+
140
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
141
+
142
+ super().__init__(
143
+ bos_token=bos_token,
144
+ eos_token=eos_token,
145
+ unk_token=unk_token,
146
+ sep_token=sep_token,
147
+ cls_token=cls_token,
148
+ pad_token=pad_token,
149
+ mask_token=mask_token,
150
+ sp_model_kwargs=self.sp_model_kwargs,
151
+ **kwargs,
152
+ )
153
+
154
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
155
+ self.sp_model.Load(str(vocab_file))
156
+ self.vocab_file = vocab_file
157
+
158
+ # Original fairseq vocab and spm vocab must be "aligned":
159
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
160
+ # -------- | ------- | ------- | ------ | ------- | ------ | ------ | ------ | ------ | ------- | ------
161
+ # fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | '.' | '_the' | ',' | '▁to' | '▁and' | '▁of'
162
+ # spm | '<unk>' | '<s>' | '</s>' | '.' | '_the' | ',' | '▁to' | '▁and' | '▁of' | '▁a'
163
+
164
+ # Mimic fairseq token-to-id alignment for the first 4 token
165
+ self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
166
+
167
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
168
+ self.fairseq_offset = 1
169
+
170
+ self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + self.fairseq_offset
171
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
172
+
173
+ self.eod_token = "</doc>"
174
+
175
+ self.boi_token = "<image>"
176
+ self.eoi_token = "</image>"
177
+
178
+ self.eoc_token = "</chunk>"
179
+ self.eol_token = "</line>"
180
+
181
+ self.bop_token = "<phrase>"
182
+ self.eop_token = "</phrase>"
183
+
184
+ self.boo_token = "<object>"
185
+ self.eoo_token = "</object>"
186
+
187
+ self.dom_token = "</delimiter_of_multi_objects/>"
188
+
189
+ self.grd_token = "<grounding>"
190
+
191
+ self.tag_tokens = [
192
+ self.eod_token,
193
+ self.boi_token,
194
+ self.eoi_token,
195
+ self.eoc_token,
196
+ self.eol_token,
197
+ self.bop_token,
198
+ self.eop_token,
199
+ self.boo_token,
200
+ self.eoo_token,
201
+ self.dom_token,
202
+ self.grd_token,
203
+ ]
204
+
205
+ self.num_patch_index_tokens = num_patch_index_tokens
206
+ patch_index_tokens = [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.num_patch_index_tokens)]
207
+
208
+ if add_tag_and_patch_index_tokens:
209
+ for idx, token in enumerate(self.tag_tokens + patch_index_tokens):
210
+ # we can't add them as special tokens, as the slow tokenizer doesn't save the information of a token
211
+ # being special when it is added through `add_tokens`, but the fast tokenizer is able to do so.
212
+ self.add_tokens(AddedToken(token, lstrip=True, rstrip=False), special_tokens=True)
213
+
214
+ def _decode(
215
+ self,
216
+ token_ids: List[int],
217
+ skip_special_tokens: bool = False,
218
+ clean_up_tokenization_spaces: bool = None,
219
+ spaces_between_special_tokens: bool = True,
220
+ **kwargs,
221
+ ) -> str:
222
+ self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
223
+
224
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
225
+
226
+ # To avoid mixing byte-level and unicode for byte-level BPT
227
+ # we need to build string separately for added tokens and byte-level tokens
228
+ # cf. https://github.com/huggingface/transformers/issues/1133
229
+ sub_texts = []
230
+ current_sub_text = []
231
+ is_first_current_sub_text = True
232
+ for token in filtered_tokens:
233
+ if skip_special_tokens and token in self.all_special_ids:
234
+ continue
235
+ if token in self.added_tokens_encoder:
236
+ if current_sub_text:
237
+ sub_text = self.convert_tokens_to_string(current_sub_text)
238
+ # `convert_tokens_to_string` removes the leading space, which is undesired if we are not at the
239
+ # beginning part of the text. We can't use `spaces_between_special_tokens` to add this space back
240
+ # neither, as it will also add a space before a tag/patch_index token (which is not the case with
241
+ # the fast tokenizer - it doesn't even support `spaces_between_special_tokens`), which is not the
242
+ # ideal output format.
243
+ # The condition `not spaces_between_special_tokens` is to avoid double spaces.
244
+ if not is_first_current_sub_text and not spaces_between_special_tokens:
245
+ sub_text = " " + sub_text
246
+ sub_texts.append(sub_text)
247
+ current_sub_text = []
248
+ is_first_current_sub_text = False
249
+ sub_texts.append(token)
250
+ else:
251
+ current_sub_text.append(token)
252
+ if current_sub_text:
253
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
254
+
255
+ if spaces_between_special_tokens:
256
+ text = " ".join(sub_texts)
257
+ else:
258
+ text = "".join(sub_texts)
259
+
260
+ clean_up_tokenization_spaces = (
261
+ clean_up_tokenization_spaces
262
+ if clean_up_tokenization_spaces is not None
263
+ else self.clean_up_tokenization_spaces
264
+ )
265
+ if clean_up_tokenization_spaces:
266
+ clean_text = self.clean_up_tokenization(text)
267
+ return clean_text
268
+ else:
269
+ return text
270
+
271
+ def __getstate__(self):
272
+ state = self.__dict__.copy()
273
+ state["sp_model"] = None
274
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
275
+ return state
276
+
277
+ def __setstate__(self, d):
278
+ self.__dict__ = d
279
+
280
+ # for backward compatibility
281
+ if not hasattr(self, "sp_model_kwargs"):
282
+ self.sp_model_kwargs = {}
283
+
284
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
285
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
286
+
287
+ def build_inputs_with_special_tokens(
288
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
289
+ ) -> List[int]:
290
+ """
291
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
292
+ adding special tokens. An XLM-RoBERTa sequence has the following format:
293
+
294
+ - single sequence: `<s> X </s>`
295
+ - pair of sequences: `<s> A </s></s> B </s>`
296
+
297
+ Args:
298
+ token_ids_0 (`List[int]`):
299
+ List of IDs to which the special tokens will be added.
300
+ token_ids_1 (`List[int]`, *optional*):
301
+ Optional second list of IDs for sequence pairs.
302
+
303
+ Returns:
304
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
305
+ """
306
+
307
+ if token_ids_1 is None:
308
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
309
+ cls = [self.cls_token_id]
310
+ sep = [self.sep_token_id]
311
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
312
+
313
+ def get_special_tokens_mask(
314
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
315
+ ) -> List[int]:
316
+ """
317
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
318
+ special tokens using the tokenizer `prepare_for_model` method.
319
+
320
+ Args:
321
+ token_ids_0 (`List[int]`):
322
+ List of IDs.
323
+ token_ids_1 (`List[int]`, *optional*):
324
+ Optional second list of IDs for sequence pairs.
325
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
326
+ Whether or not the token list is already formatted with special tokens for the model.
327
+
328
+ Returns:
329
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
330
+ """
331
+
332
+ if already_has_special_tokens:
333
+ return super().get_special_tokens_mask(
334
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
335
+ )
336
+
337
+ if token_ids_1 is None:
338
+ return [1] + ([0] * len(token_ids_0)) + [1]
339
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
340
+
341
+ def create_token_type_ids_from_sequences(
342
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
343
+ ) -> List[int]:
344
+ """
345
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
346
+ not make use of token type ids, therefore a list of zeros is returned.
347
+
348
+ Args:
349
+ token_ids_0 (`List[int]`):
350
+ List of IDs.
351
+ token_ids_1 (`List[int]`, *optional*):
352
+ Optional second list of IDs for sequence pairs.
353
+
354
+ Returns:
355
+ `List[int]`: List of zeros.
356
+
357
+ """
358
+
359
+ sep = [self.sep_token_id]
360
+ cls = [self.cls_token_id]
361
+
362
+ if token_ids_1 is None:
363
+ return len(cls + token_ids_0 + sep) * [0]
364
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
365
+
366
+ @property
367
+ def vocab_size(self):
368
+ return len(self.sp_model) + self.fairseq_offset + 1 # Add the <mask> token
369
+
370
+ def get_vocab(self):
371
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
372
+ vocab.update(self.added_tokens_encoder)
373
+ return vocab
374
+
375
+ def _tokenize(self, text: str) -> List[str]:
376
+ return self.sp_model.encode(text, out_type=str)
377
+
378
+ def _convert_token_to_id(self, token):
379
+ """Converts a token (str) in an id using the vocab."""
380
+ if token in self.fairseq_tokens_to_ids:
381
+ return self.fairseq_tokens_to_ids[token]
382
+ spm_id = self.sp_model.PieceToId(token)
383
+
384
+ # Need to return unknown token if the SP model returned 0
385
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
386
+
387
+ def _convert_id_to_token(self, index):
388
+ """Converts an index (integer) in a token (str) using the vocab."""
389
+ if index in self.fairseq_ids_to_tokens:
390
+ return self.fairseq_ids_to_tokens[index]
391
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
392
+
393
+ def convert_tokens_to_string(self, tokens):
394
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
395
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
396
+ return out_string
397
+
398
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
399
+ if not os.path.isdir(save_directory):
400
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
401
+ return
402
+ out_vocab_file = os.path.join(
403
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
404
+ )
405
+
406
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
407
+ copyfile(self.vocab_file, out_vocab_file)
408
+ elif not os.path.isfile(self.vocab_file):
409
+ with open(out_vocab_file, "wb") as fi:
410
+ content_spiece_model = self.sp_model.serialized_model_proto()
411
+ fi.write(content_spiece_model)
412
+
413
+ return (out_vocab_file,)
tokenization_kosmos2_fast.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for KOSMOS-2 model."""
16
+
17
+
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import List, Optional, Tuple
21
+
22
+ from ...tokenization_utils import AddedToken
23
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
24
+ from ...utils import is_sentencepiece_available, logging
25
+
26
+
27
+ if is_sentencepiece_available():
28
+ from .tokenization_kosmos2 import Kosmos2Tokenizer
29
+ else:
30
+ Kosmos2TokenizerFast = None
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
36
+
37
+ PRETRAINED_VOCAB_FILES_MAP = {
38
+ "vocab_file": {
39
+ "microsoft/kosmos-2-patch14-224": "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/sentencepiece.bpe.model",
40
+ }
41
+ }
42
+
43
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
44
+ "microsoft/kosmos-2-patch14-224": 2048,
45
+ }
46
+
47
+
48
+ class Kosmos2TokenizerFast(PreTrainedTokenizerFast):
49
+ """
50
+ Construct a "fast" KOSMOS-2 tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from
51
+ [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
52
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
53
+
54
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
55
+ refer to this superclass for more information regarding those methods.
56
+
57
+ Args:
58
+ vocab_file (`str`):
59
+ Path to the vocabulary file.
60
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
61
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
62
+
63
+ <Tip>
64
+
65
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
66
+ sequence. The token used is the `cls_token`.
67
+
68
+ </Tip>
69
+
70
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
71
+ The end of sequence token.
72
+
73
+ <Tip>
74
+
75
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
76
+ The token used is the `sep_token`.
77
+
78
+ </Tip>
79
+
80
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
81
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
82
+ sequence classification or for a text and a question for question answering. It is also used as the last
83
+ token of a sequence built with special tokens.
84
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
85
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
86
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
87
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
88
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
89
+ token instead.
90
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
91
+ The token used for padding, for example when batching sequences of different lengths.
92
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
93
+ The token used for masking values. This is the token used when training this model with masked language
94
+ modeling. This is the token which the model will try to predict.
95
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
96
+ Additional special tokens used by the tokenizer.
97
+ num_patch_index_tokens (`int`, *optional*, defaults to `1024`):
98
+ The number of tokens used to specify the patch indices of bounding boxes in an image. These tokens have the
99
+ format `<patch_index_xxxx>` where `xxxx` is an integer.
100
+ """
101
+
102
+ vocab_files_names = VOCAB_FILES_NAMES
103
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
104
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
105
+ model_input_names = ["input_ids", "attention_mask"]
106
+ slow_tokenizer_class = Kosmos2Tokenizer
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_file=None,
111
+ tokenizer_file=None,
112
+ bos_token="<s>",
113
+ eos_token="</s>",
114
+ sep_token="</s>",
115
+ cls_token="<s>",
116
+ unk_token="<unk>",
117
+ pad_token="<pad>",
118
+ mask_token="<mask>",
119
+ num_patch_index_tokens=1024,
120
+ add_tag_and_patch_index_tokens=False,
121
+ **kwargs,
122
+ ):
123
+ # Mask token behave like a normal word, i.e. include the space before it
124
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
125
+
126
+ super().__init__(
127
+ vocab_file,
128
+ tokenizer_file=tokenizer_file,
129
+ bos_token=bos_token,
130
+ eos_token=eos_token,
131
+ sep_token=sep_token,
132
+ cls_token=cls_token,
133
+ unk_token=unk_token,
134
+ pad_token=pad_token,
135
+ mask_token=mask_token,
136
+ **kwargs,
137
+ )
138
+
139
+ self.vocab_file = vocab_file
140
+ self.can_save_slow_tokenizer = False if not self.vocab_file else True
141
+
142
+ self.eod_token = "</doc>"
143
+
144
+ self.boi_token = "<image>"
145
+ self.eoi_token = "</image>"
146
+
147
+ self.eoc_token = "</chunk>"
148
+ self.eol_token = "</line>"
149
+
150
+ self.bop_token = "<phrase>"
151
+ self.eop_token = "</phrase>"
152
+
153
+ self.boo_token = "<object>"
154
+ self.eoo_token = "</object>"
155
+
156
+ self.dom_token = "</delimiter_of_multi_objects/>"
157
+
158
+ self.grd_token = "<grounding>"
159
+
160
+ self.tag_tokens = [
161
+ self.eod_token,
162
+ self.boi_token,
163
+ self.eoi_token,
164
+ self.eoc_token,
165
+ self.eol_token,
166
+ self.bop_token,
167
+ self.eop_token,
168
+ self.boo_token,
169
+ self.eoo_token,
170
+ self.dom_token,
171
+ self.grd_token,
172
+ ]
173
+
174
+ self.num_patch_index_tokens = num_patch_index_tokens
175
+ patch_index_tokens = [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.num_patch_index_tokens)]
176
+
177
+ if add_tag_and_patch_index_tokens:
178
+ for idx, token in enumerate(self.tag_tokens + patch_index_tokens):
179
+ # we need to set `special_tokens=False` to be the same as in the slow tokenizer.
180
+ self.add_tokens(AddedToken(token, lstrip=True, rstrip=False), special_tokens=False)
181
+
182
+ def build_inputs_with_special_tokens(
183
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
184
+ ) -> List[int]:
185
+ """
186
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
187
+ adding special tokens. An XLM-RoBERTa sequence has the following format:
188
+
189
+ - single sequence: `<s> X </s>`
190
+ - pair of sequences: `<s> A </s></s> B </s>`
191
+
192
+ Args:
193
+ token_ids_0 (`List[int]`):
194
+ List of IDs to which the special tokens will be added.
195
+ token_ids_1 (`List[int]`, *optional*):
196
+ Optional second list of IDs for sequence pairs.
197
+
198
+ Returns:
199
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
200
+ """
201
+
202
+ if token_ids_1 is None:
203
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
204
+ cls = [self.cls_token_id]
205
+ sep = [self.sep_token_id]
206
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
207
+
208
+ def create_token_type_ids_from_sequences(
209
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
210
+ ) -> List[int]:
211
+ """
212
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
213
+ not make use of token type ids, therefore a list of zeros is returned.
214
+
215
+ Args:
216
+ token_ids_0 (`List[int]`):
217
+ List of IDs.
218
+ token_ids_1 (`List[int]`, *optional*):
219
+ Optional second list of IDs for sequence pairs.
220
+
221
+ Returns:
222
+ `List[int]`: List of zeros.
223
+
224
+ """
225
+
226
+ sep = [self.sep_token_id]
227
+ cls = [self.cls_token_id]
228
+
229
+ if token_ids_1 is None:
230
+ return len(cls + token_ids_0 + sep) * [0]
231
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
232
+
233
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
234
+ if not self.can_save_slow_tokenizer:
235
+ raise ValueError(
236
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
237
+ "tokenizer."
238
+ )
239
+
240
+ if not os.path.isdir(save_directory):
241
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
242
+ return
243
+ out_vocab_file = os.path.join(
244
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
245
+ )
246
+
247
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
248
+ copyfile(self.vocab_file, out_vocab_file)
249
+
250
+ return (out_vocab_file,)