from typing import Optional, Tuple, Union import torch from torch import nn from transformers.modeling_outputs import BaseModelOutput from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer from .configuration_llama3 import Llama3Config class Llama3Embedding(MllamaPreTrainedModel): config_class = Llama3Config base_model_prefix = "audio_model" def __init__(self, config: Llama3Config): super().__init__(config) #assert config.audio_config.output_hidden_size * 2 == config.text_config.hidden_size self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id) assert config.audio_config.add_adapter == True self.audio_model = Wav2Vec2Model(config.audio_config) self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True) self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True) self.text_config = config.text_config def forward( self, input_ids: torch.LongTensor = None, audio_features: Optional[torch.Tensor] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: input_embeddings = self.text_embeddings(input_ids.clamp_min(0).detach()) if audio_features is None: return input_embeddings bs, max_num_clip, l = audio_features.shape audio_embeddings = self.audio_model(input_values=audio_features.view((bs*max_num_clip, l)))['last_hidden_state'] audio_embeddings = audio_embeddings.view((bs, max_num_clip, -1, self.start_of_audio.shape[-1])) for i in range(bs): for j in range(max_num_clip): audio_id = -1 - j if torch.any(input_ids[i] == audio_id): positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True) input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :, :], self.end_of_audio]), accumulate=False) return input_embeddings