alex-ht commited on
Commit
894cde2
1 Parent(s): af366f6
audio_processing_mllama.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, Optional, Union
3
+ import numpy as np
4
+ import transformers
5
+ from transformers.tokenization_utils_base import AudioInput
6
+ from transformers.utils import TensorType
7
+ from transformers.feature_extraction_utils import BatchFeature
8
+ from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor, Wav2Vec2Config
9
+
10
+
11
+ def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict:
12
+ if not isinstance(audio_features, list):
13
+ audio_features = audio_features['audio_features']
14
+ bs = audio_features.shape[0]
15
+ for i in range(bs):
16
+ for j in range(len(audio_features[i])):
17
+ tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0])
18
+ text[i] = text[i].replace(audio_token, tgt_token, 1)
19
+ return text
20
+
21
+ def calculate_output_length(length_in, kernel_size, stride=1, padding=0, dilation=1):
22
+ return (length_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
23
+
24
+ def get_num_embeddings(wav_length: int, config: Wav2Vec2Config) -> int:
25
+ curr_len = wav_length
26
+ for i in range(config.num_feat_extract_layers):
27
+ curr_len = calculate_output_length(curr_len, config.conv_kernel[i], stride=config.conv_stride[i])
28
+ curr_len = calculate_output_length(curr_len, config.adapter_kernel_size, stride=config.adapter_stride)
29
+ return curr_len + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
30
+
31
+ class MllamaAudioFeatureExtractor(Wav2Vec2FeatureExtractor):
32
+
33
+ def __call__(
34
+ self,
35
+ batch_audio_clips: List[List[AudioInput]],
36
+ return_tensors: Optional[Union[str, TensorType]] = None,
37
+ ) -> BatchFeature:
38
+ audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ]
39
+ packed_audio_features = self.pack_audio_clips(audio_features)
40
+
41
+ encoded_audio_inputs = BatchFeature(
42
+ data={
43
+ "audio_features": packed_audio_features,
44
+ },
45
+ tensor_type=return_tensors,
46
+ )
47
+
48
+ return encoded_audio_inputs
49
+
50
+ def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray:
51
+ assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension
52
+ # Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim)
53
+ batch_size = len(batch_audio_clips)
54
+ max_num_clips = max([len(clips) for clips in batch_audio_clips])
55
+ max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips])
56
+ feature_dim = batch_audio_clips[0][0].shape[1]
57
+
58
+ stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32)
59
+ for i, clips in enumerate(batch_audio_clips):
60
+ for j, clip in enumerate(clips):
61
+ stacked_audio_clips[i, j, :clip.shape[0], :] = clip
62
+
63
+ return stacked_audio_clips
64
+
65
+ AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor)
66
+ transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor
configuration_llama3.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Mllama model configuration"""
15
+
16
+ import os
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import transformers
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_rope_utils import rope_config_validation
22
+ from transformers.utils import logging
23
+ from transformers import Wav2Vec2Config, AutoConfig
24
+ from transformers.models.mllama.configuration_mllama import MllamaVisionConfig, MllamaTextConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Llama3Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an
32
+ Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the Mllama-9B.
34
+
35
+ e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+ Args:
41
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`):
42
+ The config object or dictionary of the vision backbone.
43
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`):
44
+ The config object or dictionary of the text backbone.
45
+ image_token_index (`int`, *optional*, defaults to 128256):
46
+ The image token index to encode the image prompt.
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig
52
+
53
+ >>> # Initializing a CLIP-vision config
54
+ >>> vision_config = MllamaVisionConfig()
55
+
56
+ >>> # Initializing a Llama config
57
+ >>> text_config = MllamaTextConfig()
58
+
59
+ >>> # Initializing a mllama-11b style configuration
60
+ >>> configuration = MllamaConfig(vision_config, text_config)
61
+
62
+ >>> # Initializing a model from the mllama-11b style configuration
63
+ >>> model = MllamaForConditionalGeneration(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "llama3"
70
+ is_composition = True
71
+
72
+ def __init__(
73
+ self,
74
+ vision_config=None,
75
+ text_config=None,
76
+ audio_config=None,
77
+ image_token_index=128256,
78
+ audio_token_index=128257,
79
+ **kwargs,
80
+ ):
81
+ if vision_config is None:
82
+ self.vision_config = MllamaVisionConfig()
83
+ logger.info("vision_config is None, using default mllama vision config")
84
+ elif isinstance(vision_config, dict):
85
+ self.vision_config = MllamaVisionConfig(**vision_config)
86
+ elif isinstance(vision_config, MllamaVisionConfig):
87
+ self.vision_config = vision_config
88
+
89
+ self.image_token_index = image_token_index
90
+
91
+ if audio_config is None:
92
+ self.audio_config = Wav2Vec2Config()
93
+ logger.info("audio_config is None, using default mllama audio config")
94
+ elif isinstance(audio_config, dict):
95
+ self.audio_config = Wav2Vec2Config(**audio_config)
96
+ elif isinstance(audio_config, Wav2Vec2Config):
97
+ self.audio_config = audio_config
98
+
99
+ self.audio_token_index = audio_token_index
100
+
101
+ if text_config is None:
102
+ self.text_config = MllamaTextConfig()
103
+ logger.info("text_config is None, using default mllama text config")
104
+ elif isinstance(text_config, dict):
105
+ self.text_config = MllamaTextConfig(**text_config)
106
+ elif isinstance(text_config, MllamaTextConfig):
107
+ self.text_config = text_config
108
+
109
+ super().__init__(**kwargs)
110
+
111
+ AutoConfig.register("llama3", Llama3Config)
112
+ transformers.Llama3Config = Llama3Config
mllama_audio_model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel
6
+ from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
7
+ from configuration_llama3 import Llama3Config
8
+
9
+
10
+ class Llama3Embedding(MllamaPreTrainedModel):
11
+ config_class = Llama3Config
12
+ base_model_prefix = "audio_model"
13
+ def __init__(self, config: Llama3Config):
14
+ super().__init__(config)
15
+ assert config.audio_config.output_hidden_size == config.text_config.hidden_size
16
+ self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
17
+ config.audio_config.add_adapter = False
18
+ self.audio_model = Wav2Vec2Model(config.audio_config)
19
+ self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
20
+ self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
21
+ self.text_config = config.text_config
22
+
23
+ def forward(
24
+ self,
25
+ input_ids: torch.LongTensor = None,
26
+ audio_features: Optional[torch.Tensor] = None,
27
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
28
+ input_embeddings = self.text_embeddings(input_ids.clamp_min(0).detach())
29
+ if audio_features is None:
30
+ return input_embeddings
31
+ bs, max_num_img, l, d = audio_features.shape
32
+ audio_embeddings = self.audio_model(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
33
+ audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
34
+
35
+ for i in range(bs):
36
+ for j in range(max_num_img):
37
+ audio_id = -1 - j
38
+ if torch.any(input_ids[i] == audio_id):
39
+ positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
40
+ input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :, :], self.end_of_audio]), accumulate=False)
41
+ return input_embeddings
modeling_llama3.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+ import transformers
9
+ from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.utils import logging
13
+ from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
14
+ from configuration_llama3 import Llama3Config
15
+ from mllama_audio_model import Llama3Embedding
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
22
+ config_class = Llama3Config
23
+ base_model_prefix = "model"
24
+ _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
25
+
26
+ def __init__(self, config: Llama3Config):
27
+ super().__init__(config)
28
+ self.vocab_size = config.text_config.vocab_size
29
+ self.hidden_size = config.text_config.hidden_size
30
+ self.max_num_tiles = config.vision_config.max_num_tiles
31
+ self.vision_output_dim = config.vision_config.vision_output_dim
32
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
33
+
34
+ self.vision_model = MllamaVisionModel._from_config(config.vision_config)
35
+ self.language_model = MllamaForCausalLM._from_config(config.text_config)
36
+ self.embed_tokens = Llama3Embedding(config)
37
+ self.multi_modal_projector = nn.Linear(
38
+ config.vision_config.vision_output_dim,
39
+ config.text_config.hidden_size,
40
+ bias=True,
41
+ )
42
+ self.post_init()
43
+
44
+ def get_input_embeddings(self):
45
+ return self.embed_tokens.text_embeddings
46
+
47
+ def set_input_embeddings(self, value):
48
+ self.embed_tokens.text_embeddings = value
49
+
50
+ def get_output_embeddings(self):
51
+ return self.language_model.get_output_embeddings()
52
+
53
+ def set_output_embeddings(self, new_embeddings):
54
+ self.language_model.set_output_embeddings(new_embeddings)
55
+
56
+ def set_decoder(self, decoder):
57
+ self.language_model.set_decoder(decoder)
58
+
59
+ def get_decoder(self):
60
+ return self.language_model.get_decoder()
61
+
62
+ def tie_weights(self):
63
+ return self.language_model.tie_weights()
64
+
65
+ def forward(
66
+ self,
67
+ input_ids: Optional[torch.LongTensor] = None,
68
+ audio_features: Optional[torch.FloatTensor] = None,
69
+ pixel_values: Optional[torch.FloatTensor] = None,
70
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
71
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ cross_attention_mask: Optional[torch.Tensor] = None,
74
+ cross_attention_states: Optional[torch.Tensor] = None,
75
+ position_ids: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
77
+ inputs_embeds: Optional[torch.FloatTensor] = None,
78
+ labels: Optional[torch.LongTensor] = None,
79
+ use_cache: Optional[bool] = None,
80
+ output_attentions: Optional[bool] = None,
81
+ output_hidden_states: Optional[bool] = None,
82
+ return_dict: Optional[bool] = None,
83
+ cache_position: Optional[torch.LongTensor] = None,
84
+ num_logits_to_keep: int = 0,
85
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
86
+ r"""
87
+ Args:
88
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
89
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
90
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
91
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
92
+
93
+ num_logits_to_keep (`int`, *optional*):
94
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
95
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
96
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
97
+
98
+
99
+ Returns:
100
+
101
+ Example:
102
+
103
+ ```python
104
+ >>> from PIL import Image
105
+ >>> import requests
106
+ >>> from transformers import AutoProcessor, MllamaForConditionalGeneration
107
+
108
+ >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
109
+ >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
110
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
111
+
112
+ >>> prompt = "<|image|>If I had to write a haiku for this one"
113
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
114
+ >>> image = Image.open(requests.get(url, stream=True).raw)
115
+
116
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
117
+
118
+ >>> # Generate
119
+ >>> output = model.generate(**inputs, max_new_tokens=15)
120
+
121
+ >>> prompt_len = inputs.input_ids.shape[-1]
122
+ >>> generated_ids = output[:, prompt_len:]
123
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
124
+ >>> print(generated_text)
125
+ [', it would be:.\\nA stop sign in Chinatown.\\n']
126
+ ```
127
+ """
128
+ output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions
129
+ output_hidden_states = (
130
+ output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
131
+ )
132
+ return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict
133
+
134
+ if (input_ids is None) ^ (inputs_embeds is not None):
135
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
136
+
137
+ if pixel_values is not None and inputs_embeds is not None:
138
+ raise ValueError(
139
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
140
+ )
141
+
142
+ if pixel_values is not None and cross_attention_states is not None:
143
+ raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
144
+
145
+ if pixel_values is not None:
146
+ if aspect_ratio_ids is None:
147
+ raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
148
+ # get vision tokens from vision model
149
+ vision_outputs = self.vision_model(
150
+ pixel_values=pixel_values,
151
+ aspect_ratio_ids=aspect_ratio_ids,
152
+ aspect_ratio_mask=aspect_ratio_mask,
153
+ output_hidden_states=output_hidden_states,
154
+ output_attentions=output_attentions,
155
+ return_dict=return_dict,
156
+ )
157
+ cross_attention_states = vision_outputs[0]
158
+ cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
159
+ -1, cross_attention_states.shape[-2], self.hidden_size
160
+ )
161
+
162
+ if cross_attention_mask is not None:
163
+ cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
164
+ cross_attention_mask,
165
+ num_vision_tokens=self.vision_model.num_patches,
166
+ dtype=self.dtype,
167
+ )
168
+ else:
169
+ full_text_row_masked_out_mask = None
170
+
171
+ if cross_attention_mask is not None and cache_position is not None:
172
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
173
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
174
+
175
+ if inputs_embeds is None:
176
+ inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
177
+
178
+ outputs = self.language_model(
179
+ input_ids=None,
180
+ attention_mask=attention_mask,
181
+ position_ids=position_ids,
182
+ cross_attention_states=cross_attention_states,
183
+ cross_attention_mask=cross_attention_mask,
184
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
185
+ past_key_values=past_key_values,
186
+ use_cache=use_cache,
187
+ inputs_embeds=inputs_embeds,
188
+ labels=labels,
189
+ output_hidden_states=output_hidden_states,
190
+ output_attentions=output_attentions,
191
+ return_dict=return_dict,
192
+ cache_position=cache_position,
193
+ num_logits_to_keep=num_logits_to_keep,
194
+ )
195
+
196
+ return outputs
197
+
198
+ def prepare_inputs_for_generation(
199
+ self,
200
+ input_ids=None,
201
+ audio_features=None,
202
+ inputs_embeds=None,
203
+ attention_mask=None,
204
+ position_ids=None,
205
+ pixel_values=None,
206
+ aspect_ratio_ids=None,
207
+ aspect_ratio_mask=None,
208
+ cross_attention_mask=None,
209
+ past_key_values=None,
210
+ use_cache=False,
211
+ cache_position=None,
212
+ num_logits_to_keep=None,
213
+ **kwargs,
214
+ ):
215
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
216
+
217
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
218
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
219
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
220
+ if past_key_values is not None:
221
+ if inputs_embeds is not None: # Exception 1
222
+ input_ids = input_ids[:, -cache_position.shape[0] :]
223
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
224
+ input_ids = input_ids[:, cache_position]
225
+
226
+ # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
227
+ if attention_mask is not None and position_ids is None:
228
+ # create position_ids on the fly for batch generation
229
+ position_ids = attention_mask.long().cumsum(-1) - 1
230
+ position_ids.masked_fill_(attention_mask == 0, 1)
231
+ if past_key_values:
232
+ position_ids = position_ids[:, -input_ids.shape[1] :]
233
+
234
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
235
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
236
+
237
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
238
+ if inputs_embeds is not None and cache_position[0] == 0:
239
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
240
+ else:
241
+ # The clone here is for the same reason as for `position_ids`.
242
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
243
+
244
+ if num_logits_to_keep is not None:
245
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
246
+
247
+ model_inputs.update(
248
+ {
249
+ "audio_features": audio_features,
250
+ "position_ids": position_ids,
251
+ "cache_position": cache_position,
252
+ "past_key_values": past_key_values,
253
+ "use_cache": use_cache,
254
+ "attention_mask": attention_mask,
255
+ "cross_attention_mask": cross_attention_mask,
256
+ }
257
+ )
258
+
259
+ # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
260
+ # to compute image hidden states, otherwise they are cached within each cross attn layer
261
+ if cache_position[0] == 0:
262
+ model_inputs["pixel_values"] = pixel_values
263
+ model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
264
+ model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
265
+
266
+ return model_inputs
267
+
268
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
269
+ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
270
+ model_kwargs = super()._update_model_kwargs_for_generation(
271
+ outputs=outputs,
272
+ model_kwargs=model_kwargs,
273
+ is_encoder_decoder=is_encoder_decoder,
274
+ **kwargs,
275
+ )
276
+
277
+ # add cross-attn mask for new token
278
+ if cross_attention_mask_prev is not None:
279
+ model_kwargs["cross_attention_mask"] = torch.cat(
280
+ [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
281
+ )
282
+ return model_kwargs
283
+
284
+ AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
285
+ transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
preprocessor_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoFeatureExtractor": "audio_processing_mllama.MllamaAudioFeatureExtractor",
4
+ "AutoProcessor": "processing_mllama.MllamaProcessor"
5
+ },
6
+ "do_convert_rgb": true,
7
+ "do_normalize": true,
8
+ "do_pad": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "MllamaAudioFeatureExtractor",
12
+ "feature_size": 80,
13
+ "image_mean": [
14
+ 0.48145466,
15
+ 0.4578275,
16
+ 0.40821073
17
+ ],
18
+ "image_processor_type": "MllamaImageProcessor",
19
+ "image_std": [
20
+ 0.26862954,
21
+ 0.26130258,
22
+ 0.27577711
23
+ ],
24
+ "max_image_tiles": 4,
25
+ "num_mel_bins": 80,
26
+ "padding_side": "right",
27
+ "padding_value": 0.0,
28
+ "processor_class": "MllamaProcessor",
29
+ "resample": 2,
30
+ "rescale_factor": 0.00392156862745098,
31
+ "return_attention_mask": true,
32
+ "sampling_rate": 16000,
33
+ "size": {
34
+ "height": 560,
35
+ "width": 560
36
+ },
37
+ "stride": 2
38
+ }
processing_mllama.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Processor class for Mllama."""
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import transformers
23
+ from transformers import AutoProcessor
24
+ from transformers.feature_extraction_utils import BatchFeature
25
+ from transformers.image_utils import ImageInput
26
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
27
+ from transformers.tokenization_utils_base import (
28
+ PreTokenizedInput,
29
+ TextInput,
30
+ AudioInput,
31
+ )
32
+
33
+ # TODO: Can we do it that way or its better include as "Copied from ..."
34
+ from transformers.models.mllama.image_processing_mllama import make_list_of_images
35
+ from .audio_processing_mllama import build_audio_tokens
36
+
37
+
38
+ class MllamaImagesKwargs(ImagesKwargs, total=False):
39
+ max_image_tiles: Optional[int]
40
+
41
+ class MllamaProcessorKwargs(ProcessingKwargs, total=False):
42
+ images_kwargs: MllamaImagesKwargs
43
+
44
+ _defaults = {
45
+ "image_kwargs": {
46
+ "max_image_tiles": 4,
47
+ },
48
+ }
49
+
50
+
51
+ def get_cross_attention_token_mask(input_ids: List[int], image_token_id: int) -> List[List[int]]:
52
+ """
53
+ Generate a cross-attention token mask for image tokens in the input sequence.
54
+
55
+ This function identifies the positions of image tokens in the input sequence and creates
56
+ a mask that defines which subsequent tokens each image token should attend to.
57
+
58
+ Args:
59
+ input_ids (List[int]): A list of token ids representing the input sequence.
60
+ image_token_id (int): The id of the token used to represent images in the sequence.
61
+
62
+ Returns:
63
+ List[List[int]]: A list of [start, end] pairs, where each pair represents the range
64
+ of tokens an image token should attend to.
65
+
66
+ Notes:
67
+ - If no image tokens are present, an empty list is returned.
68
+ - For a single image token, it attends to all subsequent tokens until the end of the sequence.
69
+ - For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
70
+ - Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
71
+ """
72
+
73
+ image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
74
+
75
+ if len(image_token_locations) == 0:
76
+ return []
77
+
78
+ # only one image present, unmask until end of sequence
79
+ if len(image_token_locations) == 1:
80
+ return [[image_token_locations[0], -1]]
81
+
82
+ vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
83
+
84
+ # last image will attend to all subsequent text
85
+ vision_masks.append([image_token_locations[-1], len(input_ids)])
86
+
87
+ # if there are two or more consecutive vision tokens,
88
+ # they should all attend to all subsequent
89
+ # text present
90
+ last_mask_end = vision_masks[-1][1]
91
+ for vision_mask in vision_masks[::-1]:
92
+ if vision_mask[0] == vision_mask[1] - 1:
93
+ vision_mask[1] = last_mask_end
94
+ last_mask_end = vision_mask[1]
95
+
96
+ return vision_masks
97
+
98
+
99
+ def convert_sparse_cross_attention_mask_to_dense(
100
+ cross_attention_token_mask: List[List[List[int]]],
101
+ num_tiles: List[List[int]],
102
+ max_num_tiles: int,
103
+ length: int,
104
+ ) -> np.ndarray:
105
+ """
106
+ Convert the cross attention mask indices to a cross attention mask 4D array.
107
+
108
+ This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
109
+ The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
110
+
111
+ Args:
112
+ cross_attention_token_mask (List[List[List[int]]]): A nested list structure where:
113
+ - The outer list represents the batch dimension.
114
+ - The middle list represents different images within each batch item.
115
+ - The inner list contains pairs of integers [start, end] representing token ranges for each image.
116
+ num_tiles (List[List[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
117
+ max_num_tiles (int): The maximum possible number of tiles.
118
+ length (int): The total sequence length of the input.
119
+
120
+ Returns:
121
+ np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
122
+ The array contains `1` where attention is allowed and `0` where it is not.
123
+
124
+ Note:
125
+ - Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
126
+ """
127
+
128
+ batch_size = len(cross_attention_token_mask)
129
+ max_num_images = max([len(masks) for masks in cross_attention_token_mask])
130
+
131
+ cross_attention_mask = np.zeros(
132
+ shape=(batch_size, length, max_num_images, max_num_tiles),
133
+ dtype=np.int64,
134
+ )
135
+
136
+ for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
137
+ for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
138
+ if len(locations) == 2:
139
+ start, end = locations
140
+ end = min(end, length)
141
+ if end == -1:
142
+ end = length
143
+ cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
144
+ return cross_attention_mask
145
+
146
+
147
+ def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
148
+ """
149
+ Builds a string from the input prompt by adding `bos_token` if not already present.
150
+
151
+ Args:
152
+ prompt (`str`):
153
+ The input prompt string.
154
+ bos_token (`str`):
155
+ The beginning of sentence token to be added.
156
+ image_token (`str`):
157
+ The image token used to identify the start of an image sequence.
158
+
159
+ Returns:
160
+ str: The modified prompt string with the `bos_token` added if necessary.
161
+
162
+ Examples:
163
+ >>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
164
+ '<begin_of_text>Hello world'
165
+
166
+ >>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
167
+ '<|image|><begin_of_text>Hello world'
168
+
169
+ >>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
170
+ '<begin_of_text>Hello world'
171
+ """
172
+
173
+ if bos_token in prompt:
174
+ return prompt
175
+
176
+ num_image_tokens_on_start = 0
177
+ while prompt.startswith(image_token):
178
+ prompt = prompt[len(image_token) :]
179
+ num_image_tokens_on_start += 1
180
+
181
+ return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
182
+
183
+
184
+ class MllamaProcessor(ProcessorMixin):
185
+ r"""
186
+ Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
187
+ [`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
188
+ tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
189
+ information.
190
+ The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
191
+ ```python
192
+ from transformers import MllamaProcessor
193
+ from PIL import Image
194
+
195
+ processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
196
+
197
+ processor(
198
+ images=your_pil_image,
199
+ text=["<|image|>If I had to write a haiku for this one"],
200
+ images_kwargs = {"size": {"height": 448, "width": 448}},
201
+ text_kwargs = {"padding": "right"},
202
+ common_kwargs = {"return_tensors": "pt"},
203
+ )
204
+ ```
205
+
206
+ Args:
207
+ image_processor ([`MllamaImageProcessor`]):
208
+ The image processor is a required input.
209
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
210
+ The tokenizer is a required input.
211
+
212
+ """
213
+
214
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
215
+ image_processor_class = "MllamaImageProcessor"
216
+ audio_processor_class = "MllamaAudioFeatureExtractor"
217
+ tokenizer_class = "PreTrainedTokenizerFast"
218
+
219
+ def __init__(self, image_processor, audio_processor, tokenizer):
220
+ self.image_token = "<|image|>"
221
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
222
+ self.audio_token = "<|audio|>"
223
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
224
+ self.python_token = "<|python_tag|>"
225
+ self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
226
+ self.bos_token = tokenizer.bos_token
227
+ self.chat_template = tokenizer.chat_template
228
+ super().__init__(image_processor, audio_processor, tokenizer)
229
+ self.tokenizer.add_tokens([f"<|audio_{i}|>" for i in range(1, 50)])
230
+
231
+
232
+ def __call__(
233
+ self,
234
+ images: Optional[ImageInput] = None,
235
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
236
+ audio: Optional[Union[AudioInput, List[AudioInput]]] = None,
237
+ videos=None,
238
+ **kwargs: Unpack[MllamaProcessorKwargs],
239
+ ) -> BatchFeature:
240
+ """
241
+ Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
242
+ arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
243
+ the text. To prepare the image(s), this method forwards the `images` arguments to
244
+ MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
245
+ to the docstring of the above two methods for more information.
246
+
247
+ Args:
248
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
249
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
250
+ tensor. Both channels-first and channels-last formats are supported.
251
+ text (`str`, `List[str]`, `List[List[str]]`):
252
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
253
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
254
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
255
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
256
+ If set, will return tensors of a particular framework. Acceptable values are:
257
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
258
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
259
+ - `'np'`: Return NumPy `np.ndarray` objects.
260
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
261
+ Returns:
262
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
263
+
264
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
265
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
266
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
267
+ `None`).
268
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
269
+ - **audio_features** -- Audio features extracted using SeamlessM4TFeatureExtractor. Returned when `audio` is not `None`.
270
+ TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
271
+ """
272
+ if text is None:
273
+ raise ValueError("You must specify text.")
274
+
275
+ output_kwargs = self._merge_kwargs(
276
+ MllamaProcessorKwargs,
277
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
278
+ **kwargs,
279
+ )
280
+
281
+ text_kwargs = output_kwargs["text_kwargs"]
282
+ images_kwargs = output_kwargs["images_kwargs"]
283
+ common_kwargs = output_kwargs["common_kwargs"]
284
+
285
+ data = {}
286
+
287
+ if audio is not None:
288
+ audio_features = self.audio_processor(audio)
289
+ data.update(audio_features)
290
+
291
+ if isinstance(text, str):
292
+ text = [text]
293
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
294
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
295
+ n_images_in_text = [t.count(self.image_token) for t in text]
296
+ text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
297
+ _ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
298
+
299
+ if audio is not None:
300
+ text = build_audio_tokens(text, audio_features, self.audio_token)
301
+
302
+ encoding = self.tokenizer(text, add_special_tokens=False, **text_kwargs)
303
+ if audio is not None:
304
+ beg_audio_id = self.tokenizer.convert_tokens_to_ids("<|audio_1|>")
305
+ idx = torch.where(encoding['input_ids'] >= beg_audio_id)
306
+ encoding['input_ids'][idx] = beg_audio_id - encoding['input_ids'][idx] - 1
307
+ data.update(encoding)
308
+
309
+ n_images_in_images = [0]
310
+ if images is not None:
311
+ images = make_list_of_images(images)
312
+ n_images_in_images = [len(sample) for sample in images]
313
+
314
+ if text is not None:
315
+ if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
316
+ batch_img == 0 for batch_img in n_images_in_text
317
+ ):
318
+ raise ValueError(
319
+ "If a batch of text is provided, there should be either no images or at least one image per sample"
320
+ )
321
+ if sum(n_images_in_images) != sum(n_images_in_text):
322
+ if images is None:
323
+ raise ValueError("No image were provided, but there are image tokens in the prompt")
324
+ else:
325
+ raise ValueError(
326
+ f"The number of image token ({sum(n_images_in_text)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
327
+ )
328
+
329
+ if images is not None:
330
+ image_features = self.image_processor(images, **images_kwargs)
331
+ num_tiles = image_features.pop("num_tiles")
332
+ data.update(image_features)
333
+
334
+ # Create cross attention mask
335
+ if images is not None and text is not None:
336
+ cross_attention_token_mask = [
337
+ get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
338
+ ]
339
+ cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
340
+ cross_attention_token_mask,
341
+ num_tiles=num_tiles,
342
+ max_num_tiles=self.image_processor.max_image_tiles,
343
+ length=max(len(input_ids) for input_ids in encoding["input_ids"]),
344
+ )
345
+ data["cross_attention_mask"] = cross_attention_mask
346
+
347
+ return_tensors = common_kwargs.pop("return_tensors", None)
348
+ batch_feature = BatchFeature(data=data, tensor_type=return_tensors)
349
+
350
+ return batch_feature
351
+
352
+ def batch_decode(self, *args, **kwargs):
353
+ """
354
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
355
+ refer to the docstring of this method for more information.
356
+ """
357
+ return self.tokenizer.batch_decode(*args, **kwargs)
358
+
359
+ def decode(self, *args, **kwargs):
360
+ """
361
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
362
+ the docstring of this method for more information.
363
+ """
364
+ return self.tokenizer.decode(*args, **kwargs)
365
+
366
+ @property
367
+ def model_input_names(self):
368
+ tokenizer_input_names = self.tokenizer.model_input_names
369
+ image_processor_input_names = self.image_processor.model_input_names
370
+ audio_processor_input_names = self.audio_processor.model_input_names
371
+ return list(tokenizer_input_names +
372
+ image_processor_input_names +
373
+ ["cross_attention_mask"] +
374
+ audio_processor_input_names)
375
+
376
+ AutoProcessor.register("MllamaProcessor", MllamaProcessor)
377
+ transformers.MllamaProcessor = MllamaProcessor