lucasjin commited on
Commit
052bc03
·
verified ·
1 Parent(s): fbe6ebd

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +192 -0
modeling_aimv2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from .configuration_aimv2 import AIMv2Config
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ __all__ = ["AIMv2Model"]
11
+
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+ self.eps = eps
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ output = self._norm(x.float()).type_as(x)
21
+ return output * self.weight
22
+
23
+ def extra_repr(self) -> str:
24
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
25
+
26
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+
30
+ class AIMv2SwiGLUFFN(nn.Module):
31
+ def __init__(self, config: AIMv2Config):
32
+ super().__init__()
33
+ hidden_features = config.intermediate_size
34
+ in_features = config.hidden_size
35
+ bias = config.use_bias
36
+
37
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
38
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
39
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = F.silu(self.fc1(x)) * self.fc3(x)
43
+ x = self.fc2(x)
44
+ return x
45
+
46
+
47
+ class AIMv2PatchEmbed(nn.Module):
48
+ def __init__(self, config: AIMv2Config):
49
+ super().__init__()
50
+ self.proj = nn.Conv2d(
51
+ config.num_channels,
52
+ config.hidden_size,
53
+ kernel_size=(config.patch_size, config.patch_size),
54
+ stride=(config.patch_size, config.patch_size),
55
+ )
56
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.proj(x).flatten(2).transpose(1, 2)
60
+ x = self.norm(x)
61
+ return x
62
+
63
+
64
+ class AIMv2ViTPreprocessor(nn.Module):
65
+ def __init__(self, config: AIMv2Config):
66
+ super().__init__()
67
+ num_patches = (config.image_size // config.patch_size) ** 2
68
+
69
+ self.patchifier = AIMv2PatchEmbed(config)
70
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ tokens = self.patchifier(x)
74
+ _, N, _ = tokens.shape
75
+ pos_embed = self.pos_embed.to(tokens.device)
76
+ tokens = tokens + pos_embed[:, :N]
77
+ return tokens
78
+
79
+
80
+ class AIMv2Attention(nn.Module):
81
+ def __init__(self, config: AIMv2Config):
82
+ super().__init__()
83
+ dim = config.hidden_size
84
+
85
+ self.num_heads = config.num_attention_heads
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
87
+ self.attn_drop = nn.Dropout(config.attention_dropout)
88
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
89
+ self.proj_drop = nn.Dropout(config.projection_dropout)
90
+
91
+ def forward(
92
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ B, N, C = x.shape
95
+ qkv = (
96
+ self.qkv(x)
97
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
98
+ .permute(2, 0, 3, 1, 4)
99
+ )
100
+ q, k, v = qkv.unbind(0)
101
+
102
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
103
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
104
+ x = self.proj(x)
105
+ x = self.proj_drop(x)
106
+ return x
107
+
108
+
109
+ class AIMv2Block(nn.Module):
110
+ def __init__(self, config: AIMv2Config):
111
+ super().__init__()
112
+ self.attn = AIMv2Attention(config)
113
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+ self.mlp = AIMv2SwiGLUFFN(config)
115
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+
117
+ def forward(
118
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
119
+ ) -> torch.Tensor:
120
+ x = x + self.attn(self.norm_1(x), mask)
121
+ x = x + self.mlp(self.norm_2(x))
122
+ return x
123
+
124
+
125
+ class AIMv2Transformer(nn.Module):
126
+ def __init__(self, config: AIMv2Config):
127
+ super().__init__()
128
+ self.blocks = nn.ModuleList(
129
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
130
+ )
131
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
132
+
133
+ def forward(
134
+ self,
135
+ tokens: torch.Tensor,
136
+ mask: Optional[torch.Tensor] = None,
137
+ output_hidden_states: bool = False,
138
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
139
+ hidden_states = () if output_hidden_states else None
140
+ for block in self.blocks:
141
+ tokens = block(tokens, mask)
142
+ if output_hidden_states:
143
+ hidden_states += (tokens,)
144
+ tokens = self.post_trunk_norm(tokens)
145
+ return tokens, hidden_states
146
+
147
+
148
+ class AIMv2PretrainedModel(PreTrainedModel):
149
+ config_class = AIMv2Config
150
+ base_model_prefix = "aimv2"
151
+ main_input_name = "pixel_values"
152
+ _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
153
+ _supports_sdpa = True
154
+
155
+
156
+ class AIMv2Model(AIMv2PretrainedModel):
157
+ def __init__(self, config: AIMv2Config):
158
+ super().__init__(config)
159
+ self.preprocessor = AIMv2ViTPreprocessor(config)
160
+ self.trunk = AIMv2Transformer(config)
161
+
162
+ def forward(
163
+ self,
164
+ pixel_values: torch.Tensor,
165
+ mask: Optional[torch.Tensor] = None,
166
+ output_hidden_states: Optional[bool] = None,
167
+ return_dict: Optional[bool] = None,
168
+ ) -> Union[
169
+ Tuple[torch.Tensor],
170
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
171
+ BaseModelOutputWithNoAttention,
172
+ ]:
173
+ if output_hidden_states is None:
174
+ output_hidden_states = self.config.output_hidden_states
175
+ if return_dict is None:
176
+ return_dict = self.config.use_return_dict
177
+
178
+ x = self.preprocessor(pixel_values)
179
+ x, hidden_states = self.trunk(
180
+ x, mask, output_hidden_states=output_hidden_states
181
+ )
182
+
183
+ if not return_dict:
184
+ res = (x,)
185
+ res += (hidden_states,) if output_hidden_states else ()
186
+ return res
187
+
188
+ return BaseModelOutputWithNoAttention(
189
+ last_hidden_state=x,
190
+ hidden_states=hidden_states,
191
+ )
192
+