ClownRat commited on
Commit
0e04069
·
verified ·
1 Parent(s): 592e852

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Videollama3VisionEncoderModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_videollama3_encoder.Videollama3VisionEncoderConfig",
8
+ "AutoModel": "modeling_videollama3_encoder.Videollama3VisionEncoderModel"
9
+ },
10
+ "hidden_act": "gelu_pytorch_tanh",
11
+ "hidden_size": 1152,
12
+ "intermediate_size": 4304,
13
+ "layer_norm_eps": 1e-06,
14
+ "model_type": "videollama3_vision_encoder",
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 27,
18
+ "patch_size": 14,
19
+ "torch_dtype": "bfloat16",
20
+ "transformers_version": "4.46.3"
21
+ }
configuration_videollama3_encoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VideoLLaMA3 vision encoder model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class Videollama3VisionEncoderConfig(PretrainedConfig):
7
+
8
+ model_type = "videollama3_vision_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ hidden_size=768,
13
+ intermediate_size=3072,
14
+ num_hidden_layers=12,
15
+ num_attention_heads=12,
16
+ num_channels=3,
17
+ patch_size=16,
18
+ hidden_act="gelu_pytorch_tanh",
19
+ layer_norm_eps=1e-6,
20
+ attention_dropout=0.0,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(**kwargs)
24
+
25
+ self.hidden_size = hidden_size
26
+ self.intermediate_size = intermediate_size
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.num_channels = num_channels
30
+ self.patch_size = patch_size
31
+ self.attention_dropout = attention_dropout
32
+ self.layer_norm_eps = layer_norm_eps
33
+ self.hidden_act = hidden_act
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:758ae92931ff54c6d278664af3fed5a452f83a2e89f534ab2e3f4ac0c6e9c061
3
+ size 824342816
modeling_videollama3_encoder.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """PyTorch VideoLLaMA3 vision encoder model."""
22
+
23
+ import importlib.util
24
+ import os.path as osp
25
+ import math
26
+ import warnings
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import is_flash_attn_2_available
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_varlen_func
40
+ else:
41
+ flash_attn_varlen_func = None
42
+
43
+ try:
44
+ from .configuration_videollama3_encoder import Videollama3VisionEncoderConfig
45
+ except ImportError:
46
+ spec = importlib.util.spec_from_file_location(
47
+ "configuration_videollama3_encoder",
48
+ osp.join(osp.dirname(__file__), "configuration_videollama3_encoder.py"),
49
+ )
50
+ configuration_videollama3_encoder = importlib.util.module_from_spec(spec)
51
+ spec.loader.exec_module(configuration_videollama3_encoder)
52
+ Videollama3VisionEncoderConfig = getattr(
53
+ configuration_videollama3_encoder,
54
+ "Videollama3VisionEncoderConfig",
55
+ )
56
+
57
+
58
+ def _trunc_normal_(tensor, mean, std, a, b):
59
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
60
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
61
+ def norm_cdf(x):
62
+ # Computes standard normal cumulative distribution function
63
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
64
+
65
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
66
+ warnings.warn(
67
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
68
+ "The distribution of values may be incorrect.",
69
+ stacklevel=2,
70
+ )
71
+
72
+ # Values are generated by using a truncated uniform distribution and
73
+ # then using the inverse CDF for the normal distribution.
74
+ # Get upper and lower cdf values
75
+ l = norm_cdf((a - mean) / std)
76
+ u = norm_cdf((b - mean) / std)
77
+
78
+ # Uniformly fill tensor with values from [l, u], then translate to
79
+ # [2l-1, 2u-1].
80
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
81
+
82
+ # Use inverse cdf transform for normal distribution to get truncated
83
+ # standard normal
84
+ tensor.erfinv_()
85
+
86
+ # Transform to proper mean, std
87
+ tensor.mul_(std * math.sqrt(2.0))
88
+ tensor.add_(mean)
89
+
90
+ # Clamp to ensure it's in the proper range
91
+ tensor.clamp_(min=a, max=b)
92
+
93
+
94
+ def trunc_normal_tf_(
95
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
96
+ ) -> torch.Tensor:
97
+ """Fills the input Tensor with values drawn from a truncated
98
+ normal distribution. The values are effectively drawn from the
99
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
100
+ with values outside :math:`[a, b]` redrawn until they are within
101
+ the bounds. The method used for generating the random values works
102
+ best when :math:`a \\leq \text{mean} \\leq b`.
103
+
104
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
105
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
106
+ and the result is subsequently scaled and shifted by the mean and std args.
107
+
108
+ Args:
109
+ tensor: an n-dimensional `torch.Tensor`
110
+ mean: the mean of the normal distribution
111
+ std: the standard deviation of the normal distribution
112
+ a: the minimum cutoff value
113
+ b: the maximum cutoff value
114
+ """
115
+ with torch.no_grad():
116
+ _trunc_normal_(tensor, 0, 1.0, a, b)
117
+ tensor.mul_(std).add_(mean)
118
+
119
+
120
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
121
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
122
+ if mode == "fan_in":
123
+ denom = fan_in
124
+ elif mode == "fan_out":
125
+ denom = fan_out
126
+ elif mode == "fan_avg":
127
+ denom = (fan_in + fan_out) / 2
128
+
129
+ variance = scale / denom
130
+
131
+ if distribution == "truncated_normal":
132
+ # constant is stddev of standard normal truncated to (-2, 2)
133
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
134
+ elif distribution == "normal":
135
+ with torch.no_grad():
136
+ tensor.normal_(std=math.sqrt(variance))
137
+ elif distribution == "uniform":
138
+ bound = math.sqrt(3 * variance)
139
+ with torch.no_grad():
140
+ tensor.uniform_(-bound, bound)
141
+ else:
142
+ raise ValueError(f"invalid distribution {distribution}")
143
+
144
+
145
+ def lecun_normal_(tensor):
146
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
147
+
148
+
149
+ def default_flax_embed_init(tensor):
150
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
162
+ orig_dtype = tensor.dtype
163
+ tensor = tensor.float()
164
+ cos = freqs.cos()
165
+ sin = freqs.sin()
166
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
167
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
168
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
169
+ output = output.to(orig_dtype)
170
+ return output
171
+
172
+
173
+ class VisionRotaryEmbedding(nn.Module):
174
+
175
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
176
+ super().__init__()
177
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
178
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
179
+
180
+ def forward(self, seqlen: int) -> torch.Tensor:
181
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
182
+ freqs = torch.outer(seq, self.inv_freq)
183
+ return freqs
184
+
185
+
186
+ class Videollama3VisionEmbeddings(nn.Module):
187
+
188
+ def __init__(self, config: Videollama3VisionEncoderConfig):
189
+ super().__init__()
190
+ self.config = config
191
+ self.embed_dim = config.hidden_size
192
+ self.patch_size = config.patch_size
193
+
194
+ self.patch_embedding = nn.Conv2d(
195
+ in_channels=config.num_channels,
196
+ out_channels=self.embed_dim,
197
+ kernel_size=self.patch_size,
198
+ stride=self.patch_size,
199
+ padding="valid",
200
+ )
201
+
202
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203
+ hidden_states = hidden_states.view(
204
+ -1, self.config.num_channels, self.patch_size, self.patch_size
205
+ )
206
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
207
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
208
+ embeddings = patch_embeds.view(-1, self.embed_dim)
209
+
210
+ return embeddings
211
+
212
+
213
+ class VisionAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
217
+ def __init__(self, config):
218
+ super().__init__()
219
+ self.config = config
220
+ self.embed_dim = config.hidden_size
221
+ self.num_heads = config.num_attention_heads
222
+ self.head_dim = self.embed_dim // self.num_heads
223
+ if self.head_dim * self.num_heads != self.embed_dim:
224
+ raise ValueError(
225
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
226
+ f" {self.num_heads})."
227
+ )
228
+ self.scale = self.head_dim**-0.5
229
+ self.dropout = config.attention_dropout
230
+
231
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
232
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
233
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
234
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ cu_seqlens: torch.Tensor,
240
+ rotary_pos_emb: torch.Tensor = None,
241
+ ) -> torch.Tensor:
242
+ """Input shape: Time x Channel"""
243
+
244
+ q_len, _ = hidden_states.size()
245
+
246
+ query_states = self.q_proj(hidden_states)
247
+ key_states = self.k_proj(hidden_states)
248
+ value_states = self.v_proj(hidden_states)
249
+
250
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
251
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
252
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
253
+
254
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
255
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
256
+
257
+ attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
258
+ for i in range(1, len(cu_seqlens)):
259
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
260
+
261
+ query_states = query_states.transpose(0, 1)
262
+ key_states = key_states.transpose(0, 1)
263
+ value_states = value_states.transpose(0, 1)
264
+
265
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
266
+ attn_weights = attn_weights + attention_mask
267
+
268
+ # upcast attention to fp32
269
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
270
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
271
+ attn_output = torch.matmul(attn_weights, value_states)
272
+
273
+ attn_output = attn_output.transpose(0, 1)
274
+ attn_output = attn_output.reshape(q_len, -1)
275
+ attn_output = self.out_proj(attn_output)
276
+
277
+ return attn_output
278
+
279
+
280
+ class VisionFlashAttention2(VisionAttention):
281
+
282
+ def __init__(self, *args, **kwargs):
283
+ super().__init__(*args, **kwargs)
284
+
285
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
286
+ def forward(
287
+ self,
288
+ hidden_states: torch.Tensor,
289
+ cu_seqlens: torch.Tensor,
290
+ rotary_pos_emb: torch.Tensor = None,
291
+ ) -> torch.Tensor:
292
+ q_len, _ = hidden_states.size()
293
+
294
+ query_states = self.q_proj(hidden_states)
295
+ key_states = self.k_proj(hidden_states)
296
+ value_states = self.v_proj(hidden_states)
297
+
298
+ # Flash attention requires the input to have the shape
299
+ # batch_size x seq_length x head_dim x hidden_dim
300
+ # therefore we just need to keep the original shape
301
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
302
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
303
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
304
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
305
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
306
+
307
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
308
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
309
+ q_len, -1
310
+ )
311
+ attn_output = self.out_proj(attn_output)
312
+
313
+ return attn_output
314
+
315
+
316
+ class VisionSdpaAttention(VisionAttention):
317
+
318
+ def forward(
319
+ self,
320
+ hidden_states: torch.Tensor,
321
+ cu_seqlens: torch.Tensor,
322
+ rotary_pos_emb: torch.Tensor = None,
323
+ ) -> torch.Tensor:
324
+ seq_length = hidden_states.shape[0]
325
+ query_states = self.q_proj(hidden_states)
326
+ key_states = self.k_proj(hidden_states)
327
+ value_states = self.v_proj(hidden_states)
328
+
329
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
330
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
331
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
332
+
333
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
334
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
335
+
336
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
337
+ for i in range(1, len(cu_seqlens)):
338
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
339
+
340
+ query_states = query_states.transpose(0, 1)
341
+ key_states = key_states.transpose(0, 1)
342
+ value_states = value_states.transpose(0, 1)
343
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
344
+ attn_output = attn_output.transpose(0, 1)
345
+ attn_output = attn_output.reshape(seq_length, -1)
346
+ attn_output = self.proj(attn_output)
347
+ return attn_output
348
+
349
+
350
+ VISION_ATTENTION_CLASSES = {
351
+ "eager": VisionAttention,
352
+ "flash_attention_2": VisionFlashAttention2,
353
+ "sdpa": VisionSdpaAttention,
354
+ }
355
+
356
+
357
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Videollama3
358
+ class Videollama3VisionMLP(nn.Module):
359
+
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.activation_fn = ACT2FN[config.hidden_act]
364
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
365
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
366
+
367
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
368
+ hidden_states = self.fc1(hidden_states)
369
+ hidden_states = self.activation_fn(hidden_states)
370
+ hidden_states = self.fc2(hidden_states)
371
+ return hidden_states
372
+
373
+
374
+ class Videollama3VisionEncoderLayer(nn.Module):
375
+
376
+ def __init__(self, config: Videollama3VisionEncoderConfig):
377
+ super().__init__()
378
+ self.embed_dim = config.hidden_size
379
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
380
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
381
+ self.mlp = Videollama3VisionMLP(config)
382
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
383
+
384
+ # Ignore copy
385
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
386
+ hidden_states = hidden_states + self.self_attn(
387
+ self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
388
+ )
389
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
390
+ return hidden_states
391
+
392
+
393
+ class Videollama3VisionTransformerEncoder(nn.Module):
394
+
395
+ def __init__(self, config: Videollama3VisionEncoderConfig):
396
+ super().__init__()
397
+ self.config = config
398
+ head_dim = config.hidden_size // config.num_attention_heads
399
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
400
+ self.layers = nn.ModuleList([Videollama3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
401
+ self.gradient_checkpointing = False
402
+
403
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
404
+ pos_ids = []
405
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
406
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
407
+ hpos_ids = hpos_ids.reshape(
408
+ h // merge_size,
409
+ merge_size,
410
+ w // merge_size,
411
+ merge_size,
412
+ )
413
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
414
+ hpos_ids = hpos_ids.flatten()
415
+
416
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
417
+ wpos_ids = wpos_ids.reshape(
418
+ h // merge_size,
419
+ merge_size,
420
+ w // merge_size,
421
+ merge_size,
422
+ )
423
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
424
+ wpos_ids = wpos_ids.flatten()
425
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
426
+
427
+ pos_ids = torch.cat(pos_ids, dim=0)
428
+ max_grid_size = grid_sizes[:, 1:].max()
429
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
430
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
431
+
432
+ return rotary_pos_emb
433
+
434
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
435
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
436
+
437
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
438
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
439
+
440
+ for blk in self.layers:
441
+ if self.gradient_checkpointing and self.training:
442
+ hidden_states = self._gradient_checkpointing_func(
443
+ blk.__call__,
444
+ hidden_states,
445
+ cu_seqlens,
446
+ rotary_pos_emb
447
+ )
448
+ else:
449
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
450
+
451
+ return hidden_states
452
+
453
+
454
+ class Videollama3VisionEncoderModel(PreTrainedModel):
455
+
456
+ config_class = Videollama3VisionEncoderConfig
457
+ base_model_prefix = "videollama3"
458
+ main_input_name = "pixel_values"
459
+ supports_gradient_checkpointing = True
460
+ _no_split_modules = [
461
+ "Videollama3VisionEncoderLayer",
462
+ "Videollama3VisionEmbeddings",
463
+ ]
464
+ _supports_flash_attn_2 = True
465
+ _supports_sdpa = True
466
+
467
+ def __init__(self, config: Videollama3VisionEncoderConfig):
468
+ super().__init__(config=config)
469
+ embed_dim = config.hidden_size
470
+
471
+ self.embeddings = Videollama3VisionEmbeddings(config)
472
+ self.encoder = Videollama3VisionTransformerEncoder(config)
473
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
474
+
475
+ self.post_init()
476
+
477
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None) -> torch.Tensor:
478
+ hidden_states = self.embeddings(pixel_values)
479
+ hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
480
+ hidden_states = self.post_layernorm(hidden_states)
481
+
482
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
483
+ outputs = []
484
+
485
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
486
+ # NOTE: previous implementation, which supports downsampling with any factor
487
+ c = hidden_states.shape[-1]
488
+ hidden_states = hidden_states.view(
489
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
490
+ ).permute(0, 1, 3, 2, 4, 5)
491
+ hidden_states = hidden_states.reshape(
492
+ grid_size[0], grid_size[1], grid_size[2], c
493
+ ).permute(0, 3, 1, 2)
494
+ hidden_states = torch.nn.functional.interpolate(
495
+ hidden_states,
496
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
497
+ mode='bilinear'
498
+ )
499
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
500
+
501
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
502
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
503
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
504
+ # hidden_states = hidden_states.mean(dim=1)
505
+
506
+ outputs.append(hidden_states)
507
+
508
+ return torch.cat(outputs, dim=0)
509
+
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, nn.Embedding):
513
+ default_flax_embed_init(module.weight)
514
+ elif isinstance(module, VisionAttention):
515
+ nn.init.xavier_uniform_(module.q_proj.weight)
516
+ nn.init.xavier_uniform_(module.k_proj.weight)
517
+ nn.init.xavier_uniform_(module.v_proj.weight)
518
+ nn.init.xavier_uniform_(module.out_proj.weight)
519
+ nn.init.zeros_(module.q_proj.bias)
520
+ nn.init.zeros_(module.k_proj.bias)
521
+ nn.init.zeros_(module.v_proj.bias)
522
+ nn.init.zeros_(module.out_proj.bias)
523
+ elif isinstance(module, Videollama3VisionMLP):
524
+ nn.init.xavier_uniform_(module.fc1.weight)
525
+ nn.init.xavier_uniform_(module.fc2.weight)
526
+ nn.init.normal_(module.fc1.bias, std=1e-6)
527
+ nn.init.normal_(module.fc2.bias, std=1e-6)
528
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
529
+ lecun_normal_(module.weight)
530
+ if module.bias is not None:
531
+ nn.init.zeros_(module.bias)
532
+ elif isinstance(module, nn.LayerNorm):
533
+ module.bias.data.zero_()
534
+ module.weight.data.fill_(1.0)