Someshfengde commited on
Commit
a139ac6
1 Parent(s): 31f23f1

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. script.py +6 -1888
script.py CHANGED
@@ -8,1890 +8,8 @@ import torchvision.transforms as T
8
  from PIL import Image
9
  import torch
10
  from transformers import AutoImageProcessor
11
-
12
  #%%
13
- # coding=utf-8
14
- # Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
15
- #
16
- # Licensed under the Apache License, Version 2.0 (the "License");
17
- # you may not use this file except in compliance with the License.
18
- # You may obtain a copy of the License at
19
- #
20
- # http://www.apache.org/licenses/LICENSE-2.0
21
- #
22
- # Unless required by applicable law or agreed to in writing, software
23
- # distributed under the License is distributed on an "AS IS" BASIS,
24
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
- # See the License for the specific language governing permissions and
26
- # limitations under the License.
27
- """ PyTorch Hiera model."""
28
-
29
-
30
- import math
31
- from dataclasses import dataclass
32
- from typing import Dict, List, Optional, Tuple, Union
33
-
34
- import torch
35
- import torch.utils.checkpoint
36
- from torch import nn
37
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
38
-
39
- import transformers
40
-
41
- from transformers.activations import ACT2FN
42
- from transformers.modeling_outputs import (
43
- BackboneOutput,
44
- BaseModelOutput,
45
- BaseModelOutputWithPooling,
46
- ImageClassifierOutput,
47
- ModelOutput,
48
- )
49
- from transformers.modeling_utils import PreTrainedModel
50
- from transformers.utils import (
51
- add_code_sample_docstrings,
52
- add_start_docstrings,
53
- add_start_docstrings_to_model_forward,
54
- logging,
55
- replace_return_docstrings,
56
- )
57
- from transformers.utils.backbone_utils import BackboneMixin
58
- # coding=utf-8
59
- # Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
60
- #
61
- # Licensed under the Apache License, Version 2.0 (the "License");
62
- # you may not use this file except in compliance with the License.
63
- # You may obtain a copy of the License at
64
- #
65
- # http://www.apache.org/licenses/LICENSE-2.0
66
- #
67
- # Unless required by applicable law or agreed to in writing, software
68
- # distributed under the License is distributed on an "AS IS" BASIS,
69
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70
- # See the License for the specific language governing permissions and
71
- # limitations under the License.
72
- """ Hiera model configuration"""
73
-
74
- from collections import OrderedDict
75
- from typing import Mapping
76
-
77
- from packaging import version
78
-
79
- from transformers.configuration_utils import PretrainedConfig
80
- from transformers.onnx import OnnxConfig
81
- from transformers.utils import logging
82
- from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
83
-
84
-
85
- logger = logging.get_logger(__name__)
86
-
87
- HIERA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
88
- "EduardoPacheco/hiera-tiny-224": "https://huggingface.co/EduardoPacheco/hiera-tiny-224/resolve/main/config.json",
89
- }
90
-
91
-
92
- class HieraConfig(BackboneConfigMixin, PretrainedConfig):
93
- r"""
94
- This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate an Hiera
95
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
96
- defaults will yield a similar configuration to that of the Hiera
97
- [EduardoPacheco/hiera-base-224](https://huggingface.co/EduardoPacheco/hiera-base-224) architecture.
98
-
99
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
100
- documentation from [`PretrainedConfig`] for more information.
101
-
102
-
103
- Args:
104
- embed_dim (`int`, *optional*, defaults to 96):
105
- Dimensionality of patch embedding.
106
- input_size (`list(int)`, *optional*, defaults to `[224, 224]`):
107
- The size (resolution) of input in the format (height, width) for images
108
- and (frames, height, width) for videos.
109
- patch_kernel (`list(int)`, *optional*, defaults to `[7, 7]`):
110
- The size (resolution) of each patch.
111
- patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
112
- The stride of the patch.
113
- patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
114
- The padding of the patch.
115
- mlp_ratio (`float`, *optional*, defaults to 4.0):
116
- The ratio of mlp hidden dim to embedding dim.
117
- depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
118
- Depth of each layer in the Transformer encoder.
119
- initial_num_heads (`int`, *optional*, defaults to 1):
120
- Initial number of attention heads in the first layer of the Transformer encoder.
121
- num_head_multiplier (`float`, *optional*, defaults to 2.0):
122
- The multiplier to the number of attention heads in each layer of the Transformer encoder.
123
- embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
124
- The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
125
- num_query_pool (`int`, *optional*, defaults to 3):
126
- The number of query pool stages.
127
- query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
128
- The stride of the query pool.
129
- masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
130
- The size of the masked unit.
131
- masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
132
- Whether to use masked unit attention in each layer of the Transformer encoder.
133
- drop_path_rate (`float`, *optional*, defaults to 0.0):
134
- The drop path rate.
135
- sep_pos_embed (`bool`, *optional*, defaults to `False`):
136
- Whether to use separate position embedding for temporal and spatial dimensions. Must be `True` for videos.
137
- and `False` for images.
138
- num_channels (`int`, *optional*, defaults to 3):
139
- The number of input channels.
140
- hidden_act (`str`, *optional*, defaults to `"gelu"`):
141
- The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
142
- `"selu"` and `"gelu_new"` are supported.
143
- initializer_range (`float`, *optional*, defaults to 0.02):
144
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
145
- the zero_initializer for initializing all bias vectors.
146
- layer_norm_init (`float`, *optional*, defaults to 1.0):
147
- The initial weight value for layer normalization layers.
148
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
149
- The epsilon used by the layer normalization layers.
150
- decoder_embed_dim (`int`, *optional*):
151
- Dimensionality of decoder embeddings for MAE pretraining.
152
- decoder_depth (`int`, *optional*):
153
- Depth of the decoder for MAE pretraining.
154
- decoder_num_heads (`int`, *optional*):
155
- Number of attention heads in each layer of the decoder for MAE pretraining.
156
- norm_pix_loss (`bool`, *optional*, defaults to `True`):
157
- Whether to normalize the pixel loss by the number of pixels.
158
- mask_ratio (`float`, *optional*, defaults to 0.6):
159
- The ratio of masked tokens in the input.
160
- out_features (`List[str]`, *optional*):
161
- If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
162
- (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
163
- corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
164
- same order as defined in the `stage_names` attribute.
165
- out_indices (`List[int]`, *optional*):
166
- If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
167
- many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
168
- If unset and `out_features` is unset, will default to the last stage. Must be in the
169
- same order as defined in the `stage_names` attribute.
170
-
171
-
172
- Example:
173
-
174
- ```python
175
- >>> from transformers import HieraConfig, HieraModel
176
-
177
- >>> # Initializing a Hiera hiera-base-patch16-224 style configuration
178
- >>> configuration = HieraConfig()
179
-
180
- >>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
181
- >>> model = HieraModel(configuration)
182
-
183
- >>> # Accessing the model configuration
184
- >>> configuration = model.config
185
- ```"""
186
-
187
- model_type = "hiera"
188
-
189
- attribute_map = {"num_hidden_layers": "num_layers"}
190
-
191
- def __init__(
192
- self,
193
- embed_dim=96,
194
- input_size=[224, 224],
195
- patch_kernel=[7, 7],
196
- patch_stride=[4, 4],
197
- patch_padding=[3, 3],
198
- mlp_ratio=4.0,
199
- depths=[2, 3, 16, 3],
200
- initial_num_heads=1,
201
- num_head_multiplier=2.0,
202
- embed_dim_multiplier=2.0,
203
- num_query_pool=3,
204
- query_stride=[2, 2],
205
- masked_unit_size=[8, 8],
206
- masked_unit_attention=[True, True, False, False],
207
- drop_path_rate=0.0,
208
- sep_pos_embed=False,
209
- num_channels=3,
210
- hidden_act="gelu",
211
- initializer_range=0.02,
212
- layer_norm_init=1.0,
213
- layer_norm_eps=1e-6,
214
- decoder_embed_dim=None,
215
- decoder_depth=None,
216
- decoder_num_heads=None,
217
- norm_pix_loss=True,
218
- mask_ratio=0.6,
219
- out_features=None,
220
- out_indices=None,
221
- **kwargs,
222
- ):
223
- super().__init__(**kwargs)
224
- if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
225
- raise ValueError(
226
- f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
227
- f"raised to the power of the number of layers ({len(depths) - 1})"
228
- )
229
-
230
- if num_query_pool >= len(depths):
231
- raise ValueError(
232
- f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
233
- )
234
-
235
- self.embed_dim = embed_dim
236
- self.input_size = input_size
237
- self.patch_kernel = patch_kernel
238
- self.patch_stride = patch_stride
239
- self.patch_padding = patch_padding
240
- self.mlp_ratio = mlp_ratio
241
- self.depths = depths
242
- self.num_layers = len(depths)
243
- self.initial_num_heads = initial_num_heads
244
- self.num_head_multiplier = num_head_multiplier
245
- self.embed_dim_multiplier = embed_dim_multiplier
246
- self.num_query_pool = num_query_pool
247
- self.query_stride = query_stride
248
- self.masked_unit_size = masked_unit_size
249
- self.masked_unit_attention = masked_unit_attention
250
- self.drop_path_rate = drop_path_rate
251
- self.sep_pos_embed = sep_pos_embed
252
- self.num_channels = num_channels
253
- self.hidden_act = hidden_act
254
- self.initializer_range = initializer_range
255
- self.layer_norm_init = layer_norm_init
256
- self.layer_norm_eps = layer_norm_eps
257
- self.decoder_embed_dim = decoder_embed_dim
258
- self.decoder_depth = decoder_depth
259
- self.decoder_num_heads = decoder_num_heads
260
- self.norm_pix_loss = norm_pix_loss
261
- self.mask_ratio = mask_ratio
262
- # we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
263
- # this indicates the channel dimension after the last stage of the model
264
- self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
265
- self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
266
- self._out_features, self._out_indices = get_aligned_output_features_output_indices(
267
- out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
268
- )
269
-
270
-
271
- class HieraOnnxConfig(OnnxConfig):
272
- torch_onnx_minimum_version = version.parse("1.11")
273
-
274
- @property
275
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
276
- return OrderedDict(
277
- [
278
- ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
279
- ]
280
- )
281
-
282
- @property
283
- def atol_for_validation(self) -> float:
284
- return 1e-4
285
-
286
- logger = logging.get_logger(__name__)
287
-
288
- # General docstring
289
- _CONFIG_FOR_DOC = "HieraConfig"
290
-
291
- # Base docstring
292
- _CHECKPOINT_FOR_DOC = "EduardoPacheco/hiera-tiny-224"
293
- _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
294
-
295
- # Image classification docstring
296
- _IMAGE_CLASS_CHECKPOINT = "EduardoPacheco/hiera-tiny-224-in1k"
297
- _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
298
-
299
-
300
- HIERA_PRETRAINED_MODEL_ARCHIVE_LIST = [
301
- "EduardoPacheco/hiera-tiny-224",
302
- # See all Hiera models at https://huggingface.co/models?filter=hiera
303
- ]
304
-
305
-
306
- @dataclass
307
- class HieraEncoderOutput(ModelOutput):
308
- """
309
- Hiera encoder's outputs, with potential hidden states and attentions.
310
-
311
- Args:
312
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
313
- Sequence of hidden-states at the output of the last layer of the model.
314
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
315
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
316
- shape `(batch_size, sequence_length, hidden_size)`. Thesre are the unrolled hidden states of the model.
317
-
318
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
319
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
320
- Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
321
- sequence_length)`.
322
-
323
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
324
- heads.
325
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
326
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
327
- shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
328
-
329
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
330
- include the spatial dimensions.
331
- """
332
-
333
- last_hidden_state: torch.FloatTensor = None
334
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
335
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
336
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
337
-
338
-
339
- @dataclass
340
- class HieraModelOutput(ModelOutput):
341
- """
342
- Hiera model's outputs that also contains a pooling of the last hidden states.
343
-
344
- Args:
345
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
346
- Sequence of hidden-states at the output of the last layer of the model.
347
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
348
- Average pooling of the last layer hidden-state.
349
- mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
350
- Tensor indicating which patches are masked (0) and which are not (1).
351
- ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
352
- Tensor containing the original index of the (shuffled) masked patches.
353
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
354
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
355
- shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
356
-
357
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
358
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
359
- Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
360
- sequence_length)`.
361
-
362
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
363
- heads.
364
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
365
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
366
- shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
367
-
368
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
369
- include the spatial dimensions.
370
- """
371
-
372
- last_hidden_state: torch.FloatTensor = None
373
- pooler_output: Optional[torch.FloatTensor] = None
374
- mask: torch.LongTensor = None
375
- ids_restore: torch.LongTensor = None
376
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
377
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
378
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
379
-
380
-
381
- @dataclass
382
- class HieraForImageClassificationOutput(ImageClassifierOutput):
383
- """
384
- Hiera image classification outputs.
385
-
386
- Args:
387
- loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
388
- Classification loss.
389
- logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
390
- Prediction scores of the classification head (logits of the output layer).
391
- hidden_states (`tuple(torch.FloatTensor)`, `optional`):
392
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
393
- shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
394
-
395
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
396
- attentions (`tuple(torch.FloatTensor)`, `optional`):
397
- Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
398
- sequence_length)`.
399
-
400
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
401
- heads.
402
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
403
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
404
- shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
405
-
406
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
407
- include the spatial dimensions.
408
- """
409
-
410
- loss: Optional[torch.FloatTensor] = None
411
- logits: torch.FloatTensor = None
412
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
413
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
414
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
415
-
416
-
417
- @dataclass
418
- class HieraForPreTrainingOutput(ModelOutput):
419
- """
420
- Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
421
-
422
- Args:
423
- loss (`torch.FloatTensor` of shape `(1,)`):
424
- Pixel reconstruction loss.
425
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
426
- Pixel reconstruction logits.
427
- mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
428
- Tensor indicating which patches are masked (0) and which are not (1).
429
- ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
430
- Tensor containing the original index of the (shuffled) masked patches.
431
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
432
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
433
- shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
434
- plus the initial embedding outputs.
435
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
436
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
437
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
438
- the self-attention heads.
439
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
440
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
441
- shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
442
- plus the initial embedding outputs reshaped to include the spatial dimensions.
443
- """
444
-
445
- loss: Optional[torch.FloatTensor] = None
446
- logits: torch.FloatTensor = None
447
- mask: torch.LongTensor = None
448
- ids_restore: torch.LongTensor = None
449
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
450
- attentions: Optional[Tuple[torch.FloatTensor]] = None
451
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
452
-
453
-
454
- # Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L73
455
- def conv_nd(n: int) -> nn.Module:
456
- """
457
- Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
458
- If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
459
- """
460
- return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
461
-
462
-
463
- # Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L81
464
- def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
465
- # Refer to `Unroll` to see how this performs a maxpool-Nd
466
- return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
467
-
468
-
469
- class HieraPatchEmbeddings(nn.Module):
470
- """
471
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
472
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
473
- Transformer.
474
- """
475
-
476
- def __init__(self, config, is_mae: bool = False):
477
- super().__init__()
478
-
479
- # Support any number of spatial dimensions
480
- self.spatial_dims = len(config.patch_kernel)
481
- if self.spatial_dims not in (2, 3):
482
- raise ValueError(
483
- f"The number of dimensions of the input image should be 2 or 3, but got {self.spatial_dims}."
484
- )
485
- self.num_channels = config.num_channels
486
- self.image_size = config.input_size[-2:]
487
- self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
488
- self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
489
- self.mask_ratio = config.mask_ratio
490
- self.is_mae = is_mae
491
-
492
- self.projection = conv_nd(self.spatial_dims)(
493
- self.num_channels,
494
- config.embed_dim,
495
- kernel_size=config.patch_kernel,
496
- stride=config.patch_stride,
497
- padding=config.patch_padding,
498
- )
499
-
500
- def masked_conv(self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
501
- """Zero-out the masked regions of the input before conv.
502
- Prevents leakage of masked regions when using overlapping kernels.
503
- """
504
- if mask is None:
505
- return self.projection(pixel_values)
506
-
507
- target_size = pixel_values.shape[2:]
508
- # Reshape mask to (batch_size, 1, mask_unit_height, mask_unit_width)
509
- mask = mask.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
510
-
511
- if len(mask.shape[2:]) != len(target_size):
512
- raise ValueError(
513
- f"The length of the spatial dimensions of the mask should match the one from input image, but got {len(mask.shape[2:])} and {len(target_size)}."
514
- )
515
-
516
- if mask.shape[2:] != target_size:
517
- mask = nn.functional.interpolate(mask, size=target_size)
518
-
519
- return self.projection(pixel_values * mask.bool())
520
-
521
- def random_masking(self, pixel_values, noise=None):
522
- """
523
- Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
524
- noise.
525
-
526
- Args:
527
- pixel_values (`torch.LongTensor` of shape `(batch_size, num_channels, height, width)`)
528
- noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
529
- mainly used for testing purposes to control randomness and maintain the reproducibility
530
- """
531
- batch_size = pixel_values.shape[0]
532
- # Tokens selected for masking at mask unit level
533
- num_windows = math.prod(self.mask_spatial_shape)
534
- len_keep = int(num_windows * (1 - self.mask_ratio))
535
-
536
- if noise is None:
537
- noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
538
-
539
- # Sort noise for each sample
540
- ids_shuffle = torch.argsort(noise, dim=1)
541
- # ascend: small is keep, large is remove
542
- ids_restore = torch.argsort(ids_shuffle, dim=1)
543
-
544
- # Generate the binary mask: 1 is *keep*, 0 is *remove*
545
- # Note this is opposite to original MAE
546
- mask = torch.zeros([batch_size, num_windows], device=pixel_values.device)
547
- mask[:, :len_keep] = 1
548
- # Unshuffle to get the binary mask
549
- mask = torch.gather(mask, dim=1, index=ids_restore)
550
-
551
- return mask, ids_restore
552
-
553
- def forward(
554
- self,
555
- pixel_values: torch.Tensor,
556
- noise: Optional[torch.FloatTensor] = None,
557
- interpolate_pos_encoding: bool = False,
558
- ) -> torch.Tensor:
559
- num_channels = pixel_values.shape[1]
560
- height, width = pixel_values.shape[-2:]
561
-
562
- if num_channels != self.num_channels:
563
- raise ValueError(
564
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
565
- f" Expected {self.num_channels} but got {num_channels}."
566
- )
567
-
568
- if not interpolate_pos_encoding:
569
- if height != self.image_size[0] or width != self.image_size[1]:
570
- raise ValueError(
571
- f"Input image size ({height}*{width}) doesn't match model"
572
- f" ({self.image_size[0]}*{self.image_size[1]})."
573
- )
574
-
575
- (mask, ids_restore) = self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
576
-
577
- embeddings = self.masked_conv(pixel_values, mask)
578
- embeddings = embeddings.flatten(2).transpose(2, 1)
579
-
580
- return embeddings, mask, ids_restore
581
-
582
-
583
- class HieraEmbeddings(nn.Module):
584
- """
585
- Construct position and patch embeddings.
586
- """
587
-
588
- def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
589
- super().__init__()
590
- self.patch_stride = config.patch_stride
591
- self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
592
- self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
593
- self.num_tokens = math.prod(self.tokens_spatial_shape)
594
- self.sep_pos_embed = config.sep_pos_embed
595
- self.is_mae = is_mae
596
-
597
- self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
598
-
599
- if self.sep_pos_embed:
600
- self.position_embeddings_spatial = nn.Parameter(
601
- torch.zeros(
602
- 1,
603
- self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
604
- config.embed_dim,
605
- )
606
- )
607
- self.position_embeddings_temporal = nn.Parameter(
608
- torch.zeros(1, self.tokens_spatial_shape[0], config.embed_dim)
609
- )
610
- else:
611
- self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
612
-
613
- def interpolate_pos_encoding(
614
- self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
615
- ) -> torch.Tensor:
616
- """
617
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
618
- resolution images.
619
-
620
- Adapted from:
621
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
622
- """
623
-
624
- num_patches = embeddings.shape[1]
625
- num_positions = pos_embeds.shape[1]
626
- if num_patches == num_positions and height == width:
627
- return pos_embeds
628
- dim = embeddings.shape[-1]
629
- h0 = height // self.patch_stride[0] if not self.sep_pos_embed else height // self.patch_stride[1]
630
- w0 = width // self.patch_stride[1] if not self.sep_pos_embed else width // self.patch_stride[2]
631
- # we add a small number to avoid floating point error in the interpolation
632
- # see discussion at https://github.com/facebookresearch/dino/issues/8
633
- h0, w0 = h0 + 0.1, w0 + 0.1
634
- pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
635
- pos_embeds = pos_embeds.permute(0, 3, 1, 2)
636
- pos_embeds = nn.functional.interpolate(
637
- pos_embeds,
638
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
639
- mode="bicubic",
640
- align_corners=False,
641
- )
642
- if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]:
643
- raise ValueError("The interpolated position encoding does not have the right size")
644
- pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
645
- return pos_embeds
646
-
647
- def get_position_embedding(
648
- self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
649
- ) -> torch.Tensor:
650
- if self.sep_pos_embed:
651
- spatial = self.position_embeddings_spatial
652
- spatial = (
653
- self.interpolate_pos_encoding(embeddings, spatial, height, width)
654
- if interpolate_pos_encoding
655
- else spatial
656
- )
657
- spatial = spatial.repeat(1, self.tokens_spatial_shape[0], 1)
658
-
659
- temporal = torch.repeat_interleave(
660
- self.position_embeddings_temporal,
661
- self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
662
- dim=1,
663
- )
664
-
665
- return spatial + temporal
666
- else:
667
- position_embeddings = self.position_embeddings
668
- position_embeddings = (
669
- self.interpolate_pos_encoding(embeddings, position_embeddings, height, width)
670
- if interpolate_pos_encoding
671
- else position_embeddings
672
- )
673
- return position_embeddings
674
-
675
- def forward(
676
- self,
677
- pixel_values: torch.Tensor,
678
- noise: Optional[torch.FloatTensor] = None,
679
- interpolate_pos_encoding: bool = False,
680
- ) -> torch.Tensor:
681
- if len(self.tokens_spatial_shape) == 2:
682
- batch_size, num_channels, height, width = pixel_values.shape
683
- else:
684
- batch_size, num_channels, depth, height, width = pixel_values.shape
685
-
686
- embeddings, mask, ids_restore = self.patch_embeddings(
687
- pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
688
- )
689
-
690
- embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
691
-
692
- return embeddings, mask, ids_restore
693
-
694
-
695
- class HieraMaskUnitAttention(nn.Module):
696
- """
697
- Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
698
-
699
- Note: this assumes the tokens have already been flattened and unrolled into mask units.
700
- """
701
-
702
- def __init__(
703
- self,
704
- dim: int,
705
- dim_out: int,
706
- num_heads: int,
707
- query_stride: int = 1,
708
- window_size: int = 0,
709
- use_mask_unit_attn: bool = False,
710
- ):
711
- super().__init__()
712
-
713
- self.dim = dim
714
- self.dim_out = dim_out
715
- self.num_heads = num_heads
716
- self.query_stride = query_stride
717
-
718
- self.head_dim = dim_out // num_heads
719
- self.scale = (self.head_dim) ** -0.5
720
-
721
- self.qkv = nn.Linear(dim, 3 * dim_out)
722
- self.proj = nn.Linear(dim_out, dim_out)
723
-
724
- self.window_size = window_size
725
- self.use_mask_unit_attn = use_mask_unit_attn
726
-
727
- def forward(
728
- self,
729
- hidden_states: torch.Tensor,
730
- head_mask: Optional[torch.FloatTensor] = None,
731
- output_attentions: bool = False,
732
- ) -> torch.Tensor:
733
- """Input should be of shape [batch, tokens, channels]."""
734
- batch_size, seq_len, _ = hidden_states.shape
735
-
736
- num_windows = 1
737
- if self.use_mask_unit_attn:
738
- num_windows = seq_len // (self.query_stride * self.window_size)
739
-
740
- qkv = self.qkv(hidden_states)
741
- qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
742
- qkv = qkv.permute(3, 0, 4, 2, 1, 5)
743
-
744
- query, key, value = qkv.unbind(0)
745
-
746
- if self.query_stride > 1:
747
- # Refer to Unroll to see how this performs a maxpool-Nd
748
- query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
749
- query = query.max(dim=3).values
750
-
751
- attn_weights = (query * self.scale) @ key.transpose(-1, -2)
752
- attn_weights = attn_weights.softmax(dim=-1)
753
-
754
- # Mask heads if we want to
755
- if head_mask is not None:
756
- attn_weights = attn_weights * head_mask
757
-
758
- attn_output = attn_weights @ value
759
- attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.dim_out)
760
- attn_output = self.proj(attn_output)
761
-
762
- return (attn_output, attn_weights) if output_attentions else (attn_output, None)
763
-
764
-
765
- # Copied from transformers.models.beit.modeling_beit.drop_path
766
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
767
- """
768
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
769
-
770
- Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
771
- however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
772
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
773
- layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
774
- argument.
775
- """
776
- if drop_prob == 0.0 or not training:
777
- return input
778
- keep_prob = 1 - drop_prob
779
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
780
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
781
- random_tensor.floor_() # binarize
782
- output = input.div(keep_prob) * random_tensor
783
- return output
784
-
785
-
786
- # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
787
- class HieraDropPath(nn.Module):
788
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
789
-
790
- def __init__(self, drop_prob: Optional[float] = None) -> None:
791
- super().__init__()
792
- self.drop_prob = drop_prob
793
-
794
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
795
- return drop_path(hidden_states, self.drop_prob, self.training)
796
-
797
- def extra_repr(self) -> str:
798
- return "p={}".format(self.drop_prob)
799
-
800
-
801
- class HieraMlp(nn.Module):
802
- def __init__(self, config, dim: int):
803
- super().__init__()
804
- self.config = config
805
- self.activation_fn = ACT2FN[config.hidden_act]
806
- self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
807
- self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
808
-
809
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
810
- hidden_states = self.fc1(hidden_states)
811
- hidden_states = self.activation_fn(hidden_states)
812
- hidden_states = self.fc2(hidden_states)
813
- return hidden_states
814
-
815
-
816
- class HieraLayer(nn.Module):
817
- def __init__(
818
- self,
819
- config,
820
- dim: int,
821
- dim_out: int,
822
- num_heads: int,
823
- drop_path: float = 0.0,
824
- query_stride: int = 1,
825
- window_size: int = 0,
826
- use_mask_unit_attn: bool = False,
827
- ):
828
- super().__init__()
829
-
830
- self.dim = dim
831
- self.dim_out = dim_out
832
- self.query_stride = query_stride
833
-
834
- self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
835
- self.attn = HieraMaskUnitAttention(dim, dim_out, num_heads, query_stride, window_size, use_mask_unit_attn)
836
-
837
- self.layernorm_after = nn.LayerNorm(dim_out, eps=config.layer_norm_eps)
838
- self.mlp = HieraMlp(config, dim_out)
839
-
840
- self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
841
- if dim != dim_out:
842
- self.proj = nn.Linear(dim, dim_out)
843
-
844
- def forward(
845
- self,
846
- hidden_states: torch.Tensor,
847
- head_mask: Optional[torch.FloatTensor] = None,
848
- output_attentions: bool = False,
849
- ) -> torch.Tensor:
850
- batch_size, seq_len, _ = hidden_states.shape
851
- # Attention + Q Pooling
852
- hidden_states_norm = self.layernorm_before(hidden_states)
853
- if self.dim != self.dim_out:
854
- hidden_states = self.proj(hidden_states_norm)
855
- # Refer to `HieraUnroll` to see how this performs a maxpool-Nd
856
- hidden_states = hidden_states.view(batch_size, self.query_stride, -1, self.dim_out).max(dim=1).values
857
-
858
- (hidden_states_norm, attn_weights) = self.attn(
859
- hidden_states_norm, head_mask, output_attentions=output_attentions
860
- )
861
- hidden_states = hidden_states + self.drop_path(hidden_states_norm)
862
-
863
- residual = hidden_states
864
- hidden_states = self.layernorm_after(hidden_states)
865
- hidden_states = self.mlp(hidden_states)
866
- hidden_states = residual + self.drop_path(hidden_states)
867
-
868
- return (hidden_states, attn_weights)
869
-
870
-
871
- class HieraStage(nn.Module):
872
- def __init__(
873
- self,
874
- config,
875
- depth: int,
876
- dim: int,
877
- dim_out: int,
878
- num_heads: int,
879
- drop_path: List[float],
880
- query_stride: List[int],
881
- window_size: int,
882
- use_mask_unit_attn: bool,
883
- stage_num: Optional[int] = None,
884
- ) -> None:
885
- super().__init__()
886
- # we need to know if the previous stage used masked attention
887
- # mask unit or global attention.
888
- # lag by 1 layer, so that global attention,
889
- # applied post pooling on lower resolution
890
- previous_stage_used_masked_attention = False
891
- if stage_num is not None:
892
- previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
893
- self.layers = nn.ModuleList(
894
- [
895
- HieraLayer(
896
- config=config,
897
- dim=dim if i == 0 else dim_out,
898
- dim_out=dim_out,
899
- num_heads=num_heads,
900
- drop_path=drop_path[i],
901
- query_stride=query_stride[i],
902
- window_size=window_size,
903
- use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
904
- )
905
- for i in range(depth)
906
- ]
907
- )
908
-
909
- def forward(
910
- self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
911
- ) -> torch.Tensor:
912
- for i, layer_module in enumerate(self.layers):
913
- layer_head_mask = head_mask[i] if head_mask is not None else None
914
- (hidden_states, attn_weights) = layer_module(
915
- hidden_states, layer_head_mask, output_attentions=output_attentions
916
- )
917
-
918
- return hidden_states, attn_weights
919
-
920
-
921
- def undo_windowing(hidden_states: torch.Tensor, shape: List[int], mask_unit_shape: List[int]) -> torch.Tensor:
922
- """
923
- Restore spatial organization by undoing windowed organization of mask units.
924
- """
925
- num_dims = len(shape)
926
- batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
927
- # From: [batch_size, num_mask_unit_height*num_#mask_unit_wdith, mask_unit_height, mask_unit_width, hidden_size]
928
- # To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
929
- num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
930
- hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
931
-
932
- # From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
933
- # To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
934
- permute = (
935
- [0]
936
- + sum(
937
- [list(p) for p in zip(range(1, 1 + num_dims), range(1 + num_dims, 1 + 2 * num_dims))],
938
- [],
939
- )
940
- + [len(hidden_states.shape) - 1]
941
- )
942
- hidden_states = hidden_states.permute(permute).reshape(batch_size, *shape, hidden_size)
943
-
944
- return hidden_states
945
-
946
-
947
- class HieraEncoder(nn.Module):
948
- def __init__(self, config: HieraConfig) -> None:
949
- super().__init__()
950
- self.config = config
951
-
952
- # stochastic depth decay rule
953
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
954
- # query strides rule
955
- stage_ends = [sum(config.depths[:i]) - 1 for i in range(1, len(config.depths) + 1)]
956
- query_pool_layer = [stage_end + 1 for stage_end in stage_ends[: config.num_query_pool]]
957
- query_strides = [
958
- math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(sum(config.depths))
959
- ]
960
-
961
- # Transformer blocks
962
- self.stages = nn.ModuleList()
963
- embed_dim = config.embed_dim
964
-
965
- for idx_stage, depth in enumerate(config.depths):
966
- dim_out = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
967
-
968
- stage = HieraStage(
969
- config=config,
970
- depth=depth,
971
- dim=embed_dim,
972
- dim_out=dim_out,
973
- num_heads=int(config.initial_num_heads * config.num_head_multiplier**idx_stage),
974
- drop_path=dpr[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
975
- query_stride=query_strides[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
976
- window_size=int(math.prod(config.masked_unit_size) * math.prod(config.query_stride) ** -idx_stage),
977
- use_mask_unit_attn=config.masked_unit_attention[idx_stage],
978
- stage_num=idx_stage,
979
- )
980
-
981
- embed_dim = dim_out
982
- self.stages.append(stage)
983
-
984
- # Setting reroll schedule
985
- # The first stage has to reverse everything
986
- # The next stage has to reverse all but the first unroll, etc.
987
- stage_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
988
- unroll_schedule = [config.query_stride] * len(config.depths[:-1])
989
-
990
- self.schedule = {}
991
- for idx_stage in range(len(config.depths)):
992
- self.schedule[idx_stage] = unroll_schedule, stage_size
993
- if idx_stage < config.num_query_pool:
994
- stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
995
- unroll_schedule = unroll_schedule[1:]
996
-
997
- self.gradient_checkpointing = False
998
-
999
- def reroll(
1000
- self, hidden_states: torch.Tensor, stage_idx: int, mask: Optional[torch.BoolTensor] = None
1001
- ) -> torch.Tensor:
1002
- """
1003
- Roll the given tensor back up to spatial order assuming it's from the given block.
1004
-
1005
- If no mask is provided returns:
1006
- - [batch_size, height, width, hidden_size] for 2d
1007
- - [batch_size, frames, height, width, hidden_size] for 3d
1008
- If a mask is provided returns:
1009
- - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size] for 2d
1010
- """
1011
- schedule, size = self.schedule[stage_idx]
1012
- batch_size, seq_len, hidden_size = hidden_states.shape
1013
-
1014
- num_dim = len(size)
1015
- mask_unit_shape = [1] * num_dim
1016
-
1017
- for strides in schedule:
1018
- # Extract the current patch from seq_len
1019
- hidden_states = hidden_states.view(
1020
- batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
1021
- )
1022
-
1023
- # Move that patch into the current MU
1024
- # Example in 2d:
1025
- # Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
1026
- # Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
1027
- L = len(hidden_states.shape)
1028
- permute = (
1029
- [0, 1 + num_dim]
1030
- + sum(
1031
- [list(p) for p in zip(range(1, 1 + num_dim), range(1 + num_dim + 1, L - 1))],
1032
- [],
1033
- )
1034
- + [L - 1]
1035
- )
1036
- hidden_states = hidden_states.permute(permute)
1037
-
1038
- # Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
1039
- for i in range(num_dim):
1040
- mask_unit_shape[i] *= strides[i]
1041
- hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
1042
- seq_len = hidden_states.shape[1]
1043
-
1044
- # Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
1045
- hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
1046
-
1047
- # If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
1048
- if mask is not None:
1049
- return hidden_states
1050
-
1051
- # If not masked, we can return [batch_size, height, width, hidden_size]
1052
- hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
1053
-
1054
- return hidden_states
1055
-
1056
- def forward(
1057
- self,
1058
- hidden_states: torch.Tensor,
1059
- mask: Optional[torch.BoolTensor] = None,
1060
- head_mask: Optional[torch.FloatTensor] = None,
1061
- output_attentions: bool = False,
1062
- output_hidden_states: bool = False,
1063
- return_dict: bool = True,
1064
- ) -> Union[tuple, BaseModelOutput]:
1065
- all_hidden_states = () if output_hidden_states else None
1066
- all_reshaped_hidden_states = () if output_hidden_states else None
1067
- all_self_attentions = () if output_attentions else None
1068
-
1069
- if output_hidden_states:
1070
- all_hidden_states = all_hidden_states + (hidden_states,)
1071
- reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, mask=mask)
1072
- all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
1073
-
1074
- for i, stage_module in enumerate(self.stages):
1075
- layer_head_mask = head_mask[i] if head_mask is not None else None
1076
-
1077
- if self.gradient_checkpointing and self.training:
1078
- layer_outputs = self._gradient_checkpointing_func(
1079
- stage_module.__call__, hidden_states, layer_head_mask, output_attentions
1080
- )
1081
- else:
1082
- layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
1083
-
1084
- hidden_states = layer_outputs[0]
1085
-
1086
- if output_attentions:
1087
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
1088
-
1089
- if output_hidden_states:
1090
- all_hidden_states = all_hidden_states + (hidden_states,)
1091
- reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, mask=mask)
1092
- all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
1093
-
1094
- if not return_dict:
1095
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1096
- return HieraEncoderOutput(
1097
- last_hidden_state=hidden_states,
1098
- hidden_states=all_hidden_states,
1099
- attentions=all_self_attentions,
1100
- reshaped_hidden_states=all_reshaped_hidden_states,
1101
- )
1102
-
1103
-
1104
- def unroll(hidden_states: torch.Tensor, size: List[int], schedule: List[List[int]]) -> torch.Tensor:
1105
- """
1106
- Reorders the tokens such that patches are contiguous in memory.
1107
- E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
1108
- [batch_size, (stride, stride, height // stride, width // stride), hidden_size]
1109
-
1110
- This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
1111
- Not only is this faster, but it also makes it easy to support inputs of arbitrary
1112
- dimensions in addition to patch-wise sparsity.
1113
-
1114
- Performing this operation multiple times in sequence puts entire windows as contiguous
1115
- in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
1116
- size 8x8 would be contiguous in memory, allowing operations like mask unit attention
1117
- computed easily and efficiently, while also allowing max to be applied sequentially.
1118
-
1119
- Note: This means that intermediate values of the model are not in height x width order, so they
1120
- need to be re-rolled if you want to use the intermediate values as a height x width feature map.
1121
- The last block of the network is fine though, since by then the strides are all consumed.
1122
- """
1123
- batch_size, _, hidden_size = hidden_states.shape
1124
-
1125
- current_size = size
1126
- hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
1127
-
1128
- for strides in schedule:
1129
- # Move patches with the given strides to the batch dimension
1130
-
1131
- # Create a view of the tensor with the patch stride as separate dims
1132
- # For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
1133
- current_size = [i // s for i, s in zip(current_size, strides)]
1134
- # initialize new_shape with [height // stride, stride, width // stride, stride]
1135
- new_shape = [item for pair in zip(current_size, strides) for item in pair]
1136
- # add batch_size and hidden_size to new_shape
1137
- new_shape = [batch_size] + new_shape + [hidden_size]
1138
- hidden_states = hidden_states.view(new_shape)
1139
-
1140
- # Move the patch stride into the batch dimension
1141
- # For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
1142
- num_dims = len(new_shape)
1143
- permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
1144
- hidden_states = hidden_states.permute(permute)
1145
-
1146
- # Now finally flatten the relevant dims into the batch dimension
1147
- hidden_states = hidden_states.flatten(0, len(strides))
1148
- batch_size *= math.prod(strides)
1149
-
1150
- hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
1151
- return hidden_states
1152
-
1153
-
1154
- class HieraPreTrainedModel(PreTrainedModel):
1155
- """
1156
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1157
- models.
1158
- """
1159
-
1160
- config_class = HieraConfig
1161
- base_model_prefix = "hiera"
1162
- main_input_name = "pixel_values"
1163
- supports_gradient_checkpointing = True
1164
-
1165
- def _init_weights(self, module) -> None:
1166
- """Initialize the weights"""
1167
- std = self.config.initializer_range
1168
-
1169
- if isinstance(module, HieraEmbeddings):
1170
- if self.config.sep_pos_embed:
1171
- nn.init.trunc_normal_(module.position_embeddings_spatial, std=std)
1172
- nn.init.trunc_normal_(module.position_embeddings_temporal, std=std)
1173
- else:
1174
- nn.init.trunc_normal_(module.position_embeddings, std=std)
1175
-
1176
- elif isinstance(module, HieraDecoder):
1177
- nn.init.trunc_normal_(module.mask_token, std=std)
1178
- nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
1179
-
1180
- elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
1181
- nn.init.trunc_normal_(module.weight, std=std)
1182
- if module.bias is not None:
1183
- nn.init.constant_(module.bias, std)
1184
-
1185
- elif isinstance(module, nn.LayerNorm):
1186
- nn.init.constant_(module.bias, std)
1187
- nn.init.constant_(module.weight, self.config.layer_norm_init)
1188
-
1189
-
1190
- HIERA_START_DOCSTRING = r"""
1191
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
1192
- as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
1193
- behavior.
1194
-
1195
- Parameters:
1196
- config ([`HieraConfig`]): Model configuration class with all the parameters of the model.
1197
- Initializing with a config file does not load the weights associated with the model, only the
1198
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1199
- """
1200
-
1201
- HIERA_INPUTS_DOCSTRING = r"""
1202
- Args:
1203
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1204
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
1205
- for details.
1206
-
1207
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1208
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1209
-
1210
- - 1 indicates the head is **not masked**,
1211
- - 0 indicates the head is **masked**.
1212
-
1213
- output_attentions (`bool`, *optional*):
1214
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1215
- tensors for more detail.
1216
- output_hidden_states (`bool`, *optional*):
1217
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1218
- more detail.
1219
- interpolate_pos_encoding (`bool`, *optional*):
1220
- Whether to interpolate the pre-trained position encodings.
1221
- return_dict (`bool`, *optional*):
1222
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1223
- """
1224
-
1225
-
1226
- class HieraPooler(nn.Module):
1227
- def __init__(self, config: HieraConfig):
1228
- super().__init__()
1229
- num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1230
- self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
1231
- self.pooler = nn.AdaptiveAvgPool1d(1)
1232
-
1233
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1234
- hidden_states = hidden_states.transpose(1, 2)
1235
- pooled_output = self.pooler(hidden_states)
1236
- pooled_output = torch.flatten(pooled_output, 1)
1237
- pooled_output = self.layernorm(pooled_output)
1238
- return pooled_output
1239
-
1240
-
1241
- @add_start_docstrings(
1242
- "The bare Hiera Model transformer outputting raw hidden-states without any specific head on top.",
1243
- HIERA_START_DOCSTRING,
1244
- """
1245
- add_pooling_layer (`bool`, *optional*, defaults to `True`):
1246
- Whether or not to apply pooling layer.
1247
- is_mae (`bool`, *optional*, defaults to `False`):
1248
- Whether or not to run the model on MAE mode.
1249
- """,
1250
- )
1251
- class HieraModel(HieraPreTrainedModel):
1252
- def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
1253
- super().__init__(config)
1254
- self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1255
-
1256
- self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
1257
- self.encoder = HieraEncoder(config)
1258
-
1259
- self.unroll_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
1260
- self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
1261
-
1262
- self.pooler = HieraPooler(config) if add_pooling_layer else None
1263
-
1264
- # Initialize weights and apply final processing
1265
- self.post_init()
1266
-
1267
- def get_input_embeddings(self) -> HieraPatchEmbeddings:
1268
- return self.embeddings.patch_embeddings
1269
-
1270
- def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
1271
- """
1272
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1273
- class PreTrainedModel
1274
- """
1275
- for layer, heads in heads_to_prune.items():
1276
- self.encoder.layer[layer].attention.prune_heads(heads)
1277
-
1278
- @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1279
- @add_code_sample_docstrings(
1280
- checkpoint=_CHECKPOINT_FOR_DOC,
1281
- output_type=HieraModelOutput,
1282
- config_class=_CONFIG_FOR_DOC,
1283
- modality="vision",
1284
- expected_output=_EXPECTED_OUTPUT_SHAPE,
1285
- )
1286
- def forward(
1287
- self,
1288
- pixel_values: Optional[torch.Tensor] = None,
1289
- noise: Optional[torch.FloatTensor] = None,
1290
- head_mask: Optional[torch.Tensor] = None,
1291
- output_attentions: Optional[bool] = None,
1292
- output_hidden_states: Optional[bool] = None,
1293
- interpolate_pos_encoding: Optional[bool] = None,
1294
- return_dict: Optional[bool] = None,
1295
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
1296
- r"""
1297
- noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
1298
- mainly used for testing purposes to control randomness and maintain the reproducibility
1299
- when is_mae is set to True.
1300
- """
1301
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1302
- output_hidden_states = (
1303
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1304
- )
1305
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1306
-
1307
- if pixel_values is None:
1308
- raise ValueError("You have to specify pixel_values")
1309
-
1310
- # Prepare head mask if needed
1311
- # 1.0 in head_mask indicate we keep the head
1312
- # attention_probs has shape bsz x n_heads x N x N
1313
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1314
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1315
- head_mask = self.get_head_mask(head_mask, len(self.config.depths))
1316
-
1317
- # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
1318
- expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
1319
- if pixel_values.dtype != expected_dtype:
1320
- pixel_values = pixel_values.to(expected_dtype)
1321
-
1322
- embedding_output, mask, ids_restore = self.embeddings(
1323
- pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
1324
- )
1325
-
1326
- hidden_states = unroll(embedding_output, self.unroll_size, self.unroll_schedule)
1327
-
1328
- # Discard masked tokens if mask is provided
1329
- if mask is not None:
1330
- mask_unit_area = math.prod(self.config.masked_unit_size)
1331
- batch_size, _, hidden_size = hidden_states.shape
1332
- positions = mask.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
1333
- positions = positions.bool()
1334
- hidden_states = hidden_states[positions]
1335
- hidden_states = hidden_states.view(batch_size, -1, hidden_size)
1336
-
1337
- encoder_outputs = self.encoder(
1338
- hidden_states,
1339
- mask=mask,
1340
- head_mask=head_mask,
1341
- output_attentions=output_attentions,
1342
- output_hidden_states=output_hidden_states,
1343
- return_dict=return_dict,
1344
- )
1345
- sequence_output = encoder_outputs[0]
1346
- pooled_output = None
1347
- if self.pooler is not None:
1348
- pooled_output = self.pooler(sequence_output)
1349
-
1350
- if not return_dict:
1351
- head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
1352
- head_outputs = head_outputs + (mask, ids_restore) if mask is not None else head_outputs
1353
- return head_outputs + encoder_outputs[1:]
1354
-
1355
- return HieraModelOutput(
1356
- last_hidden_state=sequence_output,
1357
- pooler_output=pooled_output,
1358
- mask=mask,
1359
- ids_restore=ids_restore,
1360
- hidden_states=encoder_outputs.hidden_states,
1361
- attentions=encoder_outputs.attentions,
1362
- reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
1363
- )
1364
-
1365
-
1366
- class HieraDecoder(nn.Module):
1367
- def __init__(self, config: HieraConfig):
1368
- super().__init__()
1369
- num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1370
- self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
1371
- self.tokens_spatial_shape_final = [
1372
- i // s ** (config.num_query_pool) for i, s in zip(self.tokens_spatial_shape, config.query_stride)
1373
- ]
1374
- self.mask_unit_spatial_shape_final = [
1375
- i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
1376
- ]
1377
-
1378
- self.decoder_embeddings = nn.Linear(num_features, config.decoder_embed_dim)
1379
-
1380
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_embed_dim))
1381
-
1382
- self.decoder_position_embeddings = nn.Parameter(
1383
- torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_embed_dim)
1384
- )
1385
-
1386
- self.decoder_block = HieraStage(
1387
- config=config,
1388
- dim=config.decoder_embed_dim,
1389
- dim_out=config.decoder_embed_dim,
1390
- num_heads=config.decoder_num_heads,
1391
- depth=config.decoder_depth,
1392
- use_mask_unit_attn=False,
1393
- drop_path=[0.0] * config.decoder_depth,
1394
- query_stride=[1] * config.decoder_depth,
1395
- window_size=0,
1396
- )
1397
-
1398
- self.decoder_norm = nn.LayerNorm(config.decoder_embed_dim, eps=config.layer_norm_eps)
1399
-
1400
- # patch stride of prediction
1401
- self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
1402
- pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
1403
-
1404
- self.decoder_pred = nn.Linear(config.decoder_embed_dim, pred_dim)
1405
-
1406
- def forward(
1407
- self,
1408
- encoder_hidden_states: torch.Tensor,
1409
- mask: torch.BoolTensor,
1410
- head_mask: Optional[torch.Tensor] = None,
1411
- output_attentions: bool = False,
1412
- ) -> torch.Tensor:
1413
- # Embed tokens
1414
- hidden_states = self.decoder_embeddings(encoder_hidden_states)
1415
-
1416
- # Combine visible and mask tokens
1417
-
1418
- # hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_embed_dim]
1419
- # mask: [batch_size, num_mask_units]
1420
- decoder_hidden_states = torch.zeros(
1421
- *mask.shape, *hidden_states.shape[2:], device=hidden_states.device, dtype=hidden_states.dtype
1422
- )
1423
- mask_tokens = self.mask_token.view((1,) * (len(mask.shape) + len(hidden_states.shape[2:-1])) + (-1,))
1424
- new_mask_shape = mask.shape + (1,) * len(hidden_states.shape[2:])
1425
- mask = mask.reshape(new_mask_shape)
1426
- expand_shape = (-1,) * 2 + hidden_states.shape[2:]
1427
- mask = mask.expand(expand_shape)
1428
- decoder_hidden_states[mask.bool()] = hidden_states.flatten()
1429
- decoder_hidden_states = (1 - mask) * mask_tokens + mask * decoder_hidden_states
1430
-
1431
- # Get back spatial order
1432
- hidden_states = undo_windowing(
1433
- decoder_hidden_states,
1434
- self.tokens_spatial_shape_final,
1435
- self.mask_unit_spatial_shape_final,
1436
- )
1437
- mask = undo_windowing(
1438
- mask[..., 0:1],
1439
- self.tokens_spatial_shape_final,
1440
- self.mask_unit_spatial_shape_final,
1441
- )
1442
-
1443
- # Flatten
1444
- hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
1445
- mask = mask.view(hidden_states.shape[0], -1)
1446
-
1447
- # Add pos embed
1448
- hidden_states = hidden_states + self.decoder_position_embeddings
1449
-
1450
- # Apply decoder blocks
1451
- hidden_states, attn_weights = self.decoder_block(
1452
- hidden_states, head_mask=head_mask, output_attentions=output_attentions
1453
- )
1454
- hidden_states = self.decoder_norm(hidden_states)
1455
-
1456
- # Predictor projection
1457
- hidden_states = self.decoder_pred(hidden_states)
1458
-
1459
- return hidden_states, mask
1460
-
1461
-
1462
- class HieraMultiScaleHead(nn.Module):
1463
- def __init__(self, config: HieraConfig):
1464
- super().__init__()
1465
- self.mask_unit_spatial_shape_final = [
1466
- i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
1467
- ]
1468
- self.stage_dimensions = [
1469
- int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
1470
- ]
1471
- current_masked_unit_size = config.masked_unit_size
1472
- self.multi_scale_fusion_heads = nn.ModuleList()
1473
-
1474
- for idx in range(config.num_query_pool):
1475
- kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
1476
- current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
1477
- self.multi_scale_fusion_heads.append(
1478
- conv_nd(len(config.query_stride))(
1479
- self.stage_dimensions[idx],
1480
- self.stage_dimensions[-1],
1481
- kernel_size=kernel,
1482
- stride=kernel,
1483
- )
1484
- )
1485
- self.multi_scale_fusion_heads.append(nn.Identity())
1486
-
1487
- def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
1488
- if isinstance(head, nn.Identity):
1489
- return hidden_states
1490
-
1491
- batch_size, num_mask_units = hidden_states.shape[0:2]
1492
- # From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
1493
- # To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
1494
- permute = [0] + [len(hidden_states.shape) - 2] + list(range(1, len(hidden_states.shape) - 2))
1495
- hidden_states = hidden_states.reshape(batch_size * num_mask_units, *hidden_states.shape[2:])
1496
- hidden_states = hidden_states.permute(permute)
1497
- hidden_states = head(hidden_states)
1498
-
1499
- # Restore original layout
1500
- permute = [0] + list(range(2, len(hidden_states.shape))) + [1]
1501
- hidden_states = hidden_states.permute(permute)
1502
- hidden_states = hidden_states.reshape(
1503
- batch_size, num_mask_units, *hidden_states.shape[1:-1], hidden_states.shape[-1]
1504
- )
1505
- return hidden_states
1506
-
1507
- def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor:
1508
- # Multi-scale fusion
1509
- hidden_states = 0.0
1510
- for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
1511
- hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
1512
-
1513
- return hidden_states
1514
-
1515
-
1516
- @add_start_docstrings(
1517
- """The Hiera Model transformer with the decoder on top for self-supervised pre-training.
1518
-
1519
- <Tip>
1520
-
1521
- Note that we provide a script to pre-train this model on custom data in our [examples
1522
- directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
1523
-
1524
- </Tip>
1525
- """,
1526
- HIERA_START_DOCSTRING,
1527
- )
1528
- class HieraForPreTraining(HieraPreTrainedModel):
1529
- def __init__(self, config: HieraConfig) -> None:
1530
- super().__init__(config)
1531
- # Encoder
1532
- self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
1533
- self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
1534
- # Multi-scale fusion heads
1535
- self.multiscale_fusion = HieraMultiScaleHead(config)
1536
- # Decoder
1537
- self.decoder = HieraDecoder(config)
1538
- self.pred_stride = self.decoder.pred_stride
1539
-
1540
- # Initialize weights and apply final processing
1541
- self.post_init()
1542
-
1543
- def get_pixel_label_2d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1544
- # mask (boolean tensor): True means *masked*
1545
- pixel_values = pixel_values.permute(0, 2, 3, 1)
1546
-
1547
- size = self.pred_stride
1548
- label = pixel_values.unfold(1, size, size).unfold(2, size, size)
1549
- label = label.flatten(1, 2).flatten(2)
1550
- label = label[mask.bool()]
1551
- if self.config.norm_pix_loss:
1552
- mean = label.mean(dim=-1, keepdim=True)
1553
- var = label.var(dim=-1, keepdim=True)
1554
- label = (label - mean) / (var + 1.0e-6) ** 0.5
1555
-
1556
- return label
1557
-
1558
- def get_pixel_label_3d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1559
- # mask (boolean tensor): True means *masked*
1560
- pixel_values = pixel_values[:, :, :: self.patch_stride[0], :, :]
1561
-
1562
- size = self.pred_stride
1563
- label = pixel_values.unfold(3, size, size).unfold(4, size, size)
1564
- # Different from 2D
1565
- label = label.permute(0, 2, 3, 4, 5, 6, 1)
1566
- label = label.flatten(1, 3).flatten(2)
1567
- label = label[mask.bool()]
1568
- if self.config.norm_pix_loss:
1569
- mean = label.mean(dim=-1, keepdim=True)
1570
- var = label.var(dim=-1, keepdim=True)
1571
- label = (label - mean) / (var + 1.0e-6) ** 0.5
1572
-
1573
- return label
1574
-
1575
- def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, mask: torch.BoolTensor):
1576
- # We invert the mask such that 1.0 is *masked*
1577
- mask = 1 - mask
1578
- if len(self.config.query_stride) == 2:
1579
- label = self.get_pixel_label_2d(pixel_values, mask)
1580
- elif len(self.config.query_stride) == 3:
1581
- label = self.get_pixel_label_3d(pixel_values, mask)
1582
- else:
1583
- raise NotImplementedError("Only images and videos are supported")
1584
-
1585
- logits = logits[mask.bool()]
1586
- loss = (logits - label) ** 2
1587
- loss = loss.mean()
1588
-
1589
- return loss
1590
-
1591
- @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1592
- @replace_return_docstrings(output_type=HieraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1593
- def forward(
1594
- self,
1595
- pixel_values: Optional[torch.Tensor] = None,
1596
- noise: Optional[torch.FloatTensor] = None,
1597
- head_mask: Optional[torch.Tensor] = None,
1598
- output_attentions: Optional[bool] = None,
1599
- output_hidden_states: Optional[bool] = None,
1600
- interpolate_pos_encoding: Optional[bool] = None,
1601
- return_dict: Optional[bool] = None,
1602
- ) -> Union[tuple, HieraForPreTrainingOutput]:
1603
- r"""
1604
- noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
1605
- mainly used for testing purposes to control randomness and maintain the reproducibility
1606
- when is_mae is set to True.
1607
-
1608
- Returns:
1609
-
1610
- Examples:
1611
- ```python
1612
- >>> from transformers import AutoImageProcessor, HieraForPreTraining
1613
- >>> import torch
1614
- >>> from PIL import Image
1615
- >>> import requests
1616
-
1617
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1618
- >>> image = Image.open(requests.get(url, stream=True).raw)
1619
-
1620
- >>> image_processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
1621
- >>> model = HieraForPreTraining.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
1622
-
1623
- >>> inputs = image_processor(images=image, return_tensors="pt")
1624
-
1625
- >>> outputs = model(**inputs)
1626
- >>> logits = outputs.logits
1627
- >>> list(logits.shape)
1628
- [1, 196, 768]
1629
- ```"""
1630
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1631
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1632
- output_hidden_states = (
1633
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1634
- )
1635
-
1636
- outputs = self.hiera(
1637
- pixel_values,
1638
- noise=noise,
1639
- head_mask=head_mask,
1640
- output_attentions=output_attentions,
1641
- output_hidden_states=True,
1642
- interpolate_pos_encoding=interpolate_pos_encoding,
1643
- return_dict=True,
1644
- )
1645
-
1646
- feature_maps = outputs.reshaped_hidden_states
1647
- mask = outputs.mask
1648
- ids_to_restore = outputs.ids_restore
1649
- # Take only the query pooled and last hidden states
1650
- feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
1651
- fused_hidden_states = self.multiscale_fusion(feature_maps)
1652
- fused_hidden_states = self.encoder_norm(fused_hidden_states)
1653
-
1654
- # Reconstruct pixel values
1655
- logits, mask = self.decoder(
1656
- fused_hidden_states,
1657
- mask=mask,
1658
- head_mask=head_mask,
1659
- output_attentions=output_attentions,
1660
- )
1661
-
1662
- loss = self.forward_loss(pixel_values, logits, mask)
1663
-
1664
- if not return_dict:
1665
- output = (logits, mask, ids_to_restore)
1666
- if output_hidden_states:
1667
- output = output + (outputs.hidden_states,)
1668
- if output_attentions:
1669
- output = output + (outputs.attentions,)
1670
- if output_hidden_states:
1671
- output = output + (outputs.reshaped_hidden_states,)
1672
- return ((loss,) + output) if loss is not None else output
1673
-
1674
- return HieraForPreTrainingOutput(
1675
- loss=loss,
1676
- logits=logits,
1677
- mask=mask,
1678
- ids_restore=ids_to_restore,
1679
- hidden_states=outputs.hidden_states if output_hidden_states else None,
1680
- attentions=outputs.attentions,
1681
- reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
1682
- )
1683
-
1684
-
1685
- @add_start_docstrings(
1686
- """
1687
- Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
1688
- average pooling) e.g. for ImageNet.
1689
-
1690
- <Tip>
1691
-
1692
- Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
1693
- setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
1694
- position embeddings to the higher resolution.
1695
-
1696
- </Tip>
1697
- """,
1698
- HIERA_START_DOCSTRING,
1699
- )
1700
- class HieraForImageClassification(HieraPreTrainedModel):
1701
- def __init__(self, config: HieraConfig) -> None:
1702
- super().__init__(config)
1703
-
1704
- self.num_labels = config.num_labels
1705
- self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
1706
-
1707
- # Classifier head
1708
- self.classifier = (
1709
- nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
1710
- )
1711
-
1712
- # Initialize weights and apply final processing
1713
- self.post_init()
1714
-
1715
- @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1716
- @add_code_sample_docstrings(
1717
- checkpoint=_IMAGE_CLASS_CHECKPOINT,
1718
- output_type=HieraForImageClassificationOutput,
1719
- config_class=_CONFIG_FOR_DOC,
1720
- expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1721
- )
1722
- def forward(
1723
- self,
1724
- pixel_values: Optional[torch.Tensor] = None,
1725
- head_mask: Optional[torch.Tensor] = None,
1726
- labels: Optional[torch.Tensor] = None,
1727
- output_attentions: Optional[bool] = None,
1728
- output_hidden_states: Optional[bool] = None,
1729
- interpolate_pos_encoding: Optional[bool] = None,
1730
- return_dict: Optional[bool] = None,
1731
- ) -> Union[tuple, HieraForImageClassificationOutput]:
1732
- r"""
1733
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1734
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1735
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1736
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1737
- """
1738
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1739
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1740
- output_hidden_states = (
1741
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1742
- )
1743
-
1744
- outputs = self.hiera(
1745
- pixel_values,
1746
- head_mask=head_mask,
1747
- output_attentions=output_attentions,
1748
- output_hidden_states=output_hidden_states,
1749
- interpolate_pos_encoding=interpolate_pos_encoding,
1750
- return_dict=return_dict,
1751
- )
1752
-
1753
- pooled_output = outputs[1]
1754
-
1755
- logits = self.classifier(pooled_output)
1756
-
1757
- loss = None
1758
- if labels is not None:
1759
- # move labels to correct device to enable model parallelism
1760
- labels = labels.to(logits.device)
1761
- if self.config.problem_type is None:
1762
- if self.num_labels == 1:
1763
- self.config.problem_type = "regression"
1764
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1765
- self.config.problem_type = "single_label_classification"
1766
- else:
1767
- self.config.problem_type = "multi_label_classification"
1768
-
1769
- if self.config.problem_type == "regression":
1770
- loss_fct = MSELoss()
1771
- if self.num_labels == 1:
1772
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1773
- else:
1774
- loss = loss_fct(logits, labels)
1775
- elif self.config.problem_type == "single_label_classification":
1776
- loss_fct = CrossEntropyLoss()
1777
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1778
- elif self.config.problem_type == "multi_label_classification":
1779
- loss_fct = BCEWithLogitsLoss()
1780
- loss = loss_fct(logits, labels)
1781
-
1782
- if not return_dict:
1783
- output = (logits,) + outputs[4:]
1784
- return ((loss,) + output) if loss is not None else output
1785
-
1786
- return HieraForImageClassificationOutput(
1787
- loss=loss,
1788
- logits=logits,
1789
- hidden_states=outputs.hidden_states,
1790
- attentions=outputs.attentions,
1791
- reshaped_hidden_states=outputs.reshaped_hidden_states,
1792
- )
1793
-
1794
-
1795
- @add_start_docstrings(
1796
- """
1797
- Hiera backbone, to be used with frameworks like DETR and MaskFormer.
1798
- """,
1799
- HIERA_START_DOCSTRING,
1800
- )
1801
- class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
1802
- def __init__(self, config: HieraConfig):
1803
- super().__init__(config)
1804
- super()._init_backbone(config)
1805
-
1806
- self.num_features = [config.embed_dim] + [
1807
- int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
1808
- ]
1809
- self.embeddings = HieraEmbeddings(config, is_mae=False)
1810
- self.encoder = HieraEncoder(config)
1811
-
1812
- # Add layer norms to hidden states of out_features
1813
- hidden_states_norms = {}
1814
- for stage, num_channels in zip(self._out_features, self.channels):
1815
- hidden_states_norms[stage] = nn.LayerNorm(num_channels)
1816
- self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
1817
-
1818
- # Initialize weights and apply final processing
1819
- self.post_init()
1820
-
1821
- def get_input_embeddings(self):
1822
- return self.embeddings.patch_embeddings
1823
-
1824
- def forward(
1825
- self,
1826
- pixel_values: torch.Tensor,
1827
- output_hidden_states: Optional[bool] = None,
1828
- output_attentions: Optional[bool] = None,
1829
- return_dict: Optional[bool] = None,
1830
- ) -> BackboneOutput:
1831
- """
1832
- Returns:
1833
-
1834
- Examples:
1835
-
1836
- ```python
1837
- >>> from transformers import AutoImageProcessor, AutoBackbone
1838
- >>> import torch
1839
- >>> from PIL import Image
1840
- >>> import requests
1841
-
1842
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1843
- >>> image = Image.open(requests.get(url, stream=True).raw)
1844
-
1845
- >>> processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224")
1846
- >>> model = AutoBackbone.from_pretrained(
1847
- ... "EduardoPacheco/hiera-tiny-224", out_features=["stage1", "stage2", "stage3", "stage4"]
1848
- ... )
1849
-
1850
- >>> inputs = processor(image, return_tensors="pt")
1851
- >>> outputs = model(**inputs)
1852
- >>> feature_maps = outputs.feature_maps
1853
- >>> list(feature_maps[-1].shape)
1854
- [1, 768, 7, 7]
1855
- ```"""
1856
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1857
- output_hidden_states = (
1858
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1859
- )
1860
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1861
-
1862
- embedding_output, _, _ = self.embeddings(pixel_values)
1863
-
1864
- outputs = self.encoder(
1865
- embedding_output,
1866
- head_mask=None,
1867
- output_attentions=output_attentions,
1868
- output_hidden_states=True,
1869
- return_dict=True,
1870
- )
1871
-
1872
- hidden_states = outputs.reshaped_hidden_states
1873
-
1874
- feature_maps = ()
1875
- for stage, hidden_state in zip(self.stage_names, hidden_states):
1876
- if stage in self.out_features:
1877
- batch_size, height, width, num_channels = hidden_state.shape
1878
- hidden_state = hidden_state.view(batch_size, height * width, num_channels)
1879
- hidden_state = self.hidden_states_norms[stage](hidden_state)
1880
- hidden_state = hidden_state.view(batch_size, height, width, num_channels)
1881
- hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1882
- feature_maps += (hidden_state,)
1883
-
1884
- if not return_dict:
1885
- output = (feature_maps,)
1886
- if output_hidden_states:
1887
- output += (outputs.hidden_states,)
1888
- return output
1889
-
1890
- return BackboneOutput(
1891
- feature_maps=feature_maps,
1892
- hidden_states=outputs.hidden_states if output_hidden_states else None,
1893
- attentions=outputs.attentions,
1894
- )
1895
  # %%
1896
 
1897
 
@@ -1910,8 +28,8 @@ class PytorchWorker:
1910
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1911
  print(f"Using devide: {self.device}")
1912
 
1913
- image_processor = AutoImageProcessor.from_pretrained("./hiera_model/")
1914
- model = HieraForImageClassification.from_pretrained("./hiera_model/", num_labels =1784 ).to(self.device).eval()
1915
 
1916
  return model, image_processor
1917
 
@@ -1922,7 +40,7 @@ class PytorchWorker:
1922
  :param image: Input image as numpy array.
1923
  :return: A list with logits and confidences.
1924
  """
1925
- inputs = self.image_processor(images=image, return_tensors="pt")
1926
  outputs = self.model(**inputs)
1927
  logits = outputs.logits
1928
  return logits.tolist()
@@ -1968,7 +86,7 @@ if __name__ == "__main__":
1968
  model_name=MODEL_NAME
1969
  )
1970
 
1971
-
1972
  # import requests
1973
  # image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
1974
  # # %%
@@ -1982,4 +100,4 @@ if __name__ == "__main__":
1982
  # # %%
1983
  # import numpy as np
1984
  # np.argmax(output)
1985
- # # %%
 
8
  from PIL import Image
9
  import torch
10
  from transformers import AutoImageProcessor
11
+ from submission.create_model import HieraForImageClassification
12
  #%%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # %%
14
 
15
 
 
28
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
  print(f"Using devide: {self.device}")
30
 
31
+ image_processor = AutoImageProcessor.from_pretrained("./submission/hiera_model/")
32
+ model = HieraForImageClassification.from_pretrained("./submission/hiera_model/", num_labels =1784 ).to(self.device).eval()
33
 
34
  return model, image_processor
35
 
 
40
  :param image: Input image as numpy array.
41
  :return: A list with logits and confidences.
42
  """
43
+ inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
44
  outputs = self.model(**inputs)
45
  logits = outputs.logits
46
  return logits.tolist()
 
86
  model_name=MODEL_NAME
87
  )
88
 
89
+ #%%
90
  # import requests
91
  # image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
92
  # # %%
 
100
  # # %%
101
  # import numpy as np
102
  # np.argmax(output)
103
+ # %%