Someshfengde
commited on
Commit
•
a139ac6
1
Parent(s):
31f23f1
Upload folder using huggingface_hub
Browse files
script.py
CHANGED
@@ -8,1890 +8,8 @@ import torchvision.transforms as T
|
|
8 |
from PIL import Image
|
9 |
import torch
|
10 |
from transformers import AutoImageProcessor
|
11 |
-
|
12 |
#%%
|
13 |
-
# coding=utf-8
|
14 |
-
# Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
|
15 |
-
#
|
16 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
17 |
-
# you may not use this file except in compliance with the License.
|
18 |
-
# You may obtain a copy of the License at
|
19 |
-
#
|
20 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
21 |
-
#
|
22 |
-
# Unless required by applicable law or agreed to in writing, software
|
23 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
24 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
25 |
-
# See the License for the specific language governing permissions and
|
26 |
-
# limitations under the License.
|
27 |
-
""" PyTorch Hiera model."""
|
28 |
-
|
29 |
-
|
30 |
-
import math
|
31 |
-
from dataclasses import dataclass
|
32 |
-
from typing import Dict, List, Optional, Tuple, Union
|
33 |
-
|
34 |
-
import torch
|
35 |
-
import torch.utils.checkpoint
|
36 |
-
from torch import nn
|
37 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
38 |
-
|
39 |
-
import transformers
|
40 |
-
|
41 |
-
from transformers.activations import ACT2FN
|
42 |
-
from transformers.modeling_outputs import (
|
43 |
-
BackboneOutput,
|
44 |
-
BaseModelOutput,
|
45 |
-
BaseModelOutputWithPooling,
|
46 |
-
ImageClassifierOutput,
|
47 |
-
ModelOutput,
|
48 |
-
)
|
49 |
-
from transformers.modeling_utils import PreTrainedModel
|
50 |
-
from transformers.utils import (
|
51 |
-
add_code_sample_docstrings,
|
52 |
-
add_start_docstrings,
|
53 |
-
add_start_docstrings_to_model_forward,
|
54 |
-
logging,
|
55 |
-
replace_return_docstrings,
|
56 |
-
)
|
57 |
-
from transformers.utils.backbone_utils import BackboneMixin
|
58 |
-
# coding=utf-8
|
59 |
-
# Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
|
60 |
-
#
|
61 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
62 |
-
# you may not use this file except in compliance with the License.
|
63 |
-
# You may obtain a copy of the License at
|
64 |
-
#
|
65 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
66 |
-
#
|
67 |
-
# Unless required by applicable law or agreed to in writing, software
|
68 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
69 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
70 |
-
# See the License for the specific language governing permissions and
|
71 |
-
# limitations under the License.
|
72 |
-
""" Hiera model configuration"""
|
73 |
-
|
74 |
-
from collections import OrderedDict
|
75 |
-
from typing import Mapping
|
76 |
-
|
77 |
-
from packaging import version
|
78 |
-
|
79 |
-
from transformers.configuration_utils import PretrainedConfig
|
80 |
-
from transformers.onnx import OnnxConfig
|
81 |
-
from transformers.utils import logging
|
82 |
-
from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
83 |
-
|
84 |
-
|
85 |
-
logger = logging.get_logger(__name__)
|
86 |
-
|
87 |
-
HIERA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
88 |
-
"EduardoPacheco/hiera-tiny-224": "https://huggingface.co/EduardoPacheco/hiera-tiny-224/resolve/main/config.json",
|
89 |
-
}
|
90 |
-
|
91 |
-
|
92 |
-
class HieraConfig(BackboneConfigMixin, PretrainedConfig):
|
93 |
-
r"""
|
94 |
-
This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate an Hiera
|
95 |
-
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
96 |
-
defaults will yield a similar configuration to that of the Hiera
|
97 |
-
[EduardoPacheco/hiera-base-224](https://huggingface.co/EduardoPacheco/hiera-base-224) architecture.
|
98 |
-
|
99 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
100 |
-
documentation from [`PretrainedConfig`] for more information.
|
101 |
-
|
102 |
-
|
103 |
-
Args:
|
104 |
-
embed_dim (`int`, *optional*, defaults to 96):
|
105 |
-
Dimensionality of patch embedding.
|
106 |
-
input_size (`list(int)`, *optional*, defaults to `[224, 224]`):
|
107 |
-
The size (resolution) of input in the format (height, width) for images
|
108 |
-
and (frames, height, width) for videos.
|
109 |
-
patch_kernel (`list(int)`, *optional*, defaults to `[7, 7]`):
|
110 |
-
The size (resolution) of each patch.
|
111 |
-
patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
|
112 |
-
The stride of the patch.
|
113 |
-
patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
|
114 |
-
The padding of the patch.
|
115 |
-
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
116 |
-
The ratio of mlp hidden dim to embedding dim.
|
117 |
-
depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
|
118 |
-
Depth of each layer in the Transformer encoder.
|
119 |
-
initial_num_heads (`int`, *optional*, defaults to 1):
|
120 |
-
Initial number of attention heads in the first layer of the Transformer encoder.
|
121 |
-
num_head_multiplier (`float`, *optional*, defaults to 2.0):
|
122 |
-
The multiplier to the number of attention heads in each layer of the Transformer encoder.
|
123 |
-
embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
|
124 |
-
The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
|
125 |
-
num_query_pool (`int`, *optional*, defaults to 3):
|
126 |
-
The number of query pool stages.
|
127 |
-
query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
|
128 |
-
The stride of the query pool.
|
129 |
-
masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
|
130 |
-
The size of the masked unit.
|
131 |
-
masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
|
132 |
-
Whether to use masked unit attention in each layer of the Transformer encoder.
|
133 |
-
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
134 |
-
The drop path rate.
|
135 |
-
sep_pos_embed (`bool`, *optional*, defaults to `False`):
|
136 |
-
Whether to use separate position embedding for temporal and spatial dimensions. Must be `True` for videos.
|
137 |
-
and `False` for images.
|
138 |
-
num_channels (`int`, *optional*, defaults to 3):
|
139 |
-
The number of input channels.
|
140 |
-
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
141 |
-
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
|
142 |
-
`"selu"` and `"gelu_new"` are supported.
|
143 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
144 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
|
145 |
-
the zero_initializer for initializing all bias vectors.
|
146 |
-
layer_norm_init (`float`, *optional*, defaults to 1.0):
|
147 |
-
The initial weight value for layer normalization layers.
|
148 |
-
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
149 |
-
The epsilon used by the layer normalization layers.
|
150 |
-
decoder_embed_dim (`int`, *optional*):
|
151 |
-
Dimensionality of decoder embeddings for MAE pretraining.
|
152 |
-
decoder_depth (`int`, *optional*):
|
153 |
-
Depth of the decoder for MAE pretraining.
|
154 |
-
decoder_num_heads (`int`, *optional*):
|
155 |
-
Number of attention heads in each layer of the decoder for MAE pretraining.
|
156 |
-
norm_pix_loss (`bool`, *optional*, defaults to `True`):
|
157 |
-
Whether to normalize the pixel loss by the number of pixels.
|
158 |
-
mask_ratio (`float`, *optional*, defaults to 0.6):
|
159 |
-
The ratio of masked tokens in the input.
|
160 |
-
out_features (`List[str]`, *optional*):
|
161 |
-
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
162 |
-
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
163 |
-
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
164 |
-
same order as defined in the `stage_names` attribute.
|
165 |
-
out_indices (`List[int]`, *optional*):
|
166 |
-
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
167 |
-
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
168 |
-
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
169 |
-
same order as defined in the `stage_names` attribute.
|
170 |
-
|
171 |
-
|
172 |
-
Example:
|
173 |
-
|
174 |
-
```python
|
175 |
-
>>> from transformers import HieraConfig, HieraModel
|
176 |
-
|
177 |
-
>>> # Initializing a Hiera hiera-base-patch16-224 style configuration
|
178 |
-
>>> configuration = HieraConfig()
|
179 |
-
|
180 |
-
>>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
|
181 |
-
>>> model = HieraModel(configuration)
|
182 |
-
|
183 |
-
>>> # Accessing the model configuration
|
184 |
-
>>> configuration = model.config
|
185 |
-
```"""
|
186 |
-
|
187 |
-
model_type = "hiera"
|
188 |
-
|
189 |
-
attribute_map = {"num_hidden_layers": "num_layers"}
|
190 |
-
|
191 |
-
def __init__(
|
192 |
-
self,
|
193 |
-
embed_dim=96,
|
194 |
-
input_size=[224, 224],
|
195 |
-
patch_kernel=[7, 7],
|
196 |
-
patch_stride=[4, 4],
|
197 |
-
patch_padding=[3, 3],
|
198 |
-
mlp_ratio=4.0,
|
199 |
-
depths=[2, 3, 16, 3],
|
200 |
-
initial_num_heads=1,
|
201 |
-
num_head_multiplier=2.0,
|
202 |
-
embed_dim_multiplier=2.0,
|
203 |
-
num_query_pool=3,
|
204 |
-
query_stride=[2, 2],
|
205 |
-
masked_unit_size=[8, 8],
|
206 |
-
masked_unit_attention=[True, True, False, False],
|
207 |
-
drop_path_rate=0.0,
|
208 |
-
sep_pos_embed=False,
|
209 |
-
num_channels=3,
|
210 |
-
hidden_act="gelu",
|
211 |
-
initializer_range=0.02,
|
212 |
-
layer_norm_init=1.0,
|
213 |
-
layer_norm_eps=1e-6,
|
214 |
-
decoder_embed_dim=None,
|
215 |
-
decoder_depth=None,
|
216 |
-
decoder_num_heads=None,
|
217 |
-
norm_pix_loss=True,
|
218 |
-
mask_ratio=0.6,
|
219 |
-
out_features=None,
|
220 |
-
out_indices=None,
|
221 |
-
**kwargs,
|
222 |
-
):
|
223 |
-
super().__init__(**kwargs)
|
224 |
-
if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
|
225 |
-
raise ValueError(
|
226 |
-
f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
|
227 |
-
f"raised to the power of the number of layers ({len(depths) - 1})"
|
228 |
-
)
|
229 |
-
|
230 |
-
if num_query_pool >= len(depths):
|
231 |
-
raise ValueError(
|
232 |
-
f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
|
233 |
-
)
|
234 |
-
|
235 |
-
self.embed_dim = embed_dim
|
236 |
-
self.input_size = input_size
|
237 |
-
self.patch_kernel = patch_kernel
|
238 |
-
self.patch_stride = patch_stride
|
239 |
-
self.patch_padding = patch_padding
|
240 |
-
self.mlp_ratio = mlp_ratio
|
241 |
-
self.depths = depths
|
242 |
-
self.num_layers = len(depths)
|
243 |
-
self.initial_num_heads = initial_num_heads
|
244 |
-
self.num_head_multiplier = num_head_multiplier
|
245 |
-
self.embed_dim_multiplier = embed_dim_multiplier
|
246 |
-
self.num_query_pool = num_query_pool
|
247 |
-
self.query_stride = query_stride
|
248 |
-
self.masked_unit_size = masked_unit_size
|
249 |
-
self.masked_unit_attention = masked_unit_attention
|
250 |
-
self.drop_path_rate = drop_path_rate
|
251 |
-
self.sep_pos_embed = sep_pos_embed
|
252 |
-
self.num_channels = num_channels
|
253 |
-
self.hidden_act = hidden_act
|
254 |
-
self.initializer_range = initializer_range
|
255 |
-
self.layer_norm_init = layer_norm_init
|
256 |
-
self.layer_norm_eps = layer_norm_eps
|
257 |
-
self.decoder_embed_dim = decoder_embed_dim
|
258 |
-
self.decoder_depth = decoder_depth
|
259 |
-
self.decoder_num_heads = decoder_num_heads
|
260 |
-
self.norm_pix_loss = norm_pix_loss
|
261 |
-
self.mask_ratio = mask_ratio
|
262 |
-
# we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
|
263 |
-
# this indicates the channel dimension after the last stage of the model
|
264 |
-
self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
|
265 |
-
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
266 |
-
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
267 |
-
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
268 |
-
)
|
269 |
-
|
270 |
-
|
271 |
-
class HieraOnnxConfig(OnnxConfig):
|
272 |
-
torch_onnx_minimum_version = version.parse("1.11")
|
273 |
-
|
274 |
-
@property
|
275 |
-
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
276 |
-
return OrderedDict(
|
277 |
-
[
|
278 |
-
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
279 |
-
]
|
280 |
-
)
|
281 |
-
|
282 |
-
@property
|
283 |
-
def atol_for_validation(self) -> float:
|
284 |
-
return 1e-4
|
285 |
-
|
286 |
-
logger = logging.get_logger(__name__)
|
287 |
-
|
288 |
-
# General docstring
|
289 |
-
_CONFIG_FOR_DOC = "HieraConfig"
|
290 |
-
|
291 |
-
# Base docstring
|
292 |
-
_CHECKPOINT_FOR_DOC = "EduardoPacheco/hiera-tiny-224"
|
293 |
-
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
|
294 |
-
|
295 |
-
# Image classification docstring
|
296 |
-
_IMAGE_CLASS_CHECKPOINT = "EduardoPacheco/hiera-tiny-224-in1k"
|
297 |
-
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
298 |
-
|
299 |
-
|
300 |
-
HIERA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
301 |
-
"EduardoPacheco/hiera-tiny-224",
|
302 |
-
# See all Hiera models at https://huggingface.co/models?filter=hiera
|
303 |
-
]
|
304 |
-
|
305 |
-
|
306 |
-
@dataclass
|
307 |
-
class HieraEncoderOutput(ModelOutput):
|
308 |
-
"""
|
309 |
-
Hiera encoder's outputs, with potential hidden states and attentions.
|
310 |
-
|
311 |
-
Args:
|
312 |
-
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
313 |
-
Sequence of hidden-states at the output of the last layer of the model.
|
314 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
315 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
316 |
-
shape `(batch_size, sequence_length, hidden_size)`. Thesre are the unrolled hidden states of the model.
|
317 |
-
|
318 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
319 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
320 |
-
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
321 |
-
sequence_length)`.
|
322 |
-
|
323 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
324 |
-
heads.
|
325 |
-
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
326 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
327 |
-
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
|
328 |
-
|
329 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
330 |
-
include the spatial dimensions.
|
331 |
-
"""
|
332 |
-
|
333 |
-
last_hidden_state: torch.FloatTensor = None
|
334 |
-
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
335 |
-
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
336 |
-
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
337 |
-
|
338 |
-
|
339 |
-
@dataclass
|
340 |
-
class HieraModelOutput(ModelOutput):
|
341 |
-
"""
|
342 |
-
Hiera model's outputs that also contains a pooling of the last hidden states.
|
343 |
-
|
344 |
-
Args:
|
345 |
-
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
346 |
-
Sequence of hidden-states at the output of the last layer of the model.
|
347 |
-
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
|
348 |
-
Average pooling of the last layer hidden-state.
|
349 |
-
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
350 |
-
Tensor indicating which patches are masked (0) and which are not (1).
|
351 |
-
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
352 |
-
Tensor containing the original index of the (shuffled) masked patches.
|
353 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
354 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
355 |
-
shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
|
356 |
-
|
357 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
358 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
359 |
-
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
360 |
-
sequence_length)`.
|
361 |
-
|
362 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
363 |
-
heads.
|
364 |
-
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
365 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
366 |
-
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
|
367 |
-
|
368 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
369 |
-
include the spatial dimensions.
|
370 |
-
"""
|
371 |
-
|
372 |
-
last_hidden_state: torch.FloatTensor = None
|
373 |
-
pooler_output: Optional[torch.FloatTensor] = None
|
374 |
-
mask: torch.LongTensor = None
|
375 |
-
ids_restore: torch.LongTensor = None
|
376 |
-
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
377 |
-
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
378 |
-
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
379 |
-
|
380 |
-
|
381 |
-
@dataclass
|
382 |
-
class HieraForImageClassificationOutput(ImageClassifierOutput):
|
383 |
-
"""
|
384 |
-
Hiera image classification outputs.
|
385 |
-
|
386 |
-
Args:
|
387 |
-
loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
|
388 |
-
Classification loss.
|
389 |
-
logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
|
390 |
-
Prediction scores of the classification head (logits of the output layer).
|
391 |
-
hidden_states (`tuple(torch.FloatTensor)`, `optional`):
|
392 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
393 |
-
shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
|
394 |
-
|
395 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
396 |
-
attentions (`tuple(torch.FloatTensor)`, `optional`):
|
397 |
-
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
398 |
-
sequence_length)`.
|
399 |
-
|
400 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
401 |
-
heads.
|
402 |
-
reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
|
403 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
404 |
-
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
|
405 |
-
|
406 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
407 |
-
include the spatial dimensions.
|
408 |
-
"""
|
409 |
-
|
410 |
-
loss: Optional[torch.FloatTensor] = None
|
411 |
-
logits: torch.FloatTensor = None
|
412 |
-
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
413 |
-
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
414 |
-
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
415 |
-
|
416 |
-
|
417 |
-
@dataclass
|
418 |
-
class HieraForPreTrainingOutput(ModelOutput):
|
419 |
-
"""
|
420 |
-
Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
|
421 |
-
|
422 |
-
Args:
|
423 |
-
loss (`torch.FloatTensor` of shape `(1,)`):
|
424 |
-
Pixel reconstruction loss.
|
425 |
-
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
426 |
-
Pixel reconstruction logits.
|
427 |
-
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
428 |
-
Tensor indicating which patches are masked (0) and which are not (1).
|
429 |
-
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
430 |
-
Tensor containing the original index of the (shuffled) masked patches.
|
431 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
432 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
433 |
-
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
434 |
-
plus the initial embedding outputs.
|
435 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
436 |
-
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
437 |
-
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
438 |
-
the self-attention heads.
|
439 |
-
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
440 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
441 |
-
shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
|
442 |
-
plus the initial embedding outputs reshaped to include the spatial dimensions.
|
443 |
-
"""
|
444 |
-
|
445 |
-
loss: Optional[torch.FloatTensor] = None
|
446 |
-
logits: torch.FloatTensor = None
|
447 |
-
mask: torch.LongTensor = None
|
448 |
-
ids_restore: torch.LongTensor = None
|
449 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
450 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
451 |
-
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
452 |
-
|
453 |
-
|
454 |
-
# Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L73
|
455 |
-
def conv_nd(n: int) -> nn.Module:
|
456 |
-
"""
|
457 |
-
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
|
458 |
-
If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
|
459 |
-
"""
|
460 |
-
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
461 |
-
|
462 |
-
|
463 |
-
# Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L81
|
464 |
-
def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
|
465 |
-
# Refer to `Unroll` to see how this performs a maxpool-Nd
|
466 |
-
return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
|
467 |
-
|
468 |
-
|
469 |
-
class HieraPatchEmbeddings(nn.Module):
|
470 |
-
"""
|
471 |
-
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
472 |
-
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
473 |
-
Transformer.
|
474 |
-
"""
|
475 |
-
|
476 |
-
def __init__(self, config, is_mae: bool = False):
|
477 |
-
super().__init__()
|
478 |
-
|
479 |
-
# Support any number of spatial dimensions
|
480 |
-
self.spatial_dims = len(config.patch_kernel)
|
481 |
-
if self.spatial_dims not in (2, 3):
|
482 |
-
raise ValueError(
|
483 |
-
f"The number of dimensions of the input image should be 2 or 3, but got {self.spatial_dims}."
|
484 |
-
)
|
485 |
-
self.num_channels = config.num_channels
|
486 |
-
self.image_size = config.input_size[-2:]
|
487 |
-
self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
|
488 |
-
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
|
489 |
-
self.mask_ratio = config.mask_ratio
|
490 |
-
self.is_mae = is_mae
|
491 |
-
|
492 |
-
self.projection = conv_nd(self.spatial_dims)(
|
493 |
-
self.num_channels,
|
494 |
-
config.embed_dim,
|
495 |
-
kernel_size=config.patch_kernel,
|
496 |
-
stride=config.patch_stride,
|
497 |
-
padding=config.patch_padding,
|
498 |
-
)
|
499 |
-
|
500 |
-
def masked_conv(self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
501 |
-
"""Zero-out the masked regions of the input before conv.
|
502 |
-
Prevents leakage of masked regions when using overlapping kernels.
|
503 |
-
"""
|
504 |
-
if mask is None:
|
505 |
-
return self.projection(pixel_values)
|
506 |
-
|
507 |
-
target_size = pixel_values.shape[2:]
|
508 |
-
# Reshape mask to (batch_size, 1, mask_unit_height, mask_unit_width)
|
509 |
-
mask = mask.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
|
510 |
-
|
511 |
-
if len(mask.shape[2:]) != len(target_size):
|
512 |
-
raise ValueError(
|
513 |
-
f"The length of the spatial dimensions of the mask should match the one from input image, but got {len(mask.shape[2:])} and {len(target_size)}."
|
514 |
-
)
|
515 |
-
|
516 |
-
if mask.shape[2:] != target_size:
|
517 |
-
mask = nn.functional.interpolate(mask, size=target_size)
|
518 |
-
|
519 |
-
return self.projection(pixel_values * mask.bool())
|
520 |
-
|
521 |
-
def random_masking(self, pixel_values, noise=None):
|
522 |
-
"""
|
523 |
-
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
524 |
-
noise.
|
525 |
-
|
526 |
-
Args:
|
527 |
-
pixel_values (`torch.LongTensor` of shape `(batch_size, num_channels, height, width)`)
|
528 |
-
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
|
529 |
-
mainly used for testing purposes to control randomness and maintain the reproducibility
|
530 |
-
"""
|
531 |
-
batch_size = pixel_values.shape[0]
|
532 |
-
# Tokens selected for masking at mask unit level
|
533 |
-
num_windows = math.prod(self.mask_spatial_shape)
|
534 |
-
len_keep = int(num_windows * (1 - self.mask_ratio))
|
535 |
-
|
536 |
-
if noise is None:
|
537 |
-
noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
|
538 |
-
|
539 |
-
# Sort noise for each sample
|
540 |
-
ids_shuffle = torch.argsort(noise, dim=1)
|
541 |
-
# ascend: small is keep, large is remove
|
542 |
-
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
543 |
-
|
544 |
-
# Generate the binary mask: 1 is *keep*, 0 is *remove*
|
545 |
-
# Note this is opposite to original MAE
|
546 |
-
mask = torch.zeros([batch_size, num_windows], device=pixel_values.device)
|
547 |
-
mask[:, :len_keep] = 1
|
548 |
-
# Unshuffle to get the binary mask
|
549 |
-
mask = torch.gather(mask, dim=1, index=ids_restore)
|
550 |
-
|
551 |
-
return mask, ids_restore
|
552 |
-
|
553 |
-
def forward(
|
554 |
-
self,
|
555 |
-
pixel_values: torch.Tensor,
|
556 |
-
noise: Optional[torch.FloatTensor] = None,
|
557 |
-
interpolate_pos_encoding: bool = False,
|
558 |
-
) -> torch.Tensor:
|
559 |
-
num_channels = pixel_values.shape[1]
|
560 |
-
height, width = pixel_values.shape[-2:]
|
561 |
-
|
562 |
-
if num_channels != self.num_channels:
|
563 |
-
raise ValueError(
|
564 |
-
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
565 |
-
f" Expected {self.num_channels} but got {num_channels}."
|
566 |
-
)
|
567 |
-
|
568 |
-
if not interpolate_pos_encoding:
|
569 |
-
if height != self.image_size[0] or width != self.image_size[1]:
|
570 |
-
raise ValueError(
|
571 |
-
f"Input image size ({height}*{width}) doesn't match model"
|
572 |
-
f" ({self.image_size[0]}*{self.image_size[1]})."
|
573 |
-
)
|
574 |
-
|
575 |
-
(mask, ids_restore) = self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
|
576 |
-
|
577 |
-
embeddings = self.masked_conv(pixel_values, mask)
|
578 |
-
embeddings = embeddings.flatten(2).transpose(2, 1)
|
579 |
-
|
580 |
-
return embeddings, mask, ids_restore
|
581 |
-
|
582 |
-
|
583 |
-
class HieraEmbeddings(nn.Module):
|
584 |
-
"""
|
585 |
-
Construct position and patch embeddings.
|
586 |
-
"""
|
587 |
-
|
588 |
-
def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
|
589 |
-
super().__init__()
|
590 |
-
self.patch_stride = config.patch_stride
|
591 |
-
self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
|
592 |
-
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
|
593 |
-
self.num_tokens = math.prod(self.tokens_spatial_shape)
|
594 |
-
self.sep_pos_embed = config.sep_pos_embed
|
595 |
-
self.is_mae = is_mae
|
596 |
-
|
597 |
-
self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
|
598 |
-
|
599 |
-
if self.sep_pos_embed:
|
600 |
-
self.position_embeddings_spatial = nn.Parameter(
|
601 |
-
torch.zeros(
|
602 |
-
1,
|
603 |
-
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
|
604 |
-
config.embed_dim,
|
605 |
-
)
|
606 |
-
)
|
607 |
-
self.position_embeddings_temporal = nn.Parameter(
|
608 |
-
torch.zeros(1, self.tokens_spatial_shape[0], config.embed_dim)
|
609 |
-
)
|
610 |
-
else:
|
611 |
-
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
|
612 |
-
|
613 |
-
def interpolate_pos_encoding(
|
614 |
-
self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
|
615 |
-
) -> torch.Tensor:
|
616 |
-
"""
|
617 |
-
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
618 |
-
resolution images.
|
619 |
-
|
620 |
-
Adapted from:
|
621 |
-
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
622 |
-
"""
|
623 |
-
|
624 |
-
num_patches = embeddings.shape[1]
|
625 |
-
num_positions = pos_embeds.shape[1]
|
626 |
-
if num_patches == num_positions and height == width:
|
627 |
-
return pos_embeds
|
628 |
-
dim = embeddings.shape[-1]
|
629 |
-
h0 = height // self.patch_stride[0] if not self.sep_pos_embed else height // self.patch_stride[1]
|
630 |
-
w0 = width // self.patch_stride[1] if not self.sep_pos_embed else width // self.patch_stride[2]
|
631 |
-
# we add a small number to avoid floating point error in the interpolation
|
632 |
-
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
633 |
-
h0, w0 = h0 + 0.1, w0 + 0.1
|
634 |
-
pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
635 |
-
pos_embeds = pos_embeds.permute(0, 3, 1, 2)
|
636 |
-
pos_embeds = nn.functional.interpolate(
|
637 |
-
pos_embeds,
|
638 |
-
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
639 |
-
mode="bicubic",
|
640 |
-
align_corners=False,
|
641 |
-
)
|
642 |
-
if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]:
|
643 |
-
raise ValueError("The interpolated position encoding does not have the right size")
|
644 |
-
pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
|
645 |
-
return pos_embeds
|
646 |
-
|
647 |
-
def get_position_embedding(
|
648 |
-
self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
|
649 |
-
) -> torch.Tensor:
|
650 |
-
if self.sep_pos_embed:
|
651 |
-
spatial = self.position_embeddings_spatial
|
652 |
-
spatial = (
|
653 |
-
self.interpolate_pos_encoding(embeddings, spatial, height, width)
|
654 |
-
if interpolate_pos_encoding
|
655 |
-
else spatial
|
656 |
-
)
|
657 |
-
spatial = spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
658 |
-
|
659 |
-
temporal = torch.repeat_interleave(
|
660 |
-
self.position_embeddings_temporal,
|
661 |
-
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
|
662 |
-
dim=1,
|
663 |
-
)
|
664 |
-
|
665 |
-
return spatial + temporal
|
666 |
-
else:
|
667 |
-
position_embeddings = self.position_embeddings
|
668 |
-
position_embeddings = (
|
669 |
-
self.interpolate_pos_encoding(embeddings, position_embeddings, height, width)
|
670 |
-
if interpolate_pos_encoding
|
671 |
-
else position_embeddings
|
672 |
-
)
|
673 |
-
return position_embeddings
|
674 |
-
|
675 |
-
def forward(
|
676 |
-
self,
|
677 |
-
pixel_values: torch.Tensor,
|
678 |
-
noise: Optional[torch.FloatTensor] = None,
|
679 |
-
interpolate_pos_encoding: bool = False,
|
680 |
-
) -> torch.Tensor:
|
681 |
-
if len(self.tokens_spatial_shape) == 2:
|
682 |
-
batch_size, num_channels, height, width = pixel_values.shape
|
683 |
-
else:
|
684 |
-
batch_size, num_channels, depth, height, width = pixel_values.shape
|
685 |
-
|
686 |
-
embeddings, mask, ids_restore = self.patch_embeddings(
|
687 |
-
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
|
688 |
-
)
|
689 |
-
|
690 |
-
embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
|
691 |
-
|
692 |
-
return embeddings, mask, ids_restore
|
693 |
-
|
694 |
-
|
695 |
-
class HieraMaskUnitAttention(nn.Module):
|
696 |
-
"""
|
697 |
-
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
|
698 |
-
|
699 |
-
Note: this assumes the tokens have already been flattened and unrolled into mask units.
|
700 |
-
"""
|
701 |
-
|
702 |
-
def __init__(
|
703 |
-
self,
|
704 |
-
dim: int,
|
705 |
-
dim_out: int,
|
706 |
-
num_heads: int,
|
707 |
-
query_stride: int = 1,
|
708 |
-
window_size: int = 0,
|
709 |
-
use_mask_unit_attn: bool = False,
|
710 |
-
):
|
711 |
-
super().__init__()
|
712 |
-
|
713 |
-
self.dim = dim
|
714 |
-
self.dim_out = dim_out
|
715 |
-
self.num_heads = num_heads
|
716 |
-
self.query_stride = query_stride
|
717 |
-
|
718 |
-
self.head_dim = dim_out // num_heads
|
719 |
-
self.scale = (self.head_dim) ** -0.5
|
720 |
-
|
721 |
-
self.qkv = nn.Linear(dim, 3 * dim_out)
|
722 |
-
self.proj = nn.Linear(dim_out, dim_out)
|
723 |
-
|
724 |
-
self.window_size = window_size
|
725 |
-
self.use_mask_unit_attn = use_mask_unit_attn
|
726 |
-
|
727 |
-
def forward(
|
728 |
-
self,
|
729 |
-
hidden_states: torch.Tensor,
|
730 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
731 |
-
output_attentions: bool = False,
|
732 |
-
) -> torch.Tensor:
|
733 |
-
"""Input should be of shape [batch, tokens, channels]."""
|
734 |
-
batch_size, seq_len, _ = hidden_states.shape
|
735 |
-
|
736 |
-
num_windows = 1
|
737 |
-
if self.use_mask_unit_attn:
|
738 |
-
num_windows = seq_len // (self.query_stride * self.window_size)
|
739 |
-
|
740 |
-
qkv = self.qkv(hidden_states)
|
741 |
-
qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
|
742 |
-
qkv = qkv.permute(3, 0, 4, 2, 1, 5)
|
743 |
-
|
744 |
-
query, key, value = qkv.unbind(0)
|
745 |
-
|
746 |
-
if self.query_stride > 1:
|
747 |
-
# Refer to Unroll to see how this performs a maxpool-Nd
|
748 |
-
query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
|
749 |
-
query = query.max(dim=3).values
|
750 |
-
|
751 |
-
attn_weights = (query * self.scale) @ key.transpose(-1, -2)
|
752 |
-
attn_weights = attn_weights.softmax(dim=-1)
|
753 |
-
|
754 |
-
# Mask heads if we want to
|
755 |
-
if head_mask is not None:
|
756 |
-
attn_weights = attn_weights * head_mask
|
757 |
-
|
758 |
-
attn_output = attn_weights @ value
|
759 |
-
attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.dim_out)
|
760 |
-
attn_output = self.proj(attn_output)
|
761 |
-
|
762 |
-
return (attn_output, attn_weights) if output_attentions else (attn_output, None)
|
763 |
-
|
764 |
-
|
765 |
-
# Copied from transformers.models.beit.modeling_beit.drop_path
|
766 |
-
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
767 |
-
"""
|
768 |
-
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
769 |
-
|
770 |
-
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
771 |
-
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
772 |
-
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
773 |
-
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
774 |
-
argument.
|
775 |
-
"""
|
776 |
-
if drop_prob == 0.0 or not training:
|
777 |
-
return input
|
778 |
-
keep_prob = 1 - drop_prob
|
779 |
-
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
780 |
-
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
781 |
-
random_tensor.floor_() # binarize
|
782 |
-
output = input.div(keep_prob) * random_tensor
|
783 |
-
return output
|
784 |
-
|
785 |
-
|
786 |
-
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
|
787 |
-
class HieraDropPath(nn.Module):
|
788 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
789 |
-
|
790 |
-
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
791 |
-
super().__init__()
|
792 |
-
self.drop_prob = drop_prob
|
793 |
-
|
794 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
795 |
-
return drop_path(hidden_states, self.drop_prob, self.training)
|
796 |
-
|
797 |
-
def extra_repr(self) -> str:
|
798 |
-
return "p={}".format(self.drop_prob)
|
799 |
-
|
800 |
-
|
801 |
-
class HieraMlp(nn.Module):
|
802 |
-
def __init__(self, config, dim: int):
|
803 |
-
super().__init__()
|
804 |
-
self.config = config
|
805 |
-
self.activation_fn = ACT2FN[config.hidden_act]
|
806 |
-
self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
|
807 |
-
self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
|
808 |
-
|
809 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
810 |
-
hidden_states = self.fc1(hidden_states)
|
811 |
-
hidden_states = self.activation_fn(hidden_states)
|
812 |
-
hidden_states = self.fc2(hidden_states)
|
813 |
-
return hidden_states
|
814 |
-
|
815 |
-
|
816 |
-
class HieraLayer(nn.Module):
|
817 |
-
def __init__(
|
818 |
-
self,
|
819 |
-
config,
|
820 |
-
dim: int,
|
821 |
-
dim_out: int,
|
822 |
-
num_heads: int,
|
823 |
-
drop_path: float = 0.0,
|
824 |
-
query_stride: int = 1,
|
825 |
-
window_size: int = 0,
|
826 |
-
use_mask_unit_attn: bool = False,
|
827 |
-
):
|
828 |
-
super().__init__()
|
829 |
-
|
830 |
-
self.dim = dim
|
831 |
-
self.dim_out = dim_out
|
832 |
-
self.query_stride = query_stride
|
833 |
-
|
834 |
-
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
835 |
-
self.attn = HieraMaskUnitAttention(dim, dim_out, num_heads, query_stride, window_size, use_mask_unit_attn)
|
836 |
-
|
837 |
-
self.layernorm_after = nn.LayerNorm(dim_out, eps=config.layer_norm_eps)
|
838 |
-
self.mlp = HieraMlp(config, dim_out)
|
839 |
-
|
840 |
-
self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
|
841 |
-
if dim != dim_out:
|
842 |
-
self.proj = nn.Linear(dim, dim_out)
|
843 |
-
|
844 |
-
def forward(
|
845 |
-
self,
|
846 |
-
hidden_states: torch.Tensor,
|
847 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
848 |
-
output_attentions: bool = False,
|
849 |
-
) -> torch.Tensor:
|
850 |
-
batch_size, seq_len, _ = hidden_states.shape
|
851 |
-
# Attention + Q Pooling
|
852 |
-
hidden_states_norm = self.layernorm_before(hidden_states)
|
853 |
-
if self.dim != self.dim_out:
|
854 |
-
hidden_states = self.proj(hidden_states_norm)
|
855 |
-
# Refer to `HieraUnroll` to see how this performs a maxpool-Nd
|
856 |
-
hidden_states = hidden_states.view(batch_size, self.query_stride, -1, self.dim_out).max(dim=1).values
|
857 |
-
|
858 |
-
(hidden_states_norm, attn_weights) = self.attn(
|
859 |
-
hidden_states_norm, head_mask, output_attentions=output_attentions
|
860 |
-
)
|
861 |
-
hidden_states = hidden_states + self.drop_path(hidden_states_norm)
|
862 |
-
|
863 |
-
residual = hidden_states
|
864 |
-
hidden_states = self.layernorm_after(hidden_states)
|
865 |
-
hidden_states = self.mlp(hidden_states)
|
866 |
-
hidden_states = residual + self.drop_path(hidden_states)
|
867 |
-
|
868 |
-
return (hidden_states, attn_weights)
|
869 |
-
|
870 |
-
|
871 |
-
class HieraStage(nn.Module):
|
872 |
-
def __init__(
|
873 |
-
self,
|
874 |
-
config,
|
875 |
-
depth: int,
|
876 |
-
dim: int,
|
877 |
-
dim_out: int,
|
878 |
-
num_heads: int,
|
879 |
-
drop_path: List[float],
|
880 |
-
query_stride: List[int],
|
881 |
-
window_size: int,
|
882 |
-
use_mask_unit_attn: bool,
|
883 |
-
stage_num: Optional[int] = None,
|
884 |
-
) -> None:
|
885 |
-
super().__init__()
|
886 |
-
# we need to know if the previous stage used masked attention
|
887 |
-
# mask unit or global attention.
|
888 |
-
# lag by 1 layer, so that global attention,
|
889 |
-
# applied post pooling on lower resolution
|
890 |
-
previous_stage_used_masked_attention = False
|
891 |
-
if stage_num is not None:
|
892 |
-
previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
|
893 |
-
self.layers = nn.ModuleList(
|
894 |
-
[
|
895 |
-
HieraLayer(
|
896 |
-
config=config,
|
897 |
-
dim=dim if i == 0 else dim_out,
|
898 |
-
dim_out=dim_out,
|
899 |
-
num_heads=num_heads,
|
900 |
-
drop_path=drop_path[i],
|
901 |
-
query_stride=query_stride[i],
|
902 |
-
window_size=window_size,
|
903 |
-
use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
|
904 |
-
)
|
905 |
-
for i in range(depth)
|
906 |
-
]
|
907 |
-
)
|
908 |
-
|
909 |
-
def forward(
|
910 |
-
self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
|
911 |
-
) -> torch.Tensor:
|
912 |
-
for i, layer_module in enumerate(self.layers):
|
913 |
-
layer_head_mask = head_mask[i] if head_mask is not None else None
|
914 |
-
(hidden_states, attn_weights) = layer_module(
|
915 |
-
hidden_states, layer_head_mask, output_attentions=output_attentions
|
916 |
-
)
|
917 |
-
|
918 |
-
return hidden_states, attn_weights
|
919 |
-
|
920 |
-
|
921 |
-
def undo_windowing(hidden_states: torch.Tensor, shape: List[int], mask_unit_shape: List[int]) -> torch.Tensor:
|
922 |
-
"""
|
923 |
-
Restore spatial organization by undoing windowed organization of mask units.
|
924 |
-
"""
|
925 |
-
num_dims = len(shape)
|
926 |
-
batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
|
927 |
-
# From: [batch_size, num_mask_unit_height*num_#mask_unit_wdith, mask_unit_height, mask_unit_width, hidden_size]
|
928 |
-
# To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
|
929 |
-
num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
|
930 |
-
hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
|
931 |
-
|
932 |
-
# From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
|
933 |
-
# To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
|
934 |
-
permute = (
|
935 |
-
[0]
|
936 |
-
+ sum(
|
937 |
-
[list(p) for p in zip(range(1, 1 + num_dims), range(1 + num_dims, 1 + 2 * num_dims))],
|
938 |
-
[],
|
939 |
-
)
|
940 |
-
+ [len(hidden_states.shape) - 1]
|
941 |
-
)
|
942 |
-
hidden_states = hidden_states.permute(permute).reshape(batch_size, *shape, hidden_size)
|
943 |
-
|
944 |
-
return hidden_states
|
945 |
-
|
946 |
-
|
947 |
-
class HieraEncoder(nn.Module):
|
948 |
-
def __init__(self, config: HieraConfig) -> None:
|
949 |
-
super().__init__()
|
950 |
-
self.config = config
|
951 |
-
|
952 |
-
# stochastic depth decay rule
|
953 |
-
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
954 |
-
# query strides rule
|
955 |
-
stage_ends = [sum(config.depths[:i]) - 1 for i in range(1, len(config.depths) + 1)]
|
956 |
-
query_pool_layer = [stage_end + 1 for stage_end in stage_ends[: config.num_query_pool]]
|
957 |
-
query_strides = [
|
958 |
-
math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(sum(config.depths))
|
959 |
-
]
|
960 |
-
|
961 |
-
# Transformer blocks
|
962 |
-
self.stages = nn.ModuleList()
|
963 |
-
embed_dim = config.embed_dim
|
964 |
-
|
965 |
-
for idx_stage, depth in enumerate(config.depths):
|
966 |
-
dim_out = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
|
967 |
-
|
968 |
-
stage = HieraStage(
|
969 |
-
config=config,
|
970 |
-
depth=depth,
|
971 |
-
dim=embed_dim,
|
972 |
-
dim_out=dim_out,
|
973 |
-
num_heads=int(config.initial_num_heads * config.num_head_multiplier**idx_stage),
|
974 |
-
drop_path=dpr[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
|
975 |
-
query_stride=query_strides[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
|
976 |
-
window_size=int(math.prod(config.masked_unit_size) * math.prod(config.query_stride) ** -idx_stage),
|
977 |
-
use_mask_unit_attn=config.masked_unit_attention[idx_stage],
|
978 |
-
stage_num=idx_stage,
|
979 |
-
)
|
980 |
-
|
981 |
-
embed_dim = dim_out
|
982 |
-
self.stages.append(stage)
|
983 |
-
|
984 |
-
# Setting reroll schedule
|
985 |
-
# The first stage has to reverse everything
|
986 |
-
# The next stage has to reverse all but the first unroll, etc.
|
987 |
-
stage_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
|
988 |
-
unroll_schedule = [config.query_stride] * len(config.depths[:-1])
|
989 |
-
|
990 |
-
self.schedule = {}
|
991 |
-
for idx_stage in range(len(config.depths)):
|
992 |
-
self.schedule[idx_stage] = unroll_schedule, stage_size
|
993 |
-
if idx_stage < config.num_query_pool:
|
994 |
-
stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
|
995 |
-
unroll_schedule = unroll_schedule[1:]
|
996 |
-
|
997 |
-
self.gradient_checkpointing = False
|
998 |
-
|
999 |
-
def reroll(
|
1000 |
-
self, hidden_states: torch.Tensor, stage_idx: int, mask: Optional[torch.BoolTensor] = None
|
1001 |
-
) -> torch.Tensor:
|
1002 |
-
"""
|
1003 |
-
Roll the given tensor back up to spatial order assuming it's from the given block.
|
1004 |
-
|
1005 |
-
If no mask is provided returns:
|
1006 |
-
- [batch_size, height, width, hidden_size] for 2d
|
1007 |
-
- [batch_size, frames, height, width, hidden_size] for 3d
|
1008 |
-
If a mask is provided returns:
|
1009 |
-
- [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size] for 2d
|
1010 |
-
"""
|
1011 |
-
schedule, size = self.schedule[stage_idx]
|
1012 |
-
batch_size, seq_len, hidden_size = hidden_states.shape
|
1013 |
-
|
1014 |
-
num_dim = len(size)
|
1015 |
-
mask_unit_shape = [1] * num_dim
|
1016 |
-
|
1017 |
-
for strides in schedule:
|
1018 |
-
# Extract the current patch from seq_len
|
1019 |
-
hidden_states = hidden_states.view(
|
1020 |
-
batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
|
1021 |
-
)
|
1022 |
-
|
1023 |
-
# Move that patch into the current MU
|
1024 |
-
# Example in 2d:
|
1025 |
-
# Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
|
1026 |
-
# Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
|
1027 |
-
L = len(hidden_states.shape)
|
1028 |
-
permute = (
|
1029 |
-
[0, 1 + num_dim]
|
1030 |
-
+ sum(
|
1031 |
-
[list(p) for p in zip(range(1, 1 + num_dim), range(1 + num_dim + 1, L - 1))],
|
1032 |
-
[],
|
1033 |
-
)
|
1034 |
-
+ [L - 1]
|
1035 |
-
)
|
1036 |
-
hidden_states = hidden_states.permute(permute)
|
1037 |
-
|
1038 |
-
# Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
|
1039 |
-
for i in range(num_dim):
|
1040 |
-
mask_unit_shape[i] *= strides[i]
|
1041 |
-
hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
|
1042 |
-
seq_len = hidden_states.shape[1]
|
1043 |
-
|
1044 |
-
# Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
|
1045 |
-
hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
|
1046 |
-
|
1047 |
-
# If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
|
1048 |
-
if mask is not None:
|
1049 |
-
return hidden_states
|
1050 |
-
|
1051 |
-
# If not masked, we can return [batch_size, height, width, hidden_size]
|
1052 |
-
hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
|
1053 |
-
|
1054 |
-
return hidden_states
|
1055 |
-
|
1056 |
-
def forward(
|
1057 |
-
self,
|
1058 |
-
hidden_states: torch.Tensor,
|
1059 |
-
mask: Optional[torch.BoolTensor] = None,
|
1060 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
1061 |
-
output_attentions: bool = False,
|
1062 |
-
output_hidden_states: bool = False,
|
1063 |
-
return_dict: bool = True,
|
1064 |
-
) -> Union[tuple, BaseModelOutput]:
|
1065 |
-
all_hidden_states = () if output_hidden_states else None
|
1066 |
-
all_reshaped_hidden_states = () if output_hidden_states else None
|
1067 |
-
all_self_attentions = () if output_attentions else None
|
1068 |
-
|
1069 |
-
if output_hidden_states:
|
1070 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
1071 |
-
reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, mask=mask)
|
1072 |
-
all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
|
1073 |
-
|
1074 |
-
for i, stage_module in enumerate(self.stages):
|
1075 |
-
layer_head_mask = head_mask[i] if head_mask is not None else None
|
1076 |
-
|
1077 |
-
if self.gradient_checkpointing and self.training:
|
1078 |
-
layer_outputs = self._gradient_checkpointing_func(
|
1079 |
-
stage_module.__call__, hidden_states, layer_head_mask, output_attentions
|
1080 |
-
)
|
1081 |
-
else:
|
1082 |
-
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
|
1083 |
-
|
1084 |
-
hidden_states = layer_outputs[0]
|
1085 |
-
|
1086 |
-
if output_attentions:
|
1087 |
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
1088 |
-
|
1089 |
-
if output_hidden_states:
|
1090 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
1091 |
-
reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, mask=mask)
|
1092 |
-
all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
|
1093 |
-
|
1094 |
-
if not return_dict:
|
1095 |
-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
1096 |
-
return HieraEncoderOutput(
|
1097 |
-
last_hidden_state=hidden_states,
|
1098 |
-
hidden_states=all_hidden_states,
|
1099 |
-
attentions=all_self_attentions,
|
1100 |
-
reshaped_hidden_states=all_reshaped_hidden_states,
|
1101 |
-
)
|
1102 |
-
|
1103 |
-
|
1104 |
-
def unroll(hidden_states: torch.Tensor, size: List[int], schedule: List[List[int]]) -> torch.Tensor:
|
1105 |
-
"""
|
1106 |
-
Reorders the tokens such that patches are contiguous in memory.
|
1107 |
-
E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
|
1108 |
-
[batch_size, (stride, stride, height // stride, width // stride), hidden_size]
|
1109 |
-
|
1110 |
-
This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
|
1111 |
-
Not only is this faster, but it also makes it easy to support inputs of arbitrary
|
1112 |
-
dimensions in addition to patch-wise sparsity.
|
1113 |
-
|
1114 |
-
Performing this operation multiple times in sequence puts entire windows as contiguous
|
1115 |
-
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
|
1116 |
-
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
|
1117 |
-
computed easily and efficiently, while also allowing max to be applied sequentially.
|
1118 |
-
|
1119 |
-
Note: This means that intermediate values of the model are not in height x width order, so they
|
1120 |
-
need to be re-rolled if you want to use the intermediate values as a height x width feature map.
|
1121 |
-
The last block of the network is fine though, since by then the strides are all consumed.
|
1122 |
-
"""
|
1123 |
-
batch_size, _, hidden_size = hidden_states.shape
|
1124 |
-
|
1125 |
-
current_size = size
|
1126 |
-
hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
|
1127 |
-
|
1128 |
-
for strides in schedule:
|
1129 |
-
# Move patches with the given strides to the batch dimension
|
1130 |
-
|
1131 |
-
# Create a view of the tensor with the patch stride as separate dims
|
1132 |
-
# For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
|
1133 |
-
current_size = [i // s for i, s in zip(current_size, strides)]
|
1134 |
-
# initialize new_shape with [height // stride, stride, width // stride, stride]
|
1135 |
-
new_shape = [item for pair in zip(current_size, strides) for item in pair]
|
1136 |
-
# add batch_size and hidden_size to new_shape
|
1137 |
-
new_shape = [batch_size] + new_shape + [hidden_size]
|
1138 |
-
hidden_states = hidden_states.view(new_shape)
|
1139 |
-
|
1140 |
-
# Move the patch stride into the batch dimension
|
1141 |
-
# For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
|
1142 |
-
num_dims = len(new_shape)
|
1143 |
-
permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
|
1144 |
-
hidden_states = hidden_states.permute(permute)
|
1145 |
-
|
1146 |
-
# Now finally flatten the relevant dims into the batch dimension
|
1147 |
-
hidden_states = hidden_states.flatten(0, len(strides))
|
1148 |
-
batch_size *= math.prod(strides)
|
1149 |
-
|
1150 |
-
hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
|
1151 |
-
return hidden_states
|
1152 |
-
|
1153 |
-
|
1154 |
-
class HieraPreTrainedModel(PreTrainedModel):
|
1155 |
-
"""
|
1156 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1157 |
-
models.
|
1158 |
-
"""
|
1159 |
-
|
1160 |
-
config_class = HieraConfig
|
1161 |
-
base_model_prefix = "hiera"
|
1162 |
-
main_input_name = "pixel_values"
|
1163 |
-
supports_gradient_checkpointing = True
|
1164 |
-
|
1165 |
-
def _init_weights(self, module) -> None:
|
1166 |
-
"""Initialize the weights"""
|
1167 |
-
std = self.config.initializer_range
|
1168 |
-
|
1169 |
-
if isinstance(module, HieraEmbeddings):
|
1170 |
-
if self.config.sep_pos_embed:
|
1171 |
-
nn.init.trunc_normal_(module.position_embeddings_spatial, std=std)
|
1172 |
-
nn.init.trunc_normal_(module.position_embeddings_temporal, std=std)
|
1173 |
-
else:
|
1174 |
-
nn.init.trunc_normal_(module.position_embeddings, std=std)
|
1175 |
-
|
1176 |
-
elif isinstance(module, HieraDecoder):
|
1177 |
-
nn.init.trunc_normal_(module.mask_token, std=std)
|
1178 |
-
nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
|
1179 |
-
|
1180 |
-
elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
1181 |
-
nn.init.trunc_normal_(module.weight, std=std)
|
1182 |
-
if module.bias is not None:
|
1183 |
-
nn.init.constant_(module.bias, std)
|
1184 |
-
|
1185 |
-
elif isinstance(module, nn.LayerNorm):
|
1186 |
-
nn.init.constant_(module.bias, std)
|
1187 |
-
nn.init.constant_(module.weight, self.config.layer_norm_init)
|
1188 |
-
|
1189 |
-
|
1190 |
-
HIERA_START_DOCSTRING = r"""
|
1191 |
-
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
1192 |
-
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
1193 |
-
behavior.
|
1194 |
-
|
1195 |
-
Parameters:
|
1196 |
-
config ([`HieraConfig`]): Model configuration class with all the parameters of the model.
|
1197 |
-
Initializing with a config file does not load the weights associated with the model, only the
|
1198 |
-
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1199 |
-
"""
|
1200 |
-
|
1201 |
-
HIERA_INPUTS_DOCSTRING = r"""
|
1202 |
-
Args:
|
1203 |
-
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
1204 |
-
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
|
1205 |
-
for details.
|
1206 |
-
|
1207 |
-
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1208 |
-
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
1209 |
-
|
1210 |
-
- 1 indicates the head is **not masked**,
|
1211 |
-
- 0 indicates the head is **masked**.
|
1212 |
-
|
1213 |
-
output_attentions (`bool`, *optional*):
|
1214 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1215 |
-
tensors for more detail.
|
1216 |
-
output_hidden_states (`bool`, *optional*):
|
1217 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1218 |
-
more detail.
|
1219 |
-
interpolate_pos_encoding (`bool`, *optional*):
|
1220 |
-
Whether to interpolate the pre-trained position encodings.
|
1221 |
-
return_dict (`bool`, *optional*):
|
1222 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1223 |
-
"""
|
1224 |
-
|
1225 |
-
|
1226 |
-
class HieraPooler(nn.Module):
|
1227 |
-
def __init__(self, config: HieraConfig):
|
1228 |
-
super().__init__()
|
1229 |
-
num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
|
1230 |
-
self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
|
1231 |
-
self.pooler = nn.AdaptiveAvgPool1d(1)
|
1232 |
-
|
1233 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
1234 |
-
hidden_states = hidden_states.transpose(1, 2)
|
1235 |
-
pooled_output = self.pooler(hidden_states)
|
1236 |
-
pooled_output = torch.flatten(pooled_output, 1)
|
1237 |
-
pooled_output = self.layernorm(pooled_output)
|
1238 |
-
return pooled_output
|
1239 |
-
|
1240 |
-
|
1241 |
-
@add_start_docstrings(
|
1242 |
-
"The bare Hiera Model transformer outputting raw hidden-states without any specific head on top.",
|
1243 |
-
HIERA_START_DOCSTRING,
|
1244 |
-
"""
|
1245 |
-
add_pooling_layer (`bool`, *optional*, defaults to `True`):
|
1246 |
-
Whether or not to apply pooling layer.
|
1247 |
-
is_mae (`bool`, *optional*, defaults to `False`):
|
1248 |
-
Whether or not to run the model on MAE mode.
|
1249 |
-
""",
|
1250 |
-
)
|
1251 |
-
class HieraModel(HieraPreTrainedModel):
|
1252 |
-
def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
|
1253 |
-
super().__init__(config)
|
1254 |
-
self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
|
1255 |
-
|
1256 |
-
self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
|
1257 |
-
self.encoder = HieraEncoder(config)
|
1258 |
-
|
1259 |
-
self.unroll_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
|
1260 |
-
self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
|
1261 |
-
|
1262 |
-
self.pooler = HieraPooler(config) if add_pooling_layer else None
|
1263 |
-
|
1264 |
-
# Initialize weights and apply final processing
|
1265 |
-
self.post_init()
|
1266 |
-
|
1267 |
-
def get_input_embeddings(self) -> HieraPatchEmbeddings:
|
1268 |
-
return self.embeddings.patch_embeddings
|
1269 |
-
|
1270 |
-
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
1271 |
-
"""
|
1272 |
-
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
1273 |
-
class PreTrainedModel
|
1274 |
-
"""
|
1275 |
-
for layer, heads in heads_to_prune.items():
|
1276 |
-
self.encoder.layer[layer].attention.prune_heads(heads)
|
1277 |
-
|
1278 |
-
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
|
1279 |
-
@add_code_sample_docstrings(
|
1280 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
1281 |
-
output_type=HieraModelOutput,
|
1282 |
-
config_class=_CONFIG_FOR_DOC,
|
1283 |
-
modality="vision",
|
1284 |
-
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
1285 |
-
)
|
1286 |
-
def forward(
|
1287 |
-
self,
|
1288 |
-
pixel_values: Optional[torch.Tensor] = None,
|
1289 |
-
noise: Optional[torch.FloatTensor] = None,
|
1290 |
-
head_mask: Optional[torch.Tensor] = None,
|
1291 |
-
output_attentions: Optional[bool] = None,
|
1292 |
-
output_hidden_states: Optional[bool] = None,
|
1293 |
-
interpolate_pos_encoding: Optional[bool] = None,
|
1294 |
-
return_dict: Optional[bool] = None,
|
1295 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
1296 |
-
r"""
|
1297 |
-
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
|
1298 |
-
mainly used for testing purposes to control randomness and maintain the reproducibility
|
1299 |
-
when is_mae is set to True.
|
1300 |
-
"""
|
1301 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1302 |
-
output_hidden_states = (
|
1303 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1304 |
-
)
|
1305 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1306 |
-
|
1307 |
-
if pixel_values is None:
|
1308 |
-
raise ValueError("You have to specify pixel_values")
|
1309 |
-
|
1310 |
-
# Prepare head mask if needed
|
1311 |
-
# 1.0 in head_mask indicate we keep the head
|
1312 |
-
# attention_probs has shape bsz x n_heads x N x N
|
1313 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1314 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1315 |
-
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
1316 |
-
|
1317 |
-
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
1318 |
-
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
1319 |
-
if pixel_values.dtype != expected_dtype:
|
1320 |
-
pixel_values = pixel_values.to(expected_dtype)
|
1321 |
-
|
1322 |
-
embedding_output, mask, ids_restore = self.embeddings(
|
1323 |
-
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
|
1324 |
-
)
|
1325 |
-
|
1326 |
-
hidden_states = unroll(embedding_output, self.unroll_size, self.unroll_schedule)
|
1327 |
-
|
1328 |
-
# Discard masked tokens if mask is provided
|
1329 |
-
if mask is not None:
|
1330 |
-
mask_unit_area = math.prod(self.config.masked_unit_size)
|
1331 |
-
batch_size, _, hidden_size = hidden_states.shape
|
1332 |
-
positions = mask.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
|
1333 |
-
positions = positions.bool()
|
1334 |
-
hidden_states = hidden_states[positions]
|
1335 |
-
hidden_states = hidden_states.view(batch_size, -1, hidden_size)
|
1336 |
-
|
1337 |
-
encoder_outputs = self.encoder(
|
1338 |
-
hidden_states,
|
1339 |
-
mask=mask,
|
1340 |
-
head_mask=head_mask,
|
1341 |
-
output_attentions=output_attentions,
|
1342 |
-
output_hidden_states=output_hidden_states,
|
1343 |
-
return_dict=return_dict,
|
1344 |
-
)
|
1345 |
-
sequence_output = encoder_outputs[0]
|
1346 |
-
pooled_output = None
|
1347 |
-
if self.pooler is not None:
|
1348 |
-
pooled_output = self.pooler(sequence_output)
|
1349 |
-
|
1350 |
-
if not return_dict:
|
1351 |
-
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
1352 |
-
head_outputs = head_outputs + (mask, ids_restore) if mask is not None else head_outputs
|
1353 |
-
return head_outputs + encoder_outputs[1:]
|
1354 |
-
|
1355 |
-
return HieraModelOutput(
|
1356 |
-
last_hidden_state=sequence_output,
|
1357 |
-
pooler_output=pooled_output,
|
1358 |
-
mask=mask,
|
1359 |
-
ids_restore=ids_restore,
|
1360 |
-
hidden_states=encoder_outputs.hidden_states,
|
1361 |
-
attentions=encoder_outputs.attentions,
|
1362 |
-
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
|
1363 |
-
)
|
1364 |
-
|
1365 |
-
|
1366 |
-
class HieraDecoder(nn.Module):
|
1367 |
-
def __init__(self, config: HieraConfig):
|
1368 |
-
super().__init__()
|
1369 |
-
num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
|
1370 |
-
self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
|
1371 |
-
self.tokens_spatial_shape_final = [
|
1372 |
-
i // s ** (config.num_query_pool) for i, s in zip(self.tokens_spatial_shape, config.query_stride)
|
1373 |
-
]
|
1374 |
-
self.mask_unit_spatial_shape_final = [
|
1375 |
-
i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
|
1376 |
-
]
|
1377 |
-
|
1378 |
-
self.decoder_embeddings = nn.Linear(num_features, config.decoder_embed_dim)
|
1379 |
-
|
1380 |
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_embed_dim))
|
1381 |
-
|
1382 |
-
self.decoder_position_embeddings = nn.Parameter(
|
1383 |
-
torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_embed_dim)
|
1384 |
-
)
|
1385 |
-
|
1386 |
-
self.decoder_block = HieraStage(
|
1387 |
-
config=config,
|
1388 |
-
dim=config.decoder_embed_dim,
|
1389 |
-
dim_out=config.decoder_embed_dim,
|
1390 |
-
num_heads=config.decoder_num_heads,
|
1391 |
-
depth=config.decoder_depth,
|
1392 |
-
use_mask_unit_attn=False,
|
1393 |
-
drop_path=[0.0] * config.decoder_depth,
|
1394 |
-
query_stride=[1] * config.decoder_depth,
|
1395 |
-
window_size=0,
|
1396 |
-
)
|
1397 |
-
|
1398 |
-
self.decoder_norm = nn.LayerNorm(config.decoder_embed_dim, eps=config.layer_norm_eps)
|
1399 |
-
|
1400 |
-
# patch stride of prediction
|
1401 |
-
self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
|
1402 |
-
pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
|
1403 |
-
|
1404 |
-
self.decoder_pred = nn.Linear(config.decoder_embed_dim, pred_dim)
|
1405 |
-
|
1406 |
-
def forward(
|
1407 |
-
self,
|
1408 |
-
encoder_hidden_states: torch.Tensor,
|
1409 |
-
mask: torch.BoolTensor,
|
1410 |
-
head_mask: Optional[torch.Tensor] = None,
|
1411 |
-
output_attentions: bool = False,
|
1412 |
-
) -> torch.Tensor:
|
1413 |
-
# Embed tokens
|
1414 |
-
hidden_states = self.decoder_embeddings(encoder_hidden_states)
|
1415 |
-
|
1416 |
-
# Combine visible and mask tokens
|
1417 |
-
|
1418 |
-
# hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_embed_dim]
|
1419 |
-
# mask: [batch_size, num_mask_units]
|
1420 |
-
decoder_hidden_states = torch.zeros(
|
1421 |
-
*mask.shape, *hidden_states.shape[2:], device=hidden_states.device, dtype=hidden_states.dtype
|
1422 |
-
)
|
1423 |
-
mask_tokens = self.mask_token.view((1,) * (len(mask.shape) + len(hidden_states.shape[2:-1])) + (-1,))
|
1424 |
-
new_mask_shape = mask.shape + (1,) * len(hidden_states.shape[2:])
|
1425 |
-
mask = mask.reshape(new_mask_shape)
|
1426 |
-
expand_shape = (-1,) * 2 + hidden_states.shape[2:]
|
1427 |
-
mask = mask.expand(expand_shape)
|
1428 |
-
decoder_hidden_states[mask.bool()] = hidden_states.flatten()
|
1429 |
-
decoder_hidden_states = (1 - mask) * mask_tokens + mask * decoder_hidden_states
|
1430 |
-
|
1431 |
-
# Get back spatial order
|
1432 |
-
hidden_states = undo_windowing(
|
1433 |
-
decoder_hidden_states,
|
1434 |
-
self.tokens_spatial_shape_final,
|
1435 |
-
self.mask_unit_spatial_shape_final,
|
1436 |
-
)
|
1437 |
-
mask = undo_windowing(
|
1438 |
-
mask[..., 0:1],
|
1439 |
-
self.tokens_spatial_shape_final,
|
1440 |
-
self.mask_unit_spatial_shape_final,
|
1441 |
-
)
|
1442 |
-
|
1443 |
-
# Flatten
|
1444 |
-
hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
|
1445 |
-
mask = mask.view(hidden_states.shape[0], -1)
|
1446 |
-
|
1447 |
-
# Add pos embed
|
1448 |
-
hidden_states = hidden_states + self.decoder_position_embeddings
|
1449 |
-
|
1450 |
-
# Apply decoder blocks
|
1451 |
-
hidden_states, attn_weights = self.decoder_block(
|
1452 |
-
hidden_states, head_mask=head_mask, output_attentions=output_attentions
|
1453 |
-
)
|
1454 |
-
hidden_states = self.decoder_norm(hidden_states)
|
1455 |
-
|
1456 |
-
# Predictor projection
|
1457 |
-
hidden_states = self.decoder_pred(hidden_states)
|
1458 |
-
|
1459 |
-
return hidden_states, mask
|
1460 |
-
|
1461 |
-
|
1462 |
-
class HieraMultiScaleHead(nn.Module):
|
1463 |
-
def __init__(self, config: HieraConfig):
|
1464 |
-
super().__init__()
|
1465 |
-
self.mask_unit_spatial_shape_final = [
|
1466 |
-
i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
|
1467 |
-
]
|
1468 |
-
self.stage_dimensions = [
|
1469 |
-
int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
|
1470 |
-
]
|
1471 |
-
current_masked_unit_size = config.masked_unit_size
|
1472 |
-
self.multi_scale_fusion_heads = nn.ModuleList()
|
1473 |
-
|
1474 |
-
for idx in range(config.num_query_pool):
|
1475 |
-
kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
|
1476 |
-
current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
|
1477 |
-
self.multi_scale_fusion_heads.append(
|
1478 |
-
conv_nd(len(config.query_stride))(
|
1479 |
-
self.stage_dimensions[idx],
|
1480 |
-
self.stage_dimensions[-1],
|
1481 |
-
kernel_size=kernel,
|
1482 |
-
stride=kernel,
|
1483 |
-
)
|
1484 |
-
)
|
1485 |
-
self.multi_scale_fusion_heads.append(nn.Identity())
|
1486 |
-
|
1487 |
-
def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
1488 |
-
if isinstance(head, nn.Identity):
|
1489 |
-
return hidden_states
|
1490 |
-
|
1491 |
-
batch_size, num_mask_units = hidden_states.shape[0:2]
|
1492 |
-
# From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
|
1493 |
-
# To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
|
1494 |
-
permute = [0] + [len(hidden_states.shape) - 2] + list(range(1, len(hidden_states.shape) - 2))
|
1495 |
-
hidden_states = hidden_states.reshape(batch_size * num_mask_units, *hidden_states.shape[2:])
|
1496 |
-
hidden_states = hidden_states.permute(permute)
|
1497 |
-
hidden_states = head(hidden_states)
|
1498 |
-
|
1499 |
-
# Restore original layout
|
1500 |
-
permute = [0] + list(range(2, len(hidden_states.shape))) + [1]
|
1501 |
-
hidden_states = hidden_states.permute(permute)
|
1502 |
-
hidden_states = hidden_states.reshape(
|
1503 |
-
batch_size, num_mask_units, *hidden_states.shape[1:-1], hidden_states.shape[-1]
|
1504 |
-
)
|
1505 |
-
return hidden_states
|
1506 |
-
|
1507 |
-
def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor:
|
1508 |
-
# Multi-scale fusion
|
1509 |
-
hidden_states = 0.0
|
1510 |
-
for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
|
1511 |
-
hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
|
1512 |
-
|
1513 |
-
return hidden_states
|
1514 |
-
|
1515 |
-
|
1516 |
-
@add_start_docstrings(
|
1517 |
-
"""The Hiera Model transformer with the decoder on top for self-supervised pre-training.
|
1518 |
-
|
1519 |
-
<Tip>
|
1520 |
-
|
1521 |
-
Note that we provide a script to pre-train this model on custom data in our [examples
|
1522 |
-
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
1523 |
-
|
1524 |
-
</Tip>
|
1525 |
-
""",
|
1526 |
-
HIERA_START_DOCSTRING,
|
1527 |
-
)
|
1528 |
-
class HieraForPreTraining(HieraPreTrainedModel):
|
1529 |
-
def __init__(self, config: HieraConfig) -> None:
|
1530 |
-
super().__init__(config)
|
1531 |
-
# Encoder
|
1532 |
-
self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
|
1533 |
-
self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
|
1534 |
-
# Multi-scale fusion heads
|
1535 |
-
self.multiscale_fusion = HieraMultiScaleHead(config)
|
1536 |
-
# Decoder
|
1537 |
-
self.decoder = HieraDecoder(config)
|
1538 |
-
self.pred_stride = self.decoder.pred_stride
|
1539 |
-
|
1540 |
-
# Initialize weights and apply final processing
|
1541 |
-
self.post_init()
|
1542 |
-
|
1543 |
-
def get_pixel_label_2d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
1544 |
-
# mask (boolean tensor): True means *masked*
|
1545 |
-
pixel_values = pixel_values.permute(0, 2, 3, 1)
|
1546 |
-
|
1547 |
-
size = self.pred_stride
|
1548 |
-
label = pixel_values.unfold(1, size, size).unfold(2, size, size)
|
1549 |
-
label = label.flatten(1, 2).flatten(2)
|
1550 |
-
label = label[mask.bool()]
|
1551 |
-
if self.config.norm_pix_loss:
|
1552 |
-
mean = label.mean(dim=-1, keepdim=True)
|
1553 |
-
var = label.var(dim=-1, keepdim=True)
|
1554 |
-
label = (label - mean) / (var + 1.0e-6) ** 0.5
|
1555 |
-
|
1556 |
-
return label
|
1557 |
-
|
1558 |
-
def get_pixel_label_3d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
1559 |
-
# mask (boolean tensor): True means *masked*
|
1560 |
-
pixel_values = pixel_values[:, :, :: self.patch_stride[0], :, :]
|
1561 |
-
|
1562 |
-
size = self.pred_stride
|
1563 |
-
label = pixel_values.unfold(3, size, size).unfold(4, size, size)
|
1564 |
-
# Different from 2D
|
1565 |
-
label = label.permute(0, 2, 3, 4, 5, 6, 1)
|
1566 |
-
label = label.flatten(1, 3).flatten(2)
|
1567 |
-
label = label[mask.bool()]
|
1568 |
-
if self.config.norm_pix_loss:
|
1569 |
-
mean = label.mean(dim=-1, keepdim=True)
|
1570 |
-
var = label.var(dim=-1, keepdim=True)
|
1571 |
-
label = (label - mean) / (var + 1.0e-6) ** 0.5
|
1572 |
-
|
1573 |
-
return label
|
1574 |
-
|
1575 |
-
def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, mask: torch.BoolTensor):
|
1576 |
-
# We invert the mask such that 1.0 is *masked*
|
1577 |
-
mask = 1 - mask
|
1578 |
-
if len(self.config.query_stride) == 2:
|
1579 |
-
label = self.get_pixel_label_2d(pixel_values, mask)
|
1580 |
-
elif len(self.config.query_stride) == 3:
|
1581 |
-
label = self.get_pixel_label_3d(pixel_values, mask)
|
1582 |
-
else:
|
1583 |
-
raise NotImplementedError("Only images and videos are supported")
|
1584 |
-
|
1585 |
-
logits = logits[mask.bool()]
|
1586 |
-
loss = (logits - label) ** 2
|
1587 |
-
loss = loss.mean()
|
1588 |
-
|
1589 |
-
return loss
|
1590 |
-
|
1591 |
-
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
|
1592 |
-
@replace_return_docstrings(output_type=HieraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1593 |
-
def forward(
|
1594 |
-
self,
|
1595 |
-
pixel_values: Optional[torch.Tensor] = None,
|
1596 |
-
noise: Optional[torch.FloatTensor] = None,
|
1597 |
-
head_mask: Optional[torch.Tensor] = None,
|
1598 |
-
output_attentions: Optional[bool] = None,
|
1599 |
-
output_hidden_states: Optional[bool] = None,
|
1600 |
-
interpolate_pos_encoding: Optional[bool] = None,
|
1601 |
-
return_dict: Optional[bool] = None,
|
1602 |
-
) -> Union[tuple, HieraForPreTrainingOutput]:
|
1603 |
-
r"""
|
1604 |
-
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
|
1605 |
-
mainly used for testing purposes to control randomness and maintain the reproducibility
|
1606 |
-
when is_mae is set to True.
|
1607 |
-
|
1608 |
-
Returns:
|
1609 |
-
|
1610 |
-
Examples:
|
1611 |
-
```python
|
1612 |
-
>>> from transformers import AutoImageProcessor, HieraForPreTraining
|
1613 |
-
>>> import torch
|
1614 |
-
>>> from PIL import Image
|
1615 |
-
>>> import requests
|
1616 |
-
|
1617 |
-
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1618 |
-
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1619 |
-
|
1620 |
-
>>> image_processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
|
1621 |
-
>>> model = HieraForPreTraining.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
|
1622 |
-
|
1623 |
-
>>> inputs = image_processor(images=image, return_tensors="pt")
|
1624 |
-
|
1625 |
-
>>> outputs = model(**inputs)
|
1626 |
-
>>> logits = outputs.logits
|
1627 |
-
>>> list(logits.shape)
|
1628 |
-
[1, 196, 768]
|
1629 |
-
```"""
|
1630 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1631 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1632 |
-
output_hidden_states = (
|
1633 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1634 |
-
)
|
1635 |
-
|
1636 |
-
outputs = self.hiera(
|
1637 |
-
pixel_values,
|
1638 |
-
noise=noise,
|
1639 |
-
head_mask=head_mask,
|
1640 |
-
output_attentions=output_attentions,
|
1641 |
-
output_hidden_states=True,
|
1642 |
-
interpolate_pos_encoding=interpolate_pos_encoding,
|
1643 |
-
return_dict=True,
|
1644 |
-
)
|
1645 |
-
|
1646 |
-
feature_maps = outputs.reshaped_hidden_states
|
1647 |
-
mask = outputs.mask
|
1648 |
-
ids_to_restore = outputs.ids_restore
|
1649 |
-
# Take only the query pooled and last hidden states
|
1650 |
-
feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
|
1651 |
-
fused_hidden_states = self.multiscale_fusion(feature_maps)
|
1652 |
-
fused_hidden_states = self.encoder_norm(fused_hidden_states)
|
1653 |
-
|
1654 |
-
# Reconstruct pixel values
|
1655 |
-
logits, mask = self.decoder(
|
1656 |
-
fused_hidden_states,
|
1657 |
-
mask=mask,
|
1658 |
-
head_mask=head_mask,
|
1659 |
-
output_attentions=output_attentions,
|
1660 |
-
)
|
1661 |
-
|
1662 |
-
loss = self.forward_loss(pixel_values, logits, mask)
|
1663 |
-
|
1664 |
-
if not return_dict:
|
1665 |
-
output = (logits, mask, ids_to_restore)
|
1666 |
-
if output_hidden_states:
|
1667 |
-
output = output + (outputs.hidden_states,)
|
1668 |
-
if output_attentions:
|
1669 |
-
output = output + (outputs.attentions,)
|
1670 |
-
if output_hidden_states:
|
1671 |
-
output = output + (outputs.reshaped_hidden_states,)
|
1672 |
-
return ((loss,) + output) if loss is not None else output
|
1673 |
-
|
1674 |
-
return HieraForPreTrainingOutput(
|
1675 |
-
loss=loss,
|
1676 |
-
logits=logits,
|
1677 |
-
mask=mask,
|
1678 |
-
ids_restore=ids_to_restore,
|
1679 |
-
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1680 |
-
attentions=outputs.attentions,
|
1681 |
-
reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
|
1682 |
-
)
|
1683 |
-
|
1684 |
-
|
1685 |
-
@add_start_docstrings(
|
1686 |
-
"""
|
1687 |
-
Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
|
1688 |
-
average pooling) e.g. for ImageNet.
|
1689 |
-
|
1690 |
-
<Tip>
|
1691 |
-
|
1692 |
-
Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
|
1693 |
-
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
1694 |
-
position embeddings to the higher resolution.
|
1695 |
-
|
1696 |
-
</Tip>
|
1697 |
-
""",
|
1698 |
-
HIERA_START_DOCSTRING,
|
1699 |
-
)
|
1700 |
-
class HieraForImageClassification(HieraPreTrainedModel):
|
1701 |
-
def __init__(self, config: HieraConfig) -> None:
|
1702 |
-
super().__init__(config)
|
1703 |
-
|
1704 |
-
self.num_labels = config.num_labels
|
1705 |
-
self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
|
1706 |
-
|
1707 |
-
# Classifier head
|
1708 |
-
self.classifier = (
|
1709 |
-
nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
1710 |
-
)
|
1711 |
-
|
1712 |
-
# Initialize weights and apply final processing
|
1713 |
-
self.post_init()
|
1714 |
-
|
1715 |
-
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
|
1716 |
-
@add_code_sample_docstrings(
|
1717 |
-
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
1718 |
-
output_type=HieraForImageClassificationOutput,
|
1719 |
-
config_class=_CONFIG_FOR_DOC,
|
1720 |
-
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
1721 |
-
)
|
1722 |
-
def forward(
|
1723 |
-
self,
|
1724 |
-
pixel_values: Optional[torch.Tensor] = None,
|
1725 |
-
head_mask: Optional[torch.Tensor] = None,
|
1726 |
-
labels: Optional[torch.Tensor] = None,
|
1727 |
-
output_attentions: Optional[bool] = None,
|
1728 |
-
output_hidden_states: Optional[bool] = None,
|
1729 |
-
interpolate_pos_encoding: Optional[bool] = None,
|
1730 |
-
return_dict: Optional[bool] = None,
|
1731 |
-
) -> Union[tuple, HieraForImageClassificationOutput]:
|
1732 |
-
r"""
|
1733 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1734 |
-
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
1735 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1736 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1737 |
-
"""
|
1738 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1739 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1740 |
-
output_hidden_states = (
|
1741 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1742 |
-
)
|
1743 |
-
|
1744 |
-
outputs = self.hiera(
|
1745 |
-
pixel_values,
|
1746 |
-
head_mask=head_mask,
|
1747 |
-
output_attentions=output_attentions,
|
1748 |
-
output_hidden_states=output_hidden_states,
|
1749 |
-
interpolate_pos_encoding=interpolate_pos_encoding,
|
1750 |
-
return_dict=return_dict,
|
1751 |
-
)
|
1752 |
-
|
1753 |
-
pooled_output = outputs[1]
|
1754 |
-
|
1755 |
-
logits = self.classifier(pooled_output)
|
1756 |
-
|
1757 |
-
loss = None
|
1758 |
-
if labels is not None:
|
1759 |
-
# move labels to correct device to enable model parallelism
|
1760 |
-
labels = labels.to(logits.device)
|
1761 |
-
if self.config.problem_type is None:
|
1762 |
-
if self.num_labels == 1:
|
1763 |
-
self.config.problem_type = "regression"
|
1764 |
-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1765 |
-
self.config.problem_type = "single_label_classification"
|
1766 |
-
else:
|
1767 |
-
self.config.problem_type = "multi_label_classification"
|
1768 |
-
|
1769 |
-
if self.config.problem_type == "regression":
|
1770 |
-
loss_fct = MSELoss()
|
1771 |
-
if self.num_labels == 1:
|
1772 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1773 |
-
else:
|
1774 |
-
loss = loss_fct(logits, labels)
|
1775 |
-
elif self.config.problem_type == "single_label_classification":
|
1776 |
-
loss_fct = CrossEntropyLoss()
|
1777 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1778 |
-
elif self.config.problem_type == "multi_label_classification":
|
1779 |
-
loss_fct = BCEWithLogitsLoss()
|
1780 |
-
loss = loss_fct(logits, labels)
|
1781 |
-
|
1782 |
-
if not return_dict:
|
1783 |
-
output = (logits,) + outputs[4:]
|
1784 |
-
return ((loss,) + output) if loss is not None else output
|
1785 |
-
|
1786 |
-
return HieraForImageClassificationOutput(
|
1787 |
-
loss=loss,
|
1788 |
-
logits=logits,
|
1789 |
-
hidden_states=outputs.hidden_states,
|
1790 |
-
attentions=outputs.attentions,
|
1791 |
-
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
1792 |
-
)
|
1793 |
-
|
1794 |
-
|
1795 |
-
@add_start_docstrings(
|
1796 |
-
"""
|
1797 |
-
Hiera backbone, to be used with frameworks like DETR and MaskFormer.
|
1798 |
-
""",
|
1799 |
-
HIERA_START_DOCSTRING,
|
1800 |
-
)
|
1801 |
-
class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
|
1802 |
-
def __init__(self, config: HieraConfig):
|
1803 |
-
super().__init__(config)
|
1804 |
-
super()._init_backbone(config)
|
1805 |
-
|
1806 |
-
self.num_features = [config.embed_dim] + [
|
1807 |
-
int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
|
1808 |
-
]
|
1809 |
-
self.embeddings = HieraEmbeddings(config, is_mae=False)
|
1810 |
-
self.encoder = HieraEncoder(config)
|
1811 |
-
|
1812 |
-
# Add layer norms to hidden states of out_features
|
1813 |
-
hidden_states_norms = {}
|
1814 |
-
for stage, num_channels in zip(self._out_features, self.channels):
|
1815 |
-
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
1816 |
-
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
1817 |
-
|
1818 |
-
# Initialize weights and apply final processing
|
1819 |
-
self.post_init()
|
1820 |
-
|
1821 |
-
def get_input_embeddings(self):
|
1822 |
-
return self.embeddings.patch_embeddings
|
1823 |
-
|
1824 |
-
def forward(
|
1825 |
-
self,
|
1826 |
-
pixel_values: torch.Tensor,
|
1827 |
-
output_hidden_states: Optional[bool] = None,
|
1828 |
-
output_attentions: Optional[bool] = None,
|
1829 |
-
return_dict: Optional[bool] = None,
|
1830 |
-
) -> BackboneOutput:
|
1831 |
-
"""
|
1832 |
-
Returns:
|
1833 |
-
|
1834 |
-
Examples:
|
1835 |
-
|
1836 |
-
```python
|
1837 |
-
>>> from transformers import AutoImageProcessor, AutoBackbone
|
1838 |
-
>>> import torch
|
1839 |
-
>>> from PIL import Image
|
1840 |
-
>>> import requests
|
1841 |
-
|
1842 |
-
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1843 |
-
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1844 |
-
|
1845 |
-
>>> processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224")
|
1846 |
-
>>> model = AutoBackbone.from_pretrained(
|
1847 |
-
... "EduardoPacheco/hiera-tiny-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
1848 |
-
... )
|
1849 |
-
|
1850 |
-
>>> inputs = processor(image, return_tensors="pt")
|
1851 |
-
>>> outputs = model(**inputs)
|
1852 |
-
>>> feature_maps = outputs.feature_maps
|
1853 |
-
>>> list(feature_maps[-1].shape)
|
1854 |
-
[1, 768, 7, 7]
|
1855 |
-
```"""
|
1856 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1857 |
-
output_hidden_states = (
|
1858 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1859 |
-
)
|
1860 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1861 |
-
|
1862 |
-
embedding_output, _, _ = self.embeddings(pixel_values)
|
1863 |
-
|
1864 |
-
outputs = self.encoder(
|
1865 |
-
embedding_output,
|
1866 |
-
head_mask=None,
|
1867 |
-
output_attentions=output_attentions,
|
1868 |
-
output_hidden_states=True,
|
1869 |
-
return_dict=True,
|
1870 |
-
)
|
1871 |
-
|
1872 |
-
hidden_states = outputs.reshaped_hidden_states
|
1873 |
-
|
1874 |
-
feature_maps = ()
|
1875 |
-
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
1876 |
-
if stage in self.out_features:
|
1877 |
-
batch_size, height, width, num_channels = hidden_state.shape
|
1878 |
-
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
1879 |
-
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
1880 |
-
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
1881 |
-
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
1882 |
-
feature_maps += (hidden_state,)
|
1883 |
-
|
1884 |
-
if not return_dict:
|
1885 |
-
output = (feature_maps,)
|
1886 |
-
if output_hidden_states:
|
1887 |
-
output += (outputs.hidden_states,)
|
1888 |
-
return output
|
1889 |
-
|
1890 |
-
return BackboneOutput(
|
1891 |
-
feature_maps=feature_maps,
|
1892 |
-
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1893 |
-
attentions=outputs.attentions,
|
1894 |
-
)
|
1895 |
# %%
|
1896 |
|
1897 |
|
@@ -1910,8 +28,8 @@ class PytorchWorker:
|
|
1910 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1911 |
print(f"Using devide: {self.device}")
|
1912 |
|
1913 |
-
image_processor = AutoImageProcessor.from_pretrained("./hiera_model/")
|
1914 |
-
model = HieraForImageClassification.from_pretrained("./hiera_model/", num_labels =1784 ).to(self.device).eval()
|
1915 |
|
1916 |
return model, image_processor
|
1917 |
|
@@ -1922,7 +40,7 @@ class PytorchWorker:
|
|
1922 |
:param image: Input image as numpy array.
|
1923 |
:return: A list with logits and confidences.
|
1924 |
"""
|
1925 |
-
inputs = self.image_processor(images=image, return_tensors="pt")
|
1926 |
outputs = self.model(**inputs)
|
1927 |
logits = outputs.logits
|
1928 |
return logits.tolist()
|
@@ -1968,7 +86,7 @@ if __name__ == "__main__":
|
|
1968 |
model_name=MODEL_NAME
|
1969 |
)
|
1970 |
|
1971 |
-
|
1972 |
# import requests
|
1973 |
# image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
1974 |
# # %%
|
@@ -1982,4 +100,4 @@ if __name__ == "__main__":
|
|
1982 |
# # %%
|
1983 |
# import numpy as np
|
1984 |
# np.argmax(output)
|
1985 |
-
#
|
|
|
8 |
from PIL import Image
|
9 |
import torch
|
10 |
from transformers import AutoImageProcessor
|
11 |
+
from submission.create_model import HieraForImageClassification
|
12 |
#%%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# %%
|
14 |
|
15 |
|
|
|
28 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
print(f"Using devide: {self.device}")
|
30 |
|
31 |
+
image_processor = AutoImageProcessor.from_pretrained("./submission/hiera_model/")
|
32 |
+
model = HieraForImageClassification.from_pretrained("./submission/hiera_model/", num_labels =1784 ).to(self.device).eval()
|
33 |
|
34 |
return model, image_processor
|
35 |
|
|
|
40 |
:param image: Input image as numpy array.
|
41 |
:return: A list with logits and confidences.
|
42 |
"""
|
43 |
+
inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
|
44 |
outputs = self.model(**inputs)
|
45 |
logits = outputs.logits
|
46 |
return logits.tolist()
|
|
|
86 |
model_name=MODEL_NAME
|
87 |
)
|
88 |
|
89 |
+
#%%
|
90 |
# import requests
|
91 |
# image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
92 |
# # %%
|
|
|
100 |
# # %%
|
101 |
# import numpy as np
|
102 |
# np.argmax(output)
|
103 |
+
# %%
|