alex-ht
commited on
Commit
•
894cde2
1
Parent(s):
af366f6
code
Browse files- audio_processing_mllama.py +66 -0
- configuration_llama3.py +112 -0
- mllama_audio_model.py +41 -0
- modeling_llama3.py +285 -0
- preprocessor_config.json +38 -0
- processing_mllama.py +377 -0
audio_processing_mllama.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
import numpy as np
|
4 |
+
import transformers
|
5 |
+
from transformers.tokenization_utils_base import AudioInput
|
6 |
+
from transformers.utils import TensorType
|
7 |
+
from transformers.feature_extraction_utils import BatchFeature
|
8 |
+
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor, Wav2Vec2Config
|
9 |
+
|
10 |
+
|
11 |
+
def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict:
|
12 |
+
if not isinstance(audio_features, list):
|
13 |
+
audio_features = audio_features['audio_features']
|
14 |
+
bs = audio_features.shape[0]
|
15 |
+
for i in range(bs):
|
16 |
+
for j in range(len(audio_features[i])):
|
17 |
+
tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0])
|
18 |
+
text[i] = text[i].replace(audio_token, tgt_token, 1)
|
19 |
+
return text
|
20 |
+
|
21 |
+
def calculate_output_length(length_in, kernel_size, stride=1, padding=0, dilation=1):
|
22 |
+
return (length_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
|
23 |
+
|
24 |
+
def get_num_embeddings(wav_length: int, config: Wav2Vec2Config) -> int:
|
25 |
+
curr_len = wav_length
|
26 |
+
for i in range(config.num_feat_extract_layers):
|
27 |
+
curr_len = calculate_output_length(curr_len, config.conv_kernel[i], stride=config.conv_stride[i])
|
28 |
+
curr_len = calculate_output_length(curr_len, config.adapter_kernel_size, stride=config.adapter_stride)
|
29 |
+
return curr_len + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
|
30 |
+
|
31 |
+
class MllamaAudioFeatureExtractor(Wav2Vec2FeatureExtractor):
|
32 |
+
|
33 |
+
def __call__(
|
34 |
+
self,
|
35 |
+
batch_audio_clips: List[List[AudioInput]],
|
36 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
37 |
+
) -> BatchFeature:
|
38 |
+
audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ]
|
39 |
+
packed_audio_features = self.pack_audio_clips(audio_features)
|
40 |
+
|
41 |
+
encoded_audio_inputs = BatchFeature(
|
42 |
+
data={
|
43 |
+
"audio_features": packed_audio_features,
|
44 |
+
},
|
45 |
+
tensor_type=return_tensors,
|
46 |
+
)
|
47 |
+
|
48 |
+
return encoded_audio_inputs
|
49 |
+
|
50 |
+
def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray:
|
51 |
+
assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension
|
52 |
+
# Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim)
|
53 |
+
batch_size = len(batch_audio_clips)
|
54 |
+
max_num_clips = max([len(clips) for clips in batch_audio_clips])
|
55 |
+
max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips])
|
56 |
+
feature_dim = batch_audio_clips[0][0].shape[1]
|
57 |
+
|
58 |
+
stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32)
|
59 |
+
for i, clips in enumerate(batch_audio_clips):
|
60 |
+
for j, clip in enumerate(clips):
|
61 |
+
stacked_audio_clips[i, j, :clip.shape[0], :] = clip
|
62 |
+
|
63 |
+
return stacked_audio_clips
|
64 |
+
|
65 |
+
AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor)
|
66 |
+
transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor
|
configuration_llama3.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Mllama model configuration"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
from typing import Dict, List, Optional, Union
|
18 |
+
|
19 |
+
import transformers
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
22 |
+
from transformers.utils import logging
|
23 |
+
from transformers import Wav2Vec2Config, AutoConfig
|
24 |
+
from transformers.models.mllama.configuration_mllama import MllamaVisionConfig, MllamaTextConfig
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class Llama3Config(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an
|
32 |
+
Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
33 |
+
with the defaults will yield a similar configuration to that of the Mllama-9B.
|
34 |
+
|
35 |
+
e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`):
|
42 |
+
The config object or dictionary of the vision backbone.
|
43 |
+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`):
|
44 |
+
The config object or dictionary of the text backbone.
|
45 |
+
image_token_index (`int`, *optional*, defaults to 128256):
|
46 |
+
The image token index to encode the image prompt.
|
47 |
+
|
48 |
+
Example:
|
49 |
+
|
50 |
+
```python
|
51 |
+
>>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig
|
52 |
+
|
53 |
+
>>> # Initializing a CLIP-vision config
|
54 |
+
>>> vision_config = MllamaVisionConfig()
|
55 |
+
|
56 |
+
>>> # Initializing a Llama config
|
57 |
+
>>> text_config = MllamaTextConfig()
|
58 |
+
|
59 |
+
>>> # Initializing a mllama-11b style configuration
|
60 |
+
>>> configuration = MllamaConfig(vision_config, text_config)
|
61 |
+
|
62 |
+
>>> # Initializing a model from the mllama-11b style configuration
|
63 |
+
>>> model = MllamaForConditionalGeneration(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
|
69 |
+
model_type = "llama3"
|
70 |
+
is_composition = True
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
vision_config=None,
|
75 |
+
text_config=None,
|
76 |
+
audio_config=None,
|
77 |
+
image_token_index=128256,
|
78 |
+
audio_token_index=128257,
|
79 |
+
**kwargs,
|
80 |
+
):
|
81 |
+
if vision_config is None:
|
82 |
+
self.vision_config = MllamaVisionConfig()
|
83 |
+
logger.info("vision_config is None, using default mllama vision config")
|
84 |
+
elif isinstance(vision_config, dict):
|
85 |
+
self.vision_config = MllamaVisionConfig(**vision_config)
|
86 |
+
elif isinstance(vision_config, MllamaVisionConfig):
|
87 |
+
self.vision_config = vision_config
|
88 |
+
|
89 |
+
self.image_token_index = image_token_index
|
90 |
+
|
91 |
+
if audio_config is None:
|
92 |
+
self.audio_config = Wav2Vec2Config()
|
93 |
+
logger.info("audio_config is None, using default mllama audio config")
|
94 |
+
elif isinstance(audio_config, dict):
|
95 |
+
self.audio_config = Wav2Vec2Config(**audio_config)
|
96 |
+
elif isinstance(audio_config, Wav2Vec2Config):
|
97 |
+
self.audio_config = audio_config
|
98 |
+
|
99 |
+
self.audio_token_index = audio_token_index
|
100 |
+
|
101 |
+
if text_config is None:
|
102 |
+
self.text_config = MllamaTextConfig()
|
103 |
+
logger.info("text_config is None, using default mllama text config")
|
104 |
+
elif isinstance(text_config, dict):
|
105 |
+
self.text_config = MllamaTextConfig(**text_config)
|
106 |
+
elif isinstance(text_config, MllamaTextConfig):
|
107 |
+
self.text_config = text_config
|
108 |
+
|
109 |
+
super().__init__(**kwargs)
|
110 |
+
|
111 |
+
AutoConfig.register("llama3", Llama3Config)
|
112 |
+
transformers.Llama3Config = Llama3Config
|
mllama_audio_model.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from transformers.modeling_outputs import BaseModelOutput
|
5 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel
|
6 |
+
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
|
7 |
+
from configuration_llama3 import Llama3Config
|
8 |
+
|
9 |
+
|
10 |
+
class Llama3Embedding(MllamaPreTrainedModel):
|
11 |
+
config_class = Llama3Config
|
12 |
+
base_model_prefix = "audio_model"
|
13 |
+
def __init__(self, config: Llama3Config):
|
14 |
+
super().__init__(config)
|
15 |
+
assert config.audio_config.output_hidden_size == config.text_config.hidden_size
|
16 |
+
self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
|
17 |
+
config.audio_config.add_adapter = False
|
18 |
+
self.audio_model = Wav2Vec2Model(config.audio_config)
|
19 |
+
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
20 |
+
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
21 |
+
self.text_config = config.text_config
|
22 |
+
|
23 |
+
def forward(
|
24 |
+
self,
|
25 |
+
input_ids: torch.LongTensor = None,
|
26 |
+
audio_features: Optional[torch.Tensor] = None,
|
27 |
+
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
|
28 |
+
input_embeddings = self.text_embeddings(input_ids.clamp_min(0).detach())
|
29 |
+
if audio_features is None:
|
30 |
+
return input_embeddings
|
31 |
+
bs, max_num_img, l, d = audio_features.shape
|
32 |
+
audio_embeddings = self.audio_model(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
|
33 |
+
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
|
34 |
+
|
35 |
+
for i in range(bs):
|
36 |
+
for j in range(max_num_img):
|
37 |
+
audio_id = -1 - j
|
38 |
+
if torch.any(input_ids[i] == audio_id):
|
39 |
+
positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
|
40 |
+
input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :, :], self.end_of_audio]), accumulate=False)
|
41 |
+
return input_embeddings
|
modeling_llama3.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel
|
10 |
+
from transformers.generation import GenerationMixin
|
11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
12 |
+
from transformers.utils import logging
|
13 |
+
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
|
14 |
+
from configuration_llama3 import Llama3Config
|
15 |
+
from mllama_audio_model import Llama3Embedding
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
22 |
+
config_class = Llama3Config
|
23 |
+
base_model_prefix = "model"
|
24 |
+
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
|
25 |
+
|
26 |
+
def __init__(self, config: Llama3Config):
|
27 |
+
super().__init__(config)
|
28 |
+
self.vocab_size = config.text_config.vocab_size
|
29 |
+
self.hidden_size = config.text_config.hidden_size
|
30 |
+
self.max_num_tiles = config.vision_config.max_num_tiles
|
31 |
+
self.vision_output_dim = config.vision_config.vision_output_dim
|
32 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
33 |
+
|
34 |
+
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
35 |
+
self.language_model = MllamaForCausalLM._from_config(config.text_config)
|
36 |
+
self.embed_tokens = Llama3Embedding(config)
|
37 |
+
self.multi_modal_projector = nn.Linear(
|
38 |
+
config.vision_config.vision_output_dim,
|
39 |
+
config.text_config.hidden_size,
|
40 |
+
bias=True,
|
41 |
+
)
|
42 |
+
self.post_init()
|
43 |
+
|
44 |
+
def get_input_embeddings(self):
|
45 |
+
return self.embed_tokens.text_embeddings
|
46 |
+
|
47 |
+
def set_input_embeddings(self, value):
|
48 |
+
self.embed_tokens.text_embeddings = value
|
49 |
+
|
50 |
+
def get_output_embeddings(self):
|
51 |
+
return self.language_model.get_output_embeddings()
|
52 |
+
|
53 |
+
def set_output_embeddings(self, new_embeddings):
|
54 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
55 |
+
|
56 |
+
def set_decoder(self, decoder):
|
57 |
+
self.language_model.set_decoder(decoder)
|
58 |
+
|
59 |
+
def get_decoder(self):
|
60 |
+
return self.language_model.get_decoder()
|
61 |
+
|
62 |
+
def tie_weights(self):
|
63 |
+
return self.language_model.tie_weights()
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self,
|
67 |
+
input_ids: Optional[torch.LongTensor] = None,
|
68 |
+
audio_features: Optional[torch.FloatTensor] = None,
|
69 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
70 |
+
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
71 |
+
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
74 |
+
cross_attention_states: Optional[torch.Tensor] = None,
|
75 |
+
position_ids: Optional[torch.LongTensor] = None,
|
76 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
77 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
78 |
+
labels: Optional[torch.LongTensor] = None,
|
79 |
+
use_cache: Optional[bool] = None,
|
80 |
+
output_attentions: Optional[bool] = None,
|
81 |
+
output_hidden_states: Optional[bool] = None,
|
82 |
+
return_dict: Optional[bool] = None,
|
83 |
+
cache_position: Optional[torch.LongTensor] = None,
|
84 |
+
num_logits_to_keep: int = 0,
|
85 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
86 |
+
r"""
|
87 |
+
Args:
|
88 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
89 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
90 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
91 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
92 |
+
|
93 |
+
num_logits_to_keep (`int`, *optional*):
|
94 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
95 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
96 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
97 |
+
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
|
101 |
+
Example:
|
102 |
+
|
103 |
+
```python
|
104 |
+
>>> from PIL import Image
|
105 |
+
>>> import requests
|
106 |
+
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
107 |
+
|
108 |
+
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
109 |
+
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
110 |
+
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
111 |
+
|
112 |
+
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
113 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
114 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
115 |
+
|
116 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
117 |
+
|
118 |
+
>>> # Generate
|
119 |
+
>>> output = model.generate(**inputs, max_new_tokens=15)
|
120 |
+
|
121 |
+
>>> prompt_len = inputs.input_ids.shape[-1]
|
122 |
+
>>> generated_ids = output[:, prompt_len:]
|
123 |
+
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
124 |
+
>>> print(generated_text)
|
125 |
+
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
126 |
+
```
|
127 |
+
"""
|
128 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions
|
129 |
+
output_hidden_states = (
|
130 |
+
output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
|
131 |
+
)
|
132 |
+
return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict
|
133 |
+
|
134 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
135 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
136 |
+
|
137 |
+
if pixel_values is not None and inputs_embeds is not None:
|
138 |
+
raise ValueError(
|
139 |
+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
140 |
+
)
|
141 |
+
|
142 |
+
if pixel_values is not None and cross_attention_states is not None:
|
143 |
+
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
|
144 |
+
|
145 |
+
if pixel_values is not None:
|
146 |
+
if aspect_ratio_ids is None:
|
147 |
+
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
|
148 |
+
# get vision tokens from vision model
|
149 |
+
vision_outputs = self.vision_model(
|
150 |
+
pixel_values=pixel_values,
|
151 |
+
aspect_ratio_ids=aspect_ratio_ids,
|
152 |
+
aspect_ratio_mask=aspect_ratio_mask,
|
153 |
+
output_hidden_states=output_hidden_states,
|
154 |
+
output_attentions=output_attentions,
|
155 |
+
return_dict=return_dict,
|
156 |
+
)
|
157 |
+
cross_attention_states = vision_outputs[0]
|
158 |
+
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
|
159 |
+
-1, cross_attention_states.shape[-2], self.hidden_size
|
160 |
+
)
|
161 |
+
|
162 |
+
if cross_attention_mask is not None:
|
163 |
+
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
|
164 |
+
cross_attention_mask,
|
165 |
+
num_vision_tokens=self.vision_model.num_patches,
|
166 |
+
dtype=self.dtype,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
full_text_row_masked_out_mask = None
|
170 |
+
|
171 |
+
if cross_attention_mask is not None and cache_position is not None:
|
172 |
+
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
173 |
+
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
|
174 |
+
|
175 |
+
if inputs_embeds is None:
|
176 |
+
inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
177 |
+
|
178 |
+
outputs = self.language_model(
|
179 |
+
input_ids=None,
|
180 |
+
attention_mask=attention_mask,
|
181 |
+
position_ids=position_ids,
|
182 |
+
cross_attention_states=cross_attention_states,
|
183 |
+
cross_attention_mask=cross_attention_mask,
|
184 |
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
185 |
+
past_key_values=past_key_values,
|
186 |
+
use_cache=use_cache,
|
187 |
+
inputs_embeds=inputs_embeds,
|
188 |
+
labels=labels,
|
189 |
+
output_hidden_states=output_hidden_states,
|
190 |
+
output_attentions=output_attentions,
|
191 |
+
return_dict=return_dict,
|
192 |
+
cache_position=cache_position,
|
193 |
+
num_logits_to_keep=num_logits_to_keep,
|
194 |
+
)
|
195 |
+
|
196 |
+
return outputs
|
197 |
+
|
198 |
+
def prepare_inputs_for_generation(
|
199 |
+
self,
|
200 |
+
input_ids=None,
|
201 |
+
audio_features=None,
|
202 |
+
inputs_embeds=None,
|
203 |
+
attention_mask=None,
|
204 |
+
position_ids=None,
|
205 |
+
pixel_values=None,
|
206 |
+
aspect_ratio_ids=None,
|
207 |
+
aspect_ratio_mask=None,
|
208 |
+
cross_attention_mask=None,
|
209 |
+
past_key_values=None,
|
210 |
+
use_cache=False,
|
211 |
+
cache_position=None,
|
212 |
+
num_logits_to_keep=None,
|
213 |
+
**kwargs,
|
214 |
+
):
|
215 |
+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
216 |
+
|
217 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
218 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
219 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
220 |
+
if past_key_values is not None:
|
221 |
+
if inputs_embeds is not None: # Exception 1
|
222 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
223 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
224 |
+
input_ids = input_ids[:, cache_position]
|
225 |
+
|
226 |
+
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
|
227 |
+
if attention_mask is not None and position_ids is None:
|
228 |
+
# create position_ids on the fly for batch generation
|
229 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
230 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
231 |
+
if past_key_values:
|
232 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
233 |
+
|
234 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
235 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
236 |
+
|
237 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
238 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
239 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
240 |
+
else:
|
241 |
+
# The clone here is for the same reason as for `position_ids`.
|
242 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
243 |
+
|
244 |
+
if num_logits_to_keep is not None:
|
245 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
246 |
+
|
247 |
+
model_inputs.update(
|
248 |
+
{
|
249 |
+
"audio_features": audio_features,
|
250 |
+
"position_ids": position_ids,
|
251 |
+
"cache_position": cache_position,
|
252 |
+
"past_key_values": past_key_values,
|
253 |
+
"use_cache": use_cache,
|
254 |
+
"attention_mask": attention_mask,
|
255 |
+
"cross_attention_mask": cross_attention_mask,
|
256 |
+
}
|
257 |
+
)
|
258 |
+
|
259 |
+
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
|
260 |
+
# to compute image hidden states, otherwise they are cached within each cross attn layer
|
261 |
+
if cache_position[0] == 0:
|
262 |
+
model_inputs["pixel_values"] = pixel_values
|
263 |
+
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
|
264 |
+
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
|
265 |
+
|
266 |
+
return model_inputs
|
267 |
+
|
268 |
+
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
269 |
+
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
|
270 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
271 |
+
outputs=outputs,
|
272 |
+
model_kwargs=model_kwargs,
|
273 |
+
is_encoder_decoder=is_encoder_decoder,
|
274 |
+
**kwargs,
|
275 |
+
)
|
276 |
+
|
277 |
+
# add cross-attn mask for new token
|
278 |
+
if cross_attention_mask_prev is not None:
|
279 |
+
model_kwargs["cross_attention_mask"] = torch.cat(
|
280 |
+
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
|
281 |
+
)
|
282 |
+
return model_kwargs
|
283 |
+
|
284 |
+
AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
|
285 |
+
transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
|
preprocessor_config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoFeatureExtractor": "audio_processing_mllama.MllamaAudioFeatureExtractor",
|
4 |
+
"AutoProcessor": "processing_mllama.MllamaProcessor"
|
5 |
+
},
|
6 |
+
"do_convert_rgb": true,
|
7 |
+
"do_normalize": true,
|
8 |
+
"do_pad": true,
|
9 |
+
"do_rescale": true,
|
10 |
+
"do_resize": true,
|
11 |
+
"feature_extractor_type": "MllamaAudioFeatureExtractor",
|
12 |
+
"feature_size": 80,
|
13 |
+
"image_mean": [
|
14 |
+
0.48145466,
|
15 |
+
0.4578275,
|
16 |
+
0.40821073
|
17 |
+
],
|
18 |
+
"image_processor_type": "MllamaImageProcessor",
|
19 |
+
"image_std": [
|
20 |
+
0.26862954,
|
21 |
+
0.26130258,
|
22 |
+
0.27577711
|
23 |
+
],
|
24 |
+
"max_image_tiles": 4,
|
25 |
+
"num_mel_bins": 80,
|
26 |
+
"padding_side": "right",
|
27 |
+
"padding_value": 0.0,
|
28 |
+
"processor_class": "MllamaProcessor",
|
29 |
+
"resample": 2,
|
30 |
+
"rescale_factor": 0.00392156862745098,
|
31 |
+
"return_attention_mask": true,
|
32 |
+
"sampling_rate": 16000,
|
33 |
+
"size": {
|
34 |
+
"height": 560,
|
35 |
+
"width": 560
|
36 |
+
},
|
37 |
+
"stride": 2
|
38 |
+
}
|
processing_mllama.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Processor class for Mllama."""
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import transformers
|
23 |
+
from transformers import AutoProcessor
|
24 |
+
from transformers.feature_extraction_utils import BatchFeature
|
25 |
+
from transformers.image_utils import ImageInput
|
26 |
+
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
|
27 |
+
from transformers.tokenization_utils_base import (
|
28 |
+
PreTokenizedInput,
|
29 |
+
TextInput,
|
30 |
+
AudioInput,
|
31 |
+
)
|
32 |
+
|
33 |
+
# TODO: Can we do it that way or its better include as "Copied from ..."
|
34 |
+
from transformers.models.mllama.image_processing_mllama import make_list_of_images
|
35 |
+
from .audio_processing_mllama import build_audio_tokens
|
36 |
+
|
37 |
+
|
38 |
+
class MllamaImagesKwargs(ImagesKwargs, total=False):
|
39 |
+
max_image_tiles: Optional[int]
|
40 |
+
|
41 |
+
class MllamaProcessorKwargs(ProcessingKwargs, total=False):
|
42 |
+
images_kwargs: MllamaImagesKwargs
|
43 |
+
|
44 |
+
_defaults = {
|
45 |
+
"image_kwargs": {
|
46 |
+
"max_image_tiles": 4,
|
47 |
+
},
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def get_cross_attention_token_mask(input_ids: List[int], image_token_id: int) -> List[List[int]]:
|
52 |
+
"""
|
53 |
+
Generate a cross-attention token mask for image tokens in the input sequence.
|
54 |
+
|
55 |
+
This function identifies the positions of image tokens in the input sequence and creates
|
56 |
+
a mask that defines which subsequent tokens each image token should attend to.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
input_ids (List[int]): A list of token ids representing the input sequence.
|
60 |
+
image_token_id (int): The id of the token used to represent images in the sequence.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
List[List[int]]: A list of [start, end] pairs, where each pair represents the range
|
64 |
+
of tokens an image token should attend to.
|
65 |
+
|
66 |
+
Notes:
|
67 |
+
- If no image tokens are present, an empty list is returned.
|
68 |
+
- For a single image token, it attends to all subsequent tokens until the end of the sequence.
|
69 |
+
- For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
|
70 |
+
- Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
|
71 |
+
"""
|
72 |
+
|
73 |
+
image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
|
74 |
+
|
75 |
+
if len(image_token_locations) == 0:
|
76 |
+
return []
|
77 |
+
|
78 |
+
# only one image present, unmask until end of sequence
|
79 |
+
if len(image_token_locations) == 1:
|
80 |
+
return [[image_token_locations[0], -1]]
|
81 |
+
|
82 |
+
vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
|
83 |
+
|
84 |
+
# last image will attend to all subsequent text
|
85 |
+
vision_masks.append([image_token_locations[-1], len(input_ids)])
|
86 |
+
|
87 |
+
# if there are two or more consecutive vision tokens,
|
88 |
+
# they should all attend to all subsequent
|
89 |
+
# text present
|
90 |
+
last_mask_end = vision_masks[-1][1]
|
91 |
+
for vision_mask in vision_masks[::-1]:
|
92 |
+
if vision_mask[0] == vision_mask[1] - 1:
|
93 |
+
vision_mask[1] = last_mask_end
|
94 |
+
last_mask_end = vision_mask[1]
|
95 |
+
|
96 |
+
return vision_masks
|
97 |
+
|
98 |
+
|
99 |
+
def convert_sparse_cross_attention_mask_to_dense(
|
100 |
+
cross_attention_token_mask: List[List[List[int]]],
|
101 |
+
num_tiles: List[List[int]],
|
102 |
+
max_num_tiles: int,
|
103 |
+
length: int,
|
104 |
+
) -> np.ndarray:
|
105 |
+
"""
|
106 |
+
Convert the cross attention mask indices to a cross attention mask 4D array.
|
107 |
+
|
108 |
+
This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
|
109 |
+
The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
cross_attention_token_mask (List[List[List[int]]]): A nested list structure where:
|
113 |
+
- The outer list represents the batch dimension.
|
114 |
+
- The middle list represents different images within each batch item.
|
115 |
+
- The inner list contains pairs of integers [start, end] representing token ranges for each image.
|
116 |
+
num_tiles (List[List[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
|
117 |
+
max_num_tiles (int): The maximum possible number of tiles.
|
118 |
+
length (int): The total sequence length of the input.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
|
122 |
+
The array contains `1` where attention is allowed and `0` where it is not.
|
123 |
+
|
124 |
+
Note:
|
125 |
+
- Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
|
126 |
+
"""
|
127 |
+
|
128 |
+
batch_size = len(cross_attention_token_mask)
|
129 |
+
max_num_images = max([len(masks) for masks in cross_attention_token_mask])
|
130 |
+
|
131 |
+
cross_attention_mask = np.zeros(
|
132 |
+
shape=(batch_size, length, max_num_images, max_num_tiles),
|
133 |
+
dtype=np.int64,
|
134 |
+
)
|
135 |
+
|
136 |
+
for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
|
137 |
+
for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
|
138 |
+
if len(locations) == 2:
|
139 |
+
start, end = locations
|
140 |
+
end = min(end, length)
|
141 |
+
if end == -1:
|
142 |
+
end = length
|
143 |
+
cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
|
144 |
+
return cross_attention_mask
|
145 |
+
|
146 |
+
|
147 |
+
def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
|
148 |
+
"""
|
149 |
+
Builds a string from the input prompt by adding `bos_token` if not already present.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
prompt (`str`):
|
153 |
+
The input prompt string.
|
154 |
+
bos_token (`str`):
|
155 |
+
The beginning of sentence token to be added.
|
156 |
+
image_token (`str`):
|
157 |
+
The image token used to identify the start of an image sequence.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
str: The modified prompt string with the `bos_token` added if necessary.
|
161 |
+
|
162 |
+
Examples:
|
163 |
+
>>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
|
164 |
+
'<begin_of_text>Hello world'
|
165 |
+
|
166 |
+
>>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
|
167 |
+
'<|image|><begin_of_text>Hello world'
|
168 |
+
|
169 |
+
>>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
|
170 |
+
'<begin_of_text>Hello world'
|
171 |
+
"""
|
172 |
+
|
173 |
+
if bos_token in prompt:
|
174 |
+
return prompt
|
175 |
+
|
176 |
+
num_image_tokens_on_start = 0
|
177 |
+
while prompt.startswith(image_token):
|
178 |
+
prompt = prompt[len(image_token) :]
|
179 |
+
num_image_tokens_on_start += 1
|
180 |
+
|
181 |
+
return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
|
182 |
+
|
183 |
+
|
184 |
+
class MllamaProcessor(ProcessorMixin):
|
185 |
+
r"""
|
186 |
+
Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
|
187 |
+
[`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
|
188 |
+
tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
|
189 |
+
information.
|
190 |
+
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
|
191 |
+
```python
|
192 |
+
from transformers import MllamaProcessor
|
193 |
+
from PIL import Image
|
194 |
+
|
195 |
+
processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
|
196 |
+
|
197 |
+
processor(
|
198 |
+
images=your_pil_image,
|
199 |
+
text=["<|image|>If I had to write a haiku for this one"],
|
200 |
+
images_kwargs = {"size": {"height": 448, "width": 448}},
|
201 |
+
text_kwargs = {"padding": "right"},
|
202 |
+
common_kwargs = {"return_tensors": "pt"},
|
203 |
+
)
|
204 |
+
```
|
205 |
+
|
206 |
+
Args:
|
207 |
+
image_processor ([`MllamaImageProcessor`]):
|
208 |
+
The image processor is a required input.
|
209 |
+
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
210 |
+
The tokenizer is a required input.
|
211 |
+
|
212 |
+
"""
|
213 |
+
|
214 |
+
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
215 |
+
image_processor_class = "MllamaImageProcessor"
|
216 |
+
audio_processor_class = "MllamaAudioFeatureExtractor"
|
217 |
+
tokenizer_class = "PreTrainedTokenizerFast"
|
218 |
+
|
219 |
+
def __init__(self, image_processor, audio_processor, tokenizer):
|
220 |
+
self.image_token = "<|image|>"
|
221 |
+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
222 |
+
self.audio_token = "<|audio|>"
|
223 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
224 |
+
self.python_token = "<|python_tag|>"
|
225 |
+
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
226 |
+
self.bos_token = tokenizer.bos_token
|
227 |
+
self.chat_template = tokenizer.chat_template
|
228 |
+
super().__init__(image_processor, audio_processor, tokenizer)
|
229 |
+
self.tokenizer.add_tokens([f"<|audio_{i}|>" for i in range(1, 50)])
|
230 |
+
|
231 |
+
|
232 |
+
def __call__(
|
233 |
+
self,
|
234 |
+
images: Optional[ImageInput] = None,
|
235 |
+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
236 |
+
audio: Optional[Union[AudioInput, List[AudioInput]]] = None,
|
237 |
+
videos=None,
|
238 |
+
**kwargs: Unpack[MllamaProcessorKwargs],
|
239 |
+
) -> BatchFeature:
|
240 |
+
"""
|
241 |
+
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
|
242 |
+
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
243 |
+
the text. To prepare the image(s), this method forwards the `images` arguments to
|
244 |
+
MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
|
245 |
+
to the docstring of the above two methods for more information.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
249 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
250 |
+
tensor. Both channels-first and channels-last formats are supported.
|
251 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
252 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
253 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
254 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
255 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
256 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
257 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
258 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
259 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
260 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
261 |
+
Returns:
|
262 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
263 |
+
|
264 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
265 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
266 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
267 |
+
`None`).
|
268 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
269 |
+
- **audio_features** -- Audio features extracted using SeamlessM4TFeatureExtractor. Returned when `audio` is not `None`.
|
270 |
+
TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
|
271 |
+
"""
|
272 |
+
if text is None:
|
273 |
+
raise ValueError("You must specify text.")
|
274 |
+
|
275 |
+
output_kwargs = self._merge_kwargs(
|
276 |
+
MllamaProcessorKwargs,
|
277 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
278 |
+
**kwargs,
|
279 |
+
)
|
280 |
+
|
281 |
+
text_kwargs = output_kwargs["text_kwargs"]
|
282 |
+
images_kwargs = output_kwargs["images_kwargs"]
|
283 |
+
common_kwargs = output_kwargs["common_kwargs"]
|
284 |
+
|
285 |
+
data = {}
|
286 |
+
|
287 |
+
if audio is not None:
|
288 |
+
audio_features = self.audio_processor(audio)
|
289 |
+
data.update(audio_features)
|
290 |
+
|
291 |
+
if isinstance(text, str):
|
292 |
+
text = [text]
|
293 |
+
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
294 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
295 |
+
n_images_in_text = [t.count(self.image_token) for t in text]
|
296 |
+
text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
|
297 |
+
_ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
|
298 |
+
|
299 |
+
if audio is not None:
|
300 |
+
text = build_audio_tokens(text, audio_features, self.audio_token)
|
301 |
+
|
302 |
+
encoding = self.tokenizer(text, add_special_tokens=False, **text_kwargs)
|
303 |
+
if audio is not None:
|
304 |
+
beg_audio_id = self.tokenizer.convert_tokens_to_ids("<|audio_1|>")
|
305 |
+
idx = torch.where(encoding['input_ids'] >= beg_audio_id)
|
306 |
+
encoding['input_ids'][idx] = beg_audio_id - encoding['input_ids'][idx] - 1
|
307 |
+
data.update(encoding)
|
308 |
+
|
309 |
+
n_images_in_images = [0]
|
310 |
+
if images is not None:
|
311 |
+
images = make_list_of_images(images)
|
312 |
+
n_images_in_images = [len(sample) for sample in images]
|
313 |
+
|
314 |
+
if text is not None:
|
315 |
+
if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
|
316 |
+
batch_img == 0 for batch_img in n_images_in_text
|
317 |
+
):
|
318 |
+
raise ValueError(
|
319 |
+
"If a batch of text is provided, there should be either no images or at least one image per sample"
|
320 |
+
)
|
321 |
+
if sum(n_images_in_images) != sum(n_images_in_text):
|
322 |
+
if images is None:
|
323 |
+
raise ValueError("No image were provided, but there are image tokens in the prompt")
|
324 |
+
else:
|
325 |
+
raise ValueError(
|
326 |
+
f"The number of image token ({sum(n_images_in_text)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
|
327 |
+
)
|
328 |
+
|
329 |
+
if images is not None:
|
330 |
+
image_features = self.image_processor(images, **images_kwargs)
|
331 |
+
num_tiles = image_features.pop("num_tiles")
|
332 |
+
data.update(image_features)
|
333 |
+
|
334 |
+
# Create cross attention mask
|
335 |
+
if images is not None and text is not None:
|
336 |
+
cross_attention_token_mask = [
|
337 |
+
get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
|
338 |
+
]
|
339 |
+
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
|
340 |
+
cross_attention_token_mask,
|
341 |
+
num_tiles=num_tiles,
|
342 |
+
max_num_tiles=self.image_processor.max_image_tiles,
|
343 |
+
length=max(len(input_ids) for input_ids in encoding["input_ids"]),
|
344 |
+
)
|
345 |
+
data["cross_attention_mask"] = cross_attention_mask
|
346 |
+
|
347 |
+
return_tensors = common_kwargs.pop("return_tensors", None)
|
348 |
+
batch_feature = BatchFeature(data=data, tensor_type=return_tensors)
|
349 |
+
|
350 |
+
return batch_feature
|
351 |
+
|
352 |
+
def batch_decode(self, *args, **kwargs):
|
353 |
+
"""
|
354 |
+
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
355 |
+
refer to the docstring of this method for more information.
|
356 |
+
"""
|
357 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
358 |
+
|
359 |
+
def decode(self, *args, **kwargs):
|
360 |
+
"""
|
361 |
+
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
362 |
+
the docstring of this method for more information.
|
363 |
+
"""
|
364 |
+
return self.tokenizer.decode(*args, **kwargs)
|
365 |
+
|
366 |
+
@property
|
367 |
+
def model_input_names(self):
|
368 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
369 |
+
image_processor_input_names = self.image_processor.model_input_names
|
370 |
+
audio_processor_input_names = self.audio_processor.model_input_names
|
371 |
+
return list(tokenizer_input_names +
|
372 |
+
image_processor_input_names +
|
373 |
+
["cross_attention_mask"] +
|
374 |
+
audio_processor_input_names)
|
375 |
+
|
376 |
+
AutoProcessor.register("MllamaProcessor", MllamaProcessor)
|
377 |
+
transformers.MllamaProcessor = MllamaProcessor
|