# coding=utf-8 from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn import transformers from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask from .configuration_llama3 import Llama3Config from .mllama_audio_model import Llama3Embedding logger = logging.get_logger(__name__) class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): config_class = Llama3Config base_model_prefix = "model" _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: Llama3Config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.vision_model = MllamaVisionModel._from_config(config.vision_config) self.language_model = MllamaForCausalLM._from_config(config.text_config) self.embed_tokens = Llama3Embedding(config) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) self.post_init() def get_input_embeddings(self): return self.embed_tokens.text_embeddings def set_input_embeddings(self, value): self.embed_tokens.text_embeddings = value def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def tie_weights(self): return self.language_model.tie_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, audio_features: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, aspect_ratio_mask: Optional[torch.Tensor] = None, aspect_ratio_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. num_logits_to_keep (`int`, *optional*): Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, MllamaForConditionalGeneration >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) >>> processor = AutoProcessor.from_pretrained(checkpoint) >>> prompt = "<|image|>If I had to write a haiku for this one" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> # Generate >>> output = model.generate(**inputs, max_new_tokens=15) >>> prompt_len = inputs.input_ids.shape[-1] >>> generated_ids = output[:, prompt_len:] >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) >>> print(generated_text) [', it would be:.\\nA stop sign in Chinatown.\\n'] ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) if pixel_values is not None and cross_attention_states is not None: raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") if pixel_values is not None: if aspect_ratio_ids is None: raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") # get vision tokens from vision model vision_outputs = self.vision_model( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) cross_attention_states = vision_outputs[0] cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size ) if cross_attention_mask is not None: cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( cross_attention_mask, num_vision_tokens=self.vision_model.num_patches, dtype=self.dtype, ) else: full_text_row_masked_out_mask = None if cross_attention_mask is not None and cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, labels=labels, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) return outputs def prepare_inputs_for_generation( self, input_ids=None, audio_features=None, inputs_embeds=None, attention_mask=None, position_ids=None, pixel_values=None, aspect_ratio_ids=None, aspect_ratio_mask=None, cross_attention_mask=None, past_key_values=None, use_cache=False, cache_position=None, num_logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "audio_features": audio_features, "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cross_attention_mask": cross_attention_mask, } ) # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios # to compute image hidden states, otherwise they are cached within each cross attn layer if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["aspect_ratio_ids"] = aspect_ratio_ids model_inputs["aspect_ratio_mask"] = aspect_ratio_mask return model_inputs def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) model_kwargs = super()._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs, ) # add cross-attn mask for new token if cross_attention_mask_prev is not None: model_kwargs["cross_attention_mask"] = torch.cat( [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 ) return model_kwargs AutoModel.register(Llama3Config, Llama3ForConditionalGeneration) transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration