""" Some of the code is adapted from: 1. ByteDance's SALMONN (https://github.com/bytedance/SALMONN) 2. Llama-Omni (https://github.com/ictnlp/LLaMA-Omni/) Please follow the copyright of the original projects. """ # ---------------------------------------------------- # import inspect import copy import torch import torch.nn.functional as F from torch import Tensor, device, nn import numpy as np from transformers import ( WhisperFeatureExtractor, WhisperConfig, WhisperModel, PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, ) from transformers.cache_utils import Cache, StaticCache from transformers.generation.utils import ( GenerationConfig, GenerationMode, LogitsProcessorList, StoppingCriteriaList, GenerateOutput, GenerationMixin, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateNonBeamOutput, is_deepspeed_zero3_enabled, is_torchdynamo_compiling, NEED_SETUP_CACHE_CLASSES_MAPPING, QUANT_BACKEND_CLASSES_MAPPING, is_hqq_available, QuantizedCacheConfig, is_quanto_available, DynamicCache, EncoderDecoderCache, ) from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_typhoon2audio import Typhoon2AudioConfig, BEATsConfig # ---------------------------------------------------- # # QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert import math import warnings from typing import Optional, Tuple, Dict, Union, Callable, List import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, ) from transformers.modeling_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.models.bert.configuration_bert import BertConfig # ---------------------------------------------------------- # # BEATs: https://github.com/microsoft/unilm/tree/master/beats from torch.nn import LayerNorm, Parameter import torch.distributed as distributed import torchaudio.compliance.kaldi as ta_kaldi import logging try: from einops import rearrange, repeat except ImportError: pass logger = logging.getLogger(__name__) # ---------------------------------------------------------- # # Speech Decoder from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Unit Vocoder from fairseq.models import BaseFairseqModel from fairseq.models.text_to_speech.codehifigan import CodeGenerator as CodeHiFiGANModel # ---------------------------------------------------------- # import soundfile as sf class GenerationWithCTC(GenerationMixin): @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[ Callable[[int, torch.Tensor], List[int]] ] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, streamer_unit: Optional["BaseStreamer"] = None, streaming_unit_gen=False, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() # Pull this out first, we only use it for stopping criteria tokenizer = kwargs.pop("tokenizer", None) generation_config, model_kwargs = self._prepare_generation_config( generation_config, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) self._validate_assistant(assistant_model) # 2. Set generation parameters if not already defined if synced_gpus is None: if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: synced_gpus = True else: synced_gpus = False logits_processor = ( logits_processor if logits_processor is not None else LogitsProcessorList() ) stopping_criteria = ( stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() ) accepts_attention_mask = "attention_mask" in set( inspect.signature(self.forward).parameters.keys() ) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] device = inputs_tensor.device self._prepare_special_tokens( generation_config, kwargs_has_attention_mask, device=device ) # decoder-only models must use left-padding for batched generation. if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if ( generation_config._pad_token_tensor is not None and batch_size > 1 and len(inputs_tensor.shape) == 2 and torch.sum( inputs_tensor[:, -1] == generation_config._pad_token_tensor ) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) # 4. Define other model kwargs # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": model_kwargs["use_cache"] = True else: model_kwargs["use_cache"] = generation_config.use_cache if ( not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask ): model_kwargs["attention_mask"] = ( self._prepare_attention_mask_for_generation( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor, ) ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) # pm574 else: input_ids = ( inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") ) # elif model_input_name == "input_ids" or "input_ids" in model_kwargs: # input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # elif model_input_name == "inputs_embeds": # input_ids = inputs_tensor # else: # raise Exception("error here") if generation_config.token_healing: input_ids = self.heal_tokens(input_ids, tokenizer) if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = ( kwargs.get("max_length") is None and generation_config.max_length is not None ) has_default_min_length = ( kwargs.get("min_length") is None and generation_config.min_length is not None ) generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) use_dynamic_cache_by_default = False if "mamba" in self.__class__.__name__.lower(): cache_name = "cache_params" else: cache_name = "past_key_values" if generation_config.cache_implementation is not None and ( model_kwargs.get(cache_name) is not None ): raise ValueError( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " "Cache object) is unsupported. Please use only one of the two." ) elif generation_config.cache_implementation is not None: if ( generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING ): if ( generation_config.cache_implementation == "static" and not self._supports_static_cache ): raise ValueError( "This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981" ) model_kwargs[cache_name] = self._get_cache( generation_config.cache_implementation, getattr(generation_config, "num_beams", 1) * batch_size, generation_config.max_length, model_kwargs, ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " "cache, please open an issue." ) cache_config = ( generation_config.cache_config if generation_config.cache_config is not None else QuantizedCacheConfig() ) cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] if cache_config.backend == "quanto" and not is_quanto_available(): raise ImportError( "You need to install `quanto` in order to use KV cache quantization with quanto backend. " "Please install it via with `pip install quanto`" ) elif cache_config.backend == "HQQ" and not is_hqq_available(): raise ImportError( "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " "Please install it via with `pip install hqq`" ) model_kwargs[cache_name] = cache_class(cache_config) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif ( generation_config.cache_implementation is None and self._supports_default_dynamic_cache() ): past = model_kwargs.get(cache_name, None) requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) if past is None: model_kwargs[cache_name] = ( DynamicCache() if not requires_cross_attention_cache else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) use_dynamic_cache_by_default = True elif isinstance(past, tuple): model_kwargs[cache_name] = ( DynamicCache.from_legacy_cache(past) if not requires_cross_attention_cache else EncoderDecoderCache.from_legacy_cache(past) ) use_dynamic_cache_by_default = True self._validate_generated_length( generation_config, input_ids_length, has_default_max_length ) # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if (streamer is not None or streamer_unit is not None) and ( generation_config.num_beams > 1 ): raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.", UserWarning, ) # 8. prepare distribution pre_processing samplers prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs, ) # 10. go into different generation modes if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. prepare logits warper prepared_logits_warper = ( self._get_logits_warper(generation_config, device=input_ids.device) if generation_config.do_sample else None ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) if streaming_unit_gen: return self._sample_streaming_unit( input_ids, logits_processor=prepared_logits_processor, logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, streamer_unit=streamer_unit, **model_kwargs, ) else: return self._sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise NotImplementedError def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any( hasattr(criteria, "eos_token_id") for criteria in stopping_criteria ) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): raise ValueError( "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " f"{logits_warper})." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = ( () if (return_dict_in_generate and output_attentions) else None ) cross_attentions = ( () if (return_dict_in_generate and output_attentions) else None ) decoder_hidden_states = ( () if (return_dict_in_generate and output_hidden_states) else None ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = ( model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None ) encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones( batch_size, dtype=torch.long, device=input_ids.device ) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update( {"output_attentions": output_attentions} if output_attentions else {} ) model_inputs.update( {"output_hidden_states": output_hidden_states} if output_hidden_states else {} ) # forward pass to get next token outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 1 - unfinished_sequences ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria( input_ids, scores ) this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def _sample_streaming_unit( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], streamer_unit: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any( hasattr(criteria, "eos_token_id") for criteria in stopping_criteria ) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): raise ValueError( "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " f"{logits_warper})." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = ( () if (return_dict_in_generate and output_attentions) else None ) cross_attentions = ( () if (return_dict_in_generate and output_attentions) else None ) decoder_hidden_states = ( () if (return_dict_in_generate and output_hidden_states) else None ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = ( model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None ) encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones( batch_size, dtype=torch.long, device=input_ids.device ) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) generated_units = torch.tensor([]) while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update( {"output_attentions": output_attentions} if output_attentions else {} ) model_inputs.update( {"output_hidden_states": output_hidden_states} if output_hidden_states else {} ) # forward pass to get next token outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # speechgen hidden_states = torch.cat( [decoder_hidden_states[0][-1][:, -1:, :]] + [ decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states)) ], dim=1, ) ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) cur_units = ctc_postprocess( ctc_pred, blank=self.model.config.unit_vocab_size ) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 1 - unfinished_sequences ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) if streamer_unit is not None: for i in range(len(generated_units), len(cur_units)): streamer_unit.put(cur_units[i].unsqueeze(0)) generated_units = cur_units model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria( input_ids, scores ) this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def ctc_postprocess(self, tokens, blank): _toks = tokens.squeeze(0).tolist() deduplicated_toks = [ v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1] ] hyp = torch.tensor([v for v in deduplicated_toks if v != blank]) return hyp class Typhoon2AudioForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = Typhoon2AudioConfig _supports_cache_class = True def __init__( self, config, attn_implementation=None, # only for the LLM ): super().__init__(config) # 1. Speech Encoder # 1.1) Whisper Encoder # feature_extractor self.feature_extractor = WhisperFeatureExtractor( feature_size=config.whisper_extractor_feature_size ) # whisper encoder if isinstance(config.whisper, dict): config.whisper = WhisperConfig(**config.whisper) self.speech_encoder = WhisperModel(config.whisper).encoder self.ln_speech = nn.LayerNorm(config.whisper.d_model) # 1.2) BEATs if isinstance(config.beats, dict): config.beats = BEATsConfig(config.beats) self.beats = BEATs(config.beats) self.ln_audio = nn.LayerNorm(config.beats.encoder_embed_dim) # 1.3) Speech QFormer self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( config.speech_qformer_token_num, config.whisper.d_model + config.beats.encoder_embed_dim, config.speech_qformer_layer, ) self.second_per_frame = config.second_per_frame self.second_stride = config.second_stride # 2. LLM (e.g., Llama3) self.llama_model = AutoModelForCausalLM.from_pretrained( config.llama_base_model, attn_implementation=attn_implementation ) # tokenizer self.llama_tokenizer = AutoTokenizer.from_pretrained( config.llama_base_model, use_fast=False ) self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.llama_tokenizer.padding_side = "right" # speech -> LLM projection self.speech_llama_proj = nn.Linear( self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size, ) def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2): encoder_config = BertConfig() encoder_config.num_hidden_layers = num_hidden_layers encoder_config.encoder_width = speech_width encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = 1 encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size), ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens def encode_speech_only(self, audio): # whisper spectrogram = ( self.feature_extractor(audio, return_tensors="pt", sampling_rate=16000) .input_features.to(self.device) .to(self.dtype) ) # [1, 80, 3000] speech_embeds = self.speech_encoder( spectrogram, return_dict=True ).last_hidden_state # beats raw_wav = torch.from_numpy(audio).to(self.device).unsqueeze(0) audio_padding_mask = torch.zeros(raw_wav.shape, device=self.device).bool() audio_embeds, _ = self.beats.extract_features( raw_wav, padding_mask=audio_padding_mask, feature_only=True, torch_dtype=self.dtype, ) # auditory embeds speech_embeds = self.ln_speech(speech_embeds) audio_embeds = self.ln_audio(audio_embeds) audio_embeds = F.pad( audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)) ) speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1) # split frames B, T, C = speech_embeds.shape kernel = round(T * self.second_per_frame / 30.0) stride = round(T * self.second_stride / 30.0) kernel = (1, kernel) stride = (1, stride) speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2) speech_embeds_overlap = F.unfold( speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride ) _, _, L = speech_embeds_overlap.shape speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L) speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1]) speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C) speech_atts = torch.ones( speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device ) # Qformer query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1) query_output = self.speech_Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=speech_embeds, encoder_attention_mask=speech_atts, return_dict=True, ) speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous() return speech_embeds def _get_text_from_content_list(self, content_list: List): for content in content_list: if content["type"] == "text": return content["text"] return "" def _get_audio_from_content_list(self, content_list: List): for content in content_list: if content["type"] == "audio": return f"{content['audio_url']} " return "" def _get_audio_url_from_string(self, content: str): return content.split("")[1].split("")[0] def _filter_only_audio_content(self, content_list: List): return [ self._get_audio_url_from_string(content) for content in content_list if "" in content ] def _split_conversation_by_speech(self, conversation_str: str): intermediate_list = [conversation_str] if "" in conversation_str: result = conversation_str.split("") intermediate_list = [ item + ("" if i < len(result) - 1 else "") for i, item in enumerate(result) ] processed_list = [] for item in intermediate_list: if "" in item: parts = item.split("") file_path = parts[0] remaining_context = ( "" + parts[1] if len(parts) > 1 else "" ) processed_list.extend([file_path, remaining_context]) else: processed_list.append(item) return processed_list def _convert_conv_to_embeds(self, conversation_list: List, speech_embeds: List): embeds = [] speech_embeds_keys = [speech["audio_url"] for speech in speech_embeds] for item in conversation_list: if item in speech_embeds_keys: selected = [ speech["audio"] for speech in speech_embeds if speech["audio_url"] == item ][0] selected = selected.to(self.device) embeds.append(selected) else: tokenized = self.llama_tokenizer( item, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.device) token_embeds = self.llama_model.model.embed_tokens(tokenized) embeds.append(token_embeds) return embeds def encode_speech_with_text(self, conversation: List): converted_conversation = [ f"<|start_header_id|>{msg['role']}<|end_header_id|>\n\n{msg['content'] if not isinstance(msg['content'], list) else self._get_audio_from_content_list(msg['content']) + self._get_text_from_content_list(msg['content'])}<|eot_id|>" for msg in conversation ] conversation_str = ( "".join(converted_conversation) + "<|start_header_id|>assistant<|end_header_id|>\n\n" ) conversation_list = self._split_conversation_by_speech(conversation_str) speech_embeds = [ {"audio_url": audio, "audio": self.encode_speech_only(sf.read(audio)[0])} for audio in self._filter_only_audio_content(converted_conversation) ] bos_embeds = self.llama_model.model.embed_tokens( torch.ones( [1, 1], dtype=torch.long, device=self.device, ) * self.llama_tokenizer.bos_token_id ) embed_list = [bos_embeds] + self._convert_conv_to_embeds( conversation_list, speech_embeds ) embeds = torch.cat(embed_list, dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) return embeds, atts def forward( self, conversation: List, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: # TODO: support batch_size > 1 embeds, atts = self.encode_speech_with_text(conversation) # forward outputs = self.llama_model.forward( inputs_embeds=embeds, attention_mask=atts, labels=labels, return_dict=return_dict, ) return outputs # def forward( # self, # input_ids: torch.LongTensor = None, # attention_mask: Optional[torch.Tensor] = None, # position_ids: Optional[torch.LongTensor] = None, # past_key_values: Optional[List[torch.FloatTensor]] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # labels: Optional[torch.LongTensor] = None, # use_cache: Optional[bool] = None, # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # return_dict: Optional[bool] = None, # cache_position: Optional[torch.LongTensor] = None, # ) -> Union[Tuple, CausalLMOutputWithPast]: # llama_output = self.llama_model.forward( # input_ids=input_ids, # attention_mask=attention_mask, # position_ids=position_ids, # past_key_values=past_key_values, # inputs_embeds=inputs_embeds, # labels=labels, # use_cache=use_cache, # output_attentions=output_attentions, # output_hidden_states=True, # return_dict=return_dict, # cache_position=cache_position, # ) # loss = llama_output.loss # return CausalLMOutputWithPast( # loss=loss, # logits=llama_output.logits, # past_key_values=llama_output.past_key_values, # hidden_states=llama_output.hidden_states, # attentions=llama_output.attentions # ) def generate( self, conversation: List, max_new_tokens=1024, num_beams=1, do_sample=True, top_p=0.9, repetition_penalty=1.0, length_penalty=1.0, temperature=1.0, streamer=None, ) -> str: embeds, atts = self.encode_speech_with_text(conversation) # generate output = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=do_sample, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, attention_mask=atts, bos_token_id=self.llama_tokenizer.bos_token_id, eos_token_id=self.llama_tokenizer.eos_token_id, pad_token_id=self.llama_tokenizer.pad_token_id, streamer=streamer, ) output_text = self.llama_tokenizer.batch_decode( output, add_special_tokens=False, skip_special_tokens=True ) return output_text[0] # ------------------------------------------------------------------------------- # # November 2024 -- multi-turn def init_multiturn( self, system_prompt="<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant named ไต้ฝุ่น. You always answer in Thai.<|eot_id|>", user_prompt_prefix="<|start_header_id|>user<|end_header_id|>\n\n", user_prompt_suffix=" <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", ): self.conversations = [] self.user_prompt_prefix = user_prompt_prefix self.user_prompt_suffix = user_prompt_suffix if system_prompt is not None: embed_tokens = ( self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens ) system_prompt_ids = ( self.llama_tokenizer( system_prompt, return_tensors="pt", add_special_tokens=False ) .to(self.device) .input_ids ) system_prompt_embeds = embed_tokens(system_prompt_ids) self.add_cache(dtype="text:system_prompt", embeds=system_prompt_embeds) print("multi-turn conversation initialized!") def add_cache(self, dtype, embeds): # cache # --> for text, cache content = token embeddings # --> for wav, cache content = speech embeddings self.conversations.append({"dtype": dtype, "embeds": embeds}) def generate_multiturn( self, wav_path, device="cuda:0", max_length=1500, num_beams=4, do_sample=True, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1.0, temperature=1.0, streamer=None, ): embed_tokens = ( self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens ) # prefix: <|start_header_id|>user<|end_header_id|>\n\n user_prompt_prefix_ids = ( self.llama_tokenizer( self.user_prompt_prefix, return_tensors="pt", add_special_tokens=False ) .to(self.device) .input_ids ) user_prompt_prefix_embeds = embed_tokens(user_prompt_prefix_ids) self.add_cache( dtype="text:user_prompt_prefix", embeds=user_prompt_prefix_embeds ) # process the new wav speech_embeds = self.process_wav(wav_path) self.add_cache(dtype="wav:user_input", embeds=speech_embeds) # suffix: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n user_prompt_suffix_ids = ( self.llama_tokenizer( self.user_prompt_suffix, return_tensors="pt", add_special_tokens=False ) .to(self.device) .input_ids ) user_prompt_suffix_embeds = embed_tokens(user_prompt_suffix_ids) self.add_cache( dtype="text:user_prompt_suffix", embeds=user_prompt_suffix_embeds ) # --------------------------------------------------------------------------- # list_of_embeds = [] for em in self.conversations: list_of_embeds.append(em["embeds"]) # for em in self.conversations: print(em['dtype'], em['embeds'].shape) embeds = torch.cat(list_of_embeds, dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) print("seq_length:", embeds.shape[1]) # generate output = self.llama_model.generate( inputs_embeds=embeds, max_length=max_length, num_beams=num_beams, do_sample=do_sample, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, attention_mask=atts, bos_token_id=self.llama_tokenizer.bos_token_id, eos_token_id=self.llama_tokenizer.eos_token_id, pad_token_id=self.llama_tokenizer.pad_token_id, streamer=streamer, ) # add assistant generation output_text = self.llama_tokenizer.batch_decode( output, add_special_tokens=False, skip_special_tokens=True ) assistant_text_ids = ( self.llama_tokenizer( output_text[0] + "<|eot_id|>", return_tensors="pt", add_special_tokens=False, ) .to(self.device) .input_ids ) assistant_text_embeds = embed_tokens(assistant_text_ids) self.add_cache(dtype="text:assistant_generation", embeds=assistant_text_embeds) return output_text[0] class Typhoon2Audio2AudioForConditionalGeneration( Typhoon2AudioForConditionalGeneration, GenerationWithCTC ): config_class = Typhoon2AudioConfig def __init__(self, config): super().__init__(config) """ Initialize 1) speech decoder (llm output representation -> speech unit) 2) unit vocoder (speech unit -> wav) """ self.pretraining_tp = config.pretraining_tp self.speech_generator = SpeechGeneratorCTC(config) self.init_vocoder(config) def init_vocoder(self, config=None, checkpoint_path=None): # separate vocoder initialization as it is supposed to be float32 # other parts should be in float16 if config is None: config = self.config self.vocoder = CodeHiFiGANVocoder( model_cfg=config.vocoder_config, checkpoint_path=checkpoint_path ) self.vocoder.to(self.device) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: llama_output = self.llama_model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, ) loss = llama_output.loss return CausalLMOutputWithPast( loss=loss, logits=llama_output.logits, past_key_values=llama_output.past_key_values, hidden_states=llama_output.hidden_states, attentions=llama_output.attentions, ) @torch.no_grad() def generate( self, # ----------------- # inputs_embeds=None, attention_mask=None, output_hidden_states=True, return_dict_in_generate=True, streaming_unit_gen=False, max_length=8000, # ----------------- # **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: if "conversation" in kwargs and inputs_embeds is None: conversation = kwargs.get("conversation", []) inputs_embeds, attention_mask = self.encode_speech_with_text(conversation) outputs = GenerationWithCTC.generate( self, # position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, streaming_unit_gen=streaming_unit_gen, # typhoon2 (llama3.1) will set this to 20 somehow otherwise max_length=max_length, # ------------------- # bos_token_id=128000, eos_token_id=[128001, 128008, 128009], ) hidden_states = outputs["hidden_states"] hidden_states = torch.cat( [hidden_states[0][-1][:, -1:, :]] + [hidden_states[i][-1] for i in range(1, len(hidden_states))], dim=1, ) ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) # processing output_ids, output_units = outputs.sequences, ctc_pred # text output_text = self.llama_tokenizer.batch_decode( output_ids, add_special_tokens=False, skip_special_tokens=True )[0] # wav output_audio = self.ctc_pred_to_audio(output_units) return {"text": output_text, "unit": output_units, "audio": output_audio} @torch.no_grad() def synthesize_speech( self, text, ): # apply chat template adds (supposed to be) unnecessary tokens # however, this wa applied during training, so it should be added here # in the next version, please consider removing `apply_chat_template` text_ = self.llama_tokenizer.apply_chat_template( [{"role": "assistant", "content": text}], tokenize=False ) inputs = self.llama_tokenizer(text_, return_tensors="pt").to(self.device) outputs = self(**inputs) hidden_states = outputs["hidden_states"][-1] ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) output_audio = self.ctc_pred_to_audio(ctc_pred) return output_audio def ctc_pred_to_audio(self, units): # vocoder if hasattr(self, "vocoder"): units = self.ctc_postprocess(units, blank=self.config.unit_vocab_size) units = [(list(map(int, units.strip().split())))] units_tensor = torch.tensor(units, dtype=torch.int64, device=self.device) audio_arr = self.vocoder({"code": units_tensor}, True) audio_arr = audio_arr.detach().cpu().numpy() else: audio_arr = None return { "array": audio_arr, "sampling_rate": self.config.vocoder_config["sampling_rate"], } def ctc_postprocess(self, tokens, blank): _toks = tokens.squeeze(0).tolist() deduplicated_toks = [ v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1] ] hyp = [v for v in deduplicated_toks if v != blank] hyp = " ".join(list(map(str, hyp))) return hyp def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): # taken from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py """ Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or slicing inputs given the existing cache. See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. `past_key_values`). This function should work as is for most LLMs. """ # 1. Handle BC: model_inputs = {} # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) if self._supports_cache_class: model_inputs["cache_position"] = cache_position # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) elif cache_position is None: past_length = ( past_key_values[0][0].shape[2] if past_key_values is not None else 0 ) cache_position = torch.arange( past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device, ) # 2. Generic cache-dependent input preparation # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) if past_key_values is not None: model_inputs["past_key_values"] = past_key_values if ( inputs_embeds is not None # Exception 1 # Exception 3 or ( is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] ) ): input_ids = input_ids[:, -cache_position.shape[0] :] # Default case (the "else", a no op, is Exception 2) elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] # 3. Prepare base model inputs input_ids_key = ( "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if not self.config.is_encoder_decoder: if inputs_embeds is not None and cache_position[0] == 0: model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: # `clone` calls in this function ensure a consistent stride. See #32227 model_inputs[input_ids_key] = input_ids.clone( memory_format=torch.contiguous_format ) model_inputs["inputs_embeds"] = None else: model_inputs[input_ids_key] = input_ids.clone( memory_format=torch.contiguous_format ) # 4. Create missing `position_ids` on the fly if ( attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) ): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) # placed in kwargs for further processing (see below) kwargs["position_ids"] = position_ids # 5. Slice model inputs if it's an input that should have the same length as `input_ids` for model_input_name in ["position_ids", "token_type_ids"]: model_input = kwargs.get(model_input_name) if model_input is not None: if past_key_values is not None: current_input_length = ( model_inputs["inputs_embeds"].shape[1] if model_inputs["inputs_embeds"] is not None else model_inputs[input_ids_key].shape[1] ) model_input = model_input[:, -current_input_length:] model_input = model_input.clone( memory_format=torch.contiguous_format ) model_inputs[model_input_name] = model_input # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass) if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs[input_ids_key].shape device = model_inputs[input_ids_key].device # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create # the 4D causal mask exists, it should be present in the base model (XXXModel class). base_model = getattr(self, self.base_model_prefix, None) if base_model is None: causal_mask_creation_function = getattr( self, "_prepare_4d_causal_attention_mask_with_cache_position", None ) else: causal_mask_creation_function = getattr( base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None, ) if causal_mask_creation_function is None: logger.warning_once( f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " "writing code, see Llama for an example implementation. If you're a user, please report this " "issue on GitHub." ) else: attention_mask = causal_mask_creation_function( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), dtype=self.dtype, device=device, cache_position=cache_position, batch_size=batch_size, config=self.config, past_key_values=past_key_values, ) if attention_mask is not None: model_inputs["attention_mask"] = attention_mask # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: model_inputs[key] = value # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) model_inputs.pop("labels", None) return model_inputs def _get_logits_warper( self, generation_config: GenerationConfig, device: str, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ # instantiate warpers list warpers = LogitsProcessorList() # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) if generation_config.num_beams > 1: if isinstance(generation_config._eos_token_tensor, list): min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 elif isinstance(generation_config._eos_token_tensor, torch.Tensor): min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 else: min_tokens_to_keep = 2 else: min_tokens_to_keep = 1 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if ( generation_config.temperature is not None and generation_config.temperature != 1.0 ): warpers.append(TemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: warpers.append( TopKLogitsWarper( top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.top_p is not None and generation_config.top_p < 1.0: warpers.append( TopPLogitsWarper( top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.min_p is not None: # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) warpers.append( MinPLogitsWarper( min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep ) ) if ( generation_config.typical_p is not None and generation_config.typical_p < 1.0 ): warpers.append( TypicalLogitsWarper( mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep, ) ) if ( generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0 ): warpers.append( EpsilonLogitsWarper( epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep, ) ) if ( generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0 ): warpers.append( EtaLogitsWarper( epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device, ) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) return warpers # ------------------------------------------------------------------------------------------ # # Speech Decoder Componnt class SpeechGeneratorCTC(nn.Module): def __init__(self, config): super().__init__() n_layers, n_dims, n_heads, n_inter_dims = list( map(int, config.ctc_decoder_config[1:-1].split(",")) ) _config = copy.deepcopy(config) _config.hidden_size = n_dims _config.num_hidden_layers = n_layers _config.num_attention_heads = n_heads _config.num_key_value_heads = n_heads _config.intermediate_size = n_inter_dims _config._attn_implementation = "flash_attention_2" self.upsample_factor = config.ctc_upsample_factor self.input_proj = nn.Linear(config.hidden_size, n_dims) self.layers = nn.ModuleList( [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] ) self.unit_vocab_size = config.unit_vocab_size self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) self.speech_decoder_ignore_index = config.speech_decoder_ignore_index def upsample(self, reps, tgt_units=None): src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) up_lens = src_lens * self.upsample_factor if tgt_units is not None: tgt_lens = tgt_units.ne(self.speech_decoder_ignore_index).long().sum(dim=-1) up_lens = torch.max(up_lens, tgt_lens) reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) padding_mask = self._lengths_to_padding_mask(up_lens) mapped_inputs = self._uniform_assignment(src_lens, up_lens).masked_fill( padding_mask, 0 ) copied_reps = torch.gather( reps, 1, mapped_inputs.unsqueeze(-1).expand(*mapped_inputs.size(), reps.size(-1)), ) copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) position_ids = ( torch.arange(0, max(up_lens)) .unsqueeze(0) .expand(len(reps), -1) .to(device=copied_reps.device) ) return copied_reps, ~padding_mask, position_ids def forward(self, tgt_reps, labels, tgt_units): tgt_label_reps = [] for tgt_rep, label in zip(tgt_reps, labels): tgt_label_reps.append(tgt_rep[label != self.speech_decoder_ignore_index]) hidden_states, attention_mask, position_ids = self.upsample( tgt_label_reps, tgt_units ) hidden_states = self.input_proj(hidden_states) for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_lens = attention_mask.long().sum(dim=-1) ctc_tgt_lens = tgt_units.ne(self.speech_decoder_ignore_index).long().sum(dim=-1) ctc_tgt_mask = ~self._lengths_to_padding_mask(ctc_tgt_lens) ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) ctc_loss = F.ctc_loss( ctc_lprobs.transpose(0, 1), ctc_tgt_flat, ctc_lens, ctc_tgt_lens, reduction="sum", zero_infinity=True, blank=self.unit_vocab_size, ) ctc_loss /= ctc_tgt_lens.sum().item() return ctc_loss def predict(self, tgt_reps): hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) hidden_states = self.input_proj(hidden_states) for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_( ~attention_mask, self.unit_vocab_size ) return ctc_pred def _lengths_to_padding_mask(self, lens): bsz, max_lens = lens.size(0), torch.max(lens).item() mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) return mask def _uniform_assignment(self, src_lens, tgt_lens): tgt_indices = ( torch.arange(torch.max(tgt_lens)) .expand(len(tgt_lens), -1) .to(tgt_lens.device) ) ratio = tgt_lens / src_lens index_t = (tgt_indices / ratio.view(-1, 1)).long() return index_t # Code HiFiGAN # https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/text_to_speech/vocoder.py class CodeHiFiGANVocoder(BaseFairseqModel): def __init__( self, model_cfg: Dict[str, str], checkpoint_path: str = None, fp16: bool = False ) -> None: super().__init__() self.model = CodeHiFiGANModel(model_cfg) if checkpoint_path is not None: self.load_checkpoint(checkpoint_path) self.model.eval() if fp16: self.model.half() self.model.remove_weight_norm() logger.info(f"initialized CodeHiFiGAN checkpoint") def load_checkpoint(self, checkpoint_path: str) -> None: if torch.cuda.is_available(): state_dict = torch.load(checkpoint_path) else: state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.model.load_state_dict(state_dict["generator"]) logger.info(f"loaded CodeHiFiGAN checkpoint from {checkpoint_path}") def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor: assert "code" in x x["dur_prediction"] = dur_prediction # remove invalid code mask = x["code"] >= 0 x["code"] = x["code"][mask].unsqueeze(dim=0) if "f0" in x: f0_up_ratio = x["f0"].size(1) // x["code"].size(1) mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1)) x["f0"] = x["f0"][mask].unsqueeze(dim=0) return self.model(**x).detach().squeeze() # ---------------------------------------------------------------------------------------- # class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) ) self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[ :, past_key_values_length : seq_length + past_key_values_length ].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size" ): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size ) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(-1, 1) position_ids_r = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1 ) positional_embedding = positional_embedding.to( dtype=query_layer.dtype ) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", key_layer, positional_embedding ) attention_scores = ( attention_scores + relative_position_scores_query + relative_position_scores_key ) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) outputs = ( (context_layer, attention_probs) if output_attentions else (context_layer,) ) outputs = outputs + (past_key_value,) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = ( self.self.attention_head_size * self.self.num_attention_heads ) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.layer_num = layer_num if ( self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0 ): self.crossattention = BertAttention( config, is_cross_attention=self.config.add_cross_attention ) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.intermediate_query = BertIntermediate(config) self.output_query = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = ( past_key_value[:2] if past_key_value is not None else None ) self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: assert ( encoder_hidden_states is not None ), "encoder_hidden_states must be given for cross-attention layers" cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = ( outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output,) + outputs outputs = outputs + (present_key_value,) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [BertLayer(config, i) for i in range(config.num_hidden_layers)] ) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = ( () if output_attentions and self.config.add_cross_attention else None ) next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module( *inputs, past_key_value, output_attentions, query_length ) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool, has_query: bool = False, ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = ( seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] ) # add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] if has_query: # UniLM style attention mask causal_mask = torch.cat( [ torch.zeros( (batch_size, prefix_seq_len, seq_length), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=1, ) causal_mask = torch.cat( [ torch.ones( (batch_size, causal_mask.shape[1], prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = ( causal_mask[:, None, :, :] * attention_mask[:, None, None, :] ) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape ) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=self.dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: assert ( query_embeds is not None ), "You have to specify query_embeds when input_ids is None" # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 ) query_length = query_embeds.shape[1] if query_embeds is not None else 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, query_embeds=query_embeds, past_key_values_length=past_key_values_length, ) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: attention_mask = torch.ones( ((batch_size, seq_length + past_key_values_length)), device=device ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if is_decoder: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_ids.shape, device, is_decoder, has_query=(query_embeds is not None), ) else: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 0 ].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [ self.invert_attention_mask(mask) for mask in encoder_attention_mask ] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) sequence_output = encoder_outputs[0] pooled_output = ( self.pooler(sequence_output) if self.pooler is not None else None ) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=True, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=True, reduction="mean", ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). Returns: Example:: >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> config = BertConfig.from_pretrained("bert-base-cased") >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if labels is not None: use_cache = False if past_key_values is not None: query_embeds = None outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) sequence_output = outputs[0] if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1] :, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) lm_loss = loss_fct( shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation( self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs ): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) query_mask = input_ids.new_ones(query_embeds.shape[:-1]) attention_mask = torch.cat([query_mask, attention_mask], dim=-1) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "query_embeds": query_embeds, "attention_mask": attention_mask, "past_key_values": past, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, } def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += ( tuple( past_state.index_select(0, beam_idx) for past_state in layer_past ), ) return reordered_past class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1] :, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else output ) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # ------------------------------------------------------ # class BEATs(nn.Module): def __init__( self, cfg, ) -> None: super().__init__() logger.info(f"BEATs Config: {cfg.__dict__}") self.cfg = cfg self.embed = cfg.embed_dim self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None ) self.input_patch_size = cfg.input_patch_size self.patch_embedding = nn.Conv2d( 1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias, ) self.dropout_input = nn.Dropout(cfg.dropout_input) assert not cfg.deep_norm or not cfg.layer_norm_first self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) if cfg.finetuned_model: self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) else: self.predictor = None def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask def preprocess( self, source: torch.Tensor, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ) -> torch.Tensor: fbanks = [] for waveform in source: waveform = waveform.unsqueeze(0) * 2**15 fbank = ta_kaldi.fbank( waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10, ) fbanks.append(fbank) fbank = torch.stack(fbanks, dim=0) fbank = (fbank - fbank_mean) / (2 * fbank_std) return fbank def extract_features( self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, feature_only=False, torch_dtype=torch.float32, ): fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to( torch_dtype ) if padding_mask is not None: padding_mask = self.forward_padding_mask(fbank, padding_mask) fbank = fbank.unsqueeze(1) features = self.patch_embedding(fbank) features = features.reshape(features.shape[0], features.shape[1], -1) features = features.transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) x = self.dropout_input(features) x, layer_results = self.encoder( x, padding_mask=padding_mask, ) if not feature_only and self.predictor is not None: x = self.predictor_dropout(x) logits = self.predictor(x) if padding_mask is not None and padding_mask.any(): logits[padding_mask] = 0 logits = logits.sum(dim=1) logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( logits ) else: logits = logits.mean(dim=1) lprobs = torch.sigmoid(logits) return lprobs, padding_mask else: return x, padding_mask class TransformerEncoder(nn.Module): def __init__(self, args): super().__init__() self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.pos_conv = nn.Conv1d( self.embedding_dim, self.embedding_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) if hasattr(args, "relative_position_embedding"): self.relative_position_embedding = args.relative_position_embedding self.num_buckets = args.num_buckets self.max_distance = args.max_distance else: self.relative_position_embedding = False self.num_buckets = 0 self.max_distance = 0 self.layers = nn.ModuleList( [ TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, deep_norm=args.deep_norm, has_relative_attention_bias=self.relative_position_embedding, num_buckets=self.num_buckets, max_distance=self.max_distance, gru_rel_pos=args.gru_rel_pos, encoder_layers=args.encoder_layers, ) for i in range(args.encoder_layers) ] ) if self.relative_position_embedding: for i in range(1, args.encoder_layers): del self.layers[i].self_attn.relative_attention_bias self.layers[i].self_attn.relative_attention_bias = self.layers[ 0 ].self_attn.relative_attention_bias self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop self.apply(init_bert_params) if args.deep_norm: deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) for i in range(args.encoder_layers): nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) nn.init.xavier_normal_( self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta ) nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) nn.init.xavier_normal_( self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta ) nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) self.layer_wise_gradient_decay_ratio = getattr( args, "layer_wise_gradient_decay_ratio", 1 ) def forward(self, x, padding_mask=None, layer=None): x, layer_results = self.extract_features(x, padding_mask, layer) if self.layer_norm_first and layer is None: x = self.layer_norm(x) return x, layer_results def extract_features(self, x, padding_mask=None, tgt_layer=None): if padding_mask is not None: x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) layer_results = [] z = None if tgt_layer is not None: layer_results.append((x, z)) r = None pos_bias = None for i, layer in enumerate(self.layers): if self.layer_wise_gradient_decay_ratio != 1.0: x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): x, z, pos_bias = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias, ) if tgt_layer is not None: layer_results.append((x, z)) if i == tgt_layer: r = x break if r is not None: x = r # T x B x C -> B x T x C x = x.transpose(0, 1) return x, layer_results class TransformerSentenceEncoderLayer(nn.Module): def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, deep_norm: bool = False, has_relative_attention_bias: bool = False, num_buckets: int = 0, max_distance: int = 0, rescale_init: bool = False, gru_rel_pos: bool = False, encoder_layers: int = 0, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout self.activation_name = activation_fn self.activation_fn = get_activation_fn(activation_fn) self.self_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True, has_relative_attention_bias=has_relative_attention_bias, num_buckets=num_buckets, max_distance=max_distance, rescale_init=rescale_init, gru_rel_pos=gru_rel_pos, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first self.self_attn_layer_norm = LayerNorm(self.embedding_dim) if self.activation_name == "glu": self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") else: self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.final_layer_norm = LayerNorm(self.embedding_dim) self.deep_norm = deep_norm if self.deep_norm: self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) else: self.deep_norm_alpha = 1 def forward( self, x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, need_weights: bool = False, pos_bias=None, ): residual = x if self.layer_norm_first: x = self.self_attn_layer_norm(x) x, attn, pos_bias = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False, attn_mask=self_attn_mask, position_bias=pos_bias, ) x = self.dropout1(x) x = residual + x residual = x x = self.final_layer_norm(x) if self.activation_name == "glu": x = self.fc1(x) else: x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual + x else: x, attn, pos_bias = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=need_weights, attn_mask=self_attn_mask, position_bias=pos_bias, ) x = self.dropout1(x) x = residual * self.deep_norm_alpha + x x = self.self_attn_layer_norm(x) residual = x if self.activation_name == "glu": x = self.fc1(x) else: x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual * self.deep_norm_alpha + x x = self.final_layer_norm(x) return x, attn, pos_bias class MultiheadAttention(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, has_relative_attention_bias=False, num_buckets=32, max_distance=128, gru_rel_pos=False, rescale_init=False, ): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout_module = nn.Dropout(dropout) self.has_relative_attention_bias = has_relative_attention_bias self.num_buckets = num_buckets self.max_distance = max_distance if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) self.head_dim = embed_dim // num_heads self.q_head_dim = self.head_dim self.k_head_dim = self.head_dim assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ( "Self-attention requires query, key and " "value to be of the same size" ) k_bias = True if rescale_init: k_bias = False k_embed_dim = embed_dim q_embed_dim = embed_dim self.k_proj = quant_noise( nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size ) self.v_proj = quant_noise( nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.q_proj = quant_noise( nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size ) self.out_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self.gru_rel_pos = gru_rel_pos if self.gru_rel_pos: self.grep_linear = nn.Linear(self.q_head_dim, 8) self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) self.reset_parameters() def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) if self.has_relative_attention_bias: nn.init.xavier_normal_(self.relative_attention_bias.weight) def _relative_positions_bucket(self, relative_positions, bidirectional=True): num_buckets = self.num_buckets max_distance = self.max_distance relative_buckets = 0 if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets relative_positions = torch.abs(relative_positions) else: relative_positions = -torch.min( relative_positions, torch.zeros_like(relative_positions) ) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_postion_if_large = max_exact + ( torch.log(relative_positions.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1), ) relative_buckets += torch.where( is_small, relative_positions, relative_postion_if_large ) return relative_buckets def compute_bias(self, query_length, key_length): context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_positions_bucket( relative_position, bidirectional=True ) relative_position_bucket = relative_position_bucket.to( self.relative_attention_bias.weight.device ) values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]) return values def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True is_tpu = query.device.type == "xla" tgt_len, bsz, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): assert key_bsz == bsz assert value is not None assert src_len, bsz == value.shape[:2] if self.has_relative_attention_bias and position_bias is None: position_bias = self.compute_bias(tgt_len, src_len) position_bias = ( position_bias.unsqueeze(0) .repeat(bsz, 1, 1, 1) .view(bsz * self.num_heads, tgt_len, src_len) ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling alpha = 32 q *= 1 / alpha if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.q_head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.k_head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as( key_padding_mask ), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = ( attn_weights - attn_weights.max(dim=-1, keepdim=True)[0] ) * alpha attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v, position_bias if position_bias is not None: attn_mask_rel_pos = position_bias if self.gru_rel_pos == 1: query_layer = ( q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling ) _B, _H, _L, __ = query_layer.size() gate_a, gate_b = torch.sigmoid( self.grep_linear(query_layer) .view(_B, _H, _L, 2, 4) .sum(-1, keepdim=False) ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 attn_mask_rel_pos = ( gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias ) attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) attn_weights = attn_weights + attn_mask_rel_pos attn_weights_float = F.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights, position_bias @staticmethod def _append_prev_key_padding_mask( key_padding_mask: Optional[Tensor], prev_key_padding_mask: Optional[Tensor], batch_size: int, src_len: int, static_kv: bool, ) -> Optional[Tensor]: # saved key padding masks have shape (bsz, seq_len) if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 ) # During incremental decoding, as the padding token enters and # leaves the frame, there will be a time when prev or current # is None elif prev_key_padding_mask is not None: if src_len > prev_key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device, ) new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), filler.float()], dim=1 ) else: new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: if src_len > key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device, ) new_key_padding_mask = torch.cat( [filler.float(), key_padding_mask.float()], dim=1 ) else: new_key_padding_mask = key_padding_mask.float() else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask def _get_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ) -> Dict[str, Optional[Tensor]]: result = self.get_incremental_state(incremental_state, "attn_state") if result is not None: return result else: empty_result: Dict[str, Optional[Tensor]] = {} return empty_result def _set_input_buffer( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], buffer: Dict[str, Optional[Tensor]], ): return self.set_incremental_state(incremental_state, "attn_state", buffer) def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights def init_bert_params(module): """ Initialize the weights specific to the BERT Model. This overrides the default initializations depending on the specified arguments. 1. If normal_init_linear_weights is set then weights of linear layer will be initialized using the normal distribution and bais will be set to the specified value. 2. If normal_init_embed_weights is set then weights of embedding layer will be initialized using the normal distribution. 3. If normal_init_proj_weights is set then weights of in_project_weight for MultiHeadAttention initialized using the normal distribution (to be validated). """ def normal_(data): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, MultiheadAttention): normal_(module.q_proj.weight.data) normal_(module.k_proj.weight.data) normal_(module.v_proj.weight.data) class GradMultiply(torch.autograd.Function): @staticmethod def forward(ctx, x, scale): ctx.scale = scale res = x.new(x) return res @staticmethod def backward(ctx, grad): return grad * ctx.scale, None class SamePad(nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() self.act = torch.nn.Sigmoid() def forward(self, x): return x * self.act(x) class GLU_Linear(nn.Module): def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): super(GLU_Linear, self).__init__() self.glu_type = glu_type self.output_dim = output_dim if glu_type == "sigmoid": self.glu_act = torch.nn.Sigmoid() elif glu_type == "swish": self.glu_act = Swish() elif glu_type == "relu": self.glu_act = torch.nn.ReLU() elif glu_type == "gelu": self.glu_act = torch.nn.GELU() if bias_in_glu: self.linear = nn.Linear(input_dim, output_dim * 2, True) else: self.linear = nn.Linear(input_dim, output_dim * 2, False) def forward(self, x): # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case x = self.linear(x) if self.glu_type == "bilinear": x = ( x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2] ) else: x = x[:, :, 0 : self.output_dim] * self.glu_act( x[:, :, self.output_dim : self.output_dim * 2] ) return x def gelu_accurate(x): if not hasattr(gelu_accurate, "_a"): gelu_accurate._a = math.sqrt(2 / math.pi) return ( 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) ) def gelu(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x.float()).type_as(x) def get_activation_fn(activation: str): """Returns the activation function corresponding to `activation`""" if activation == "relu": return F.relu elif activation == "gelu": return gelu elif activation == "gelu_fast": warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") return gelu_accurate elif activation == "gelu_accurate": return gelu_accurate elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x elif activation == "glu": return lambda x: x else: raise RuntimeError("--activation-fn {} not supported".format(activation)) def quant_noise(module, p, block_size): """ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product Quantization as described in "Training with Quantization Noise for Extreme Model Compression" Args: - module: nn.Module - p: amount of Quantization Noise - block_size: size of the blocks for subsequent quantization with iPQ Remarks: - Module weights must have the right sizes wrt the block size - Only Linear, Embedding and Conv2d modules are supported for the moment - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping blocks """ # if no quantization noise, don't register hook if p <= 0: return module # supported modules assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) # test whether module.weight has the right sizes wrt block_size is_conv = module.weight.ndim == 4 # 2D matrix if not is_conv: assert ( module.weight.size(1) % block_size == 0 ), "Input features must be a multiple of block sizes" # 4D matrix else: # 1x1 convolutions if module.kernel_size == (1, 1): assert ( module.in_channels % block_size == 0 ), "Input channels must be a multiple of block sizes" # regular convolutions else: k = module.kernel_size[0] * module.kernel_size[1] assert k % block_size == 0, "Kernel size must be a multiple of block size" def _forward_pre_hook(mod, input): # no noise for evaluation if mod.training: if not is_conv: # gather weight and sizes weight = mod.weight in_features = weight.size(1) out_features = weight.size(0) # split weight matrix into blocks and randomly drop selected blocks mask = torch.zeros( in_features // block_size * out_features, device=weight.device ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) else: # gather weight and sizes weight = mod.weight in_channels = mod.in_channels out_channels = mod.out_channels # split weight matrix into blocks and randomly drop selected blocks if mod.kernel_size == (1, 1): mask = torch.zeros( int(in_channels // block_size * out_channels), device=weight.device, ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) else: mask = torch.zeros( weight.size(0), weight.size(1), device=weight.device ) mask.bernoulli_(p) mask = ( mask.unsqueeze(2) .unsqueeze(3) .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) ) # scale weights and apply mask mask = mask.to( torch.bool ) # x.bool() is not currently supported in TorchScript s = 1 / (1 - p) mod.weight.data = s * weight.masked_fill(mask, 0) module.register_forward_pre_hook(_forward_pre_hook) return module class TokenizersConfig: def __init__(self, cfg=None): self.input_patch_size: int = -1 # path size of patch embedding self.embed_dim: int = 512 # patch embedding dimension self.conv_bias: bool = False # include bias in conv encoder self.encoder_layers: int = 12 # num encoder layers in the transformer self.encoder_embed_dim: int = 768 # encoder embedding dimension self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN self.encoder_attention_heads: int = 12 # num encoder attention heads self.activation_fn: str = "gelu" # activation function to use self.layer_norm_first: bool = False # apply layernorm first in the transformer self.deep_norm: bool = False # apply deep_norm first in the transformer # dropouts self.dropout: float = 0.1 # dropout probability for the transformer self.attention_dropout: float = 0.1 # dropout probability for attention weights # dropout probability after activation in FFN self.activation_dropout: float = 0.0 # probability of dropping a tarnsformer layer self.encoder_layerdrop: float = 0.0 # dropout to apply to the input (after feat extr) self.dropout_input: float = 0.0 # positional embeddings self.conv_pos: int = ( 128 # number of filters for convolutional positional embeddings ) # number of groups for convolutional positional embedding self.conv_pos_groups: int = 16 # relative position embedding # apply relative position embedding self.relative_position_embedding: bool = False self.num_buckets: int = 320 # number of buckets for relative position embedding self.max_distance: int = ( 1280 # maximum distance for relative position embedding ) self.gru_rel_pos: bool = False # apply gated relative position embedding # quantizer self.quant_n: int = 1024 # codebook number in quantizer self.quant_dim: int = 256 # codebook dimension in quantizer if cfg is not None: self.update(cfg) def update(self, cfg: dict): self.__dict__.update(cfg) class Tokenizers(nn.Module): def __init__( self, cfg: TokenizersConfig, ) -> None: super().__init__() logger.info(f"Tokenizers Config: {cfg.__dict__}") self.cfg = cfg self.embed = cfg.embed_dim self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None ) self.input_patch_size = cfg.input_patch_size self.patch_embedding = nn.Conv2d( 1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias, ) self.dropout_input = nn.Dropout(cfg.dropout_input) assert not cfg.deep_norm or not cfg.layer_norm_first self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.quantize = NormEMAVectorQuantizer( n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, ) self.quant_n = cfg.quant_n self.quantize_layer = nn.Sequential( nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), nn.Tanh(), nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim), # for quantize ) def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask def preprocess( self, source: torch.Tensor, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ) -> torch.Tensor: fbanks = [] for waveform in source: waveform = waveform.unsqueeze(0) * 2**15 fbank = ta_kaldi.fbank( waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10, ) fbanks.append(fbank) fbank = torch.stack(fbanks, dim=0) fbank = (fbank - fbank_mean) / (2 * fbank_std) return fbank def extract_labels( self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ): fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) if padding_mask is not None: padding_mask = self.forward_padding_mask(fbank, padding_mask) fbank = fbank.unsqueeze(1) features = self.patch_embedding(fbank) features = features.reshape(features.shape[0], features.shape[1], -1) features = features.transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) x = self.dropout_input(features) x, layer_results = self.encoder( x, padding_mask=padding_mask, ) quantize_input = self.quantize_layer(x) quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) return embed_ind def l2norm(t): return F.normalize(t, p=2, dim=-1) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device means = sample_vectors(samples, num_clusters) for _ in range(num_iters): if use_cosine_sim: dists = samples @ means.t() else: diffs = rearrange(samples, "n d -> n () d") - rearrange( means, "c d -> () c d" ) dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] if use_cosine_sim: new_means = l2norm(new_means) means = torch.where(zero_mask[..., None], means, new_means) return means, bins class EmbeddingEMA(nn.Module): def __init__( self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path="", ): super().__init__() self.num_tokens = num_tokens self.codebook_dim = codebook_dim self.decay = decay self.eps = eps if codebook_init_path == "": if not kmeans_init: weight = torch.randn(num_tokens, codebook_dim) weight = l2norm(weight) else: weight = torch.zeros(num_tokens, codebook_dim) self.register_buffer("initted", torch.Tensor([not kmeans_init])) else: print(f"load init codebook weight from {codebook_init_path}") codebook_ckpt_weight = torch.load(codebook_init_path, map_location="cpu") weight = codebook_ckpt_weight.clone() self.register_buffer("initted", torch.Tensor([True])) self.weight = nn.Parameter(weight, requires_grad=False) self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) # self.register_buffer('initted', torch.Tensor([not kmeans_init])) self.update = True @torch.jit.ignore def init_embed_(self, data): if self.initted: return print("Performing Kemans init for codebook") embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) self.weight.data.copy_(embed) self.cluster_size.data.copy_(cluster_size) self.initted.data.copy_(torch.Tensor([True])) def forward(self, embed_id): return F.embedding(embed_id, self.weight) def cluster_size_ema_update(self, new_cluster_size): self.cluster_size.data.mul_(self.decay).add_( new_cluster_size, alpha=1 - self.decay ) def embed_avg_ema_update(self, new_embed_avg): self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) def weight_update(self, num_tokens): n = self.cluster_size.sum() smoothed_cluster_size = ( (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n ) # normalize embedding average with smoothed cluster size embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) self.weight.data.copy_(embed_normalized) def norm_ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) moving_avg.data.copy_(l2norm(moving_avg.data)) class NormEMAVectorQuantizer(nn.Module): def __init__( self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, statistic_code_usage=True, kmeans_init=False, codebook_init_path="", ): super().__init__() self.codebook_dim = embedding_dim self.num_tokens = n_embed self.beta = beta self.decay = decay # learnable = True if orthogonal_reg_weight > 0 else False self.embedding = EmbeddingEMA( self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path, ) self.statistic_code_usage = statistic_code_usage if statistic_code_usage: self.register_buffer("cluster_size", torch.zeros(n_embed)) if distributed.is_available() and distributed.is_initialized(): print( "ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!" ) self.all_reduce_fn = distributed.all_reduce else: self.all_reduce_fn = nn.Identity() def reset_cluster_size(self, device): if self.statistic_code_usage: self.register_buffer("cluster_size", torch.zeros(self.num_tokens)) self.cluster_size = self.cluster_size.to(device) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten # z, 'b c h w -> b h w c' # z = rearrange(z, 'b c h w -> b h w c') # z = z.transpose(1, 2) z = l2norm(z) z_flattened = z.reshape(-1, self.codebook_dim) self.embedding.init_embed_(z_flattened) d = ( z_flattened.pow(2).sum(dim=1, keepdim=True) + self.embedding.weight.pow(2).sum(dim=1) - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) ) # 'n d -> d n' encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(encoding_indices).view(z.shape) encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) if not self.training: with torch.no_grad(): cluster_size = encodings.sum(0) self.all_reduce_fn(cluster_size) ema_inplace(self.cluster_size, cluster_size, self.decay) if self.training and self.embedding.update: # EMA cluster size bins = encodings.sum(0) self.all_reduce_fn(bins) # self.embedding.cluster_size_ema_update(bins) ema_inplace(self.cluster_size, bins, self.decay) zero_mask = bins == 0 bins = bins.masked_fill(zero_mask, 1.0) embed_sum = z_flattened.t() @ encodings self.all_reduce_fn(embed_sum) embed_normalized = (embed_sum / bins.unsqueeze(0)).t() embed_normalized = l2norm(embed_normalized) embed_normalized = torch.where( zero_mask[..., None], self.embedding.weight, embed_normalized ) norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) # compute loss for embedding loss = self.beta * F.mse_loss(z_q.detach(), z) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape # z_q, 'b h w c -> b c h w' # z_q = rearrange(z_q, 'b h w c -> b c h w') # z_q = z_q.transpose(1, 2) return z_q, loss, encoding_indices