llama3.1-typhoon2-audio-8b-instruct / modeling_typhoon2audio.py
potsawee's picture
Upload Typhoon2Audio2AudioForConditionalGeneration
679b242 verified
raw
history blame
177 kB
"""
Some of the code is adapted from:
1. ByteDance's SALMONN (https://github.com/bytedance/SALMONN)
2. Llama-Omni (https://github.com/ictnlp/LLaMA-Omni/)
Please follow the copyright of the original projects.
"""
# ---------------------------------------------------- #
import inspect
import copy
import torch
import torch.nn.functional as F
from torch import Tensor, device, nn
import numpy as np
from transformers import (
WhisperFeatureExtractor,
WhisperConfig,
WhisperModel,
PreTrainedModel,
AutoTokenizer,
AutoModelForCausalLM,
)
from transformers.cache_utils import Cache, StaticCache
from transformers.generation.utils import (
GenerationConfig,
GenerationMode,
LogitsProcessorList,
StoppingCriteriaList,
GenerateOutput,
GenerationMixin,
GenerateEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateNonBeamOutput,
is_deepspeed_zero3_enabled,
is_torchdynamo_compiling,
NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING,
is_hqq_available,
QuantizedCacheConfig,
is_quanto_available,
DynamicCache,
EncoderDecoderCache,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_typhoon2audio import Typhoon2AudioConfig, BEATsConfig
# ---------------------------------------------------- #
# QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
import math
import warnings
from typing import Optional, Tuple, Dict, Union, Callable, List
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
)
from transformers.modeling_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.models.bert.configuration_bert import BertConfig
# ---------------------------------------------------------- #
# BEATs: https://github.com/microsoft/unilm/tree/master/beats
from torch.nn import LayerNorm, Parameter
import torch.distributed as distributed
import torchaudio.compliance.kaldi as ta_kaldi
import logging
try:
from einops import rearrange, repeat
except ImportError:
pass
logger = logging.getLogger(__name__)
# ---------------------------------------------------------- #
# Speech Decoder
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Unit Vocoder
from fairseq.models import BaseFairseqModel
from fairseq.models.text_to_speech.codehifigan import CodeGenerator as CodeHiFiGANModel
# ---------------------------------------------------------- #
import soundfile as sf
class GenerationWithCTC(GenerationMixin):
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
streamer_unit: Optional["BaseStreamer"] = None,
streaming_unit_gen=False,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# Pull this out first, we only use it for stopping criteria
tokenizer = kwargs.pop("tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config, **kwargs
)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(
generation_config, kwargs_has_attention_mask, device=device
)
# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(
inputs_tensor[:, -1] == generation_config._pad_token_tensor
)
> 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
if (
not kwargs_has_attention_mask
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs["attention_mask"] = (
self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config._pad_token_tensor,
generation_config._eos_token_tensor,
)
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
# pm574
else:
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
# elif model_input_name == "input_ids" or "input_ids" in model_kwargs:
# input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# elif model_input_name == "inputs_embeds":
# input_ids = inputs_tensor
# else:
# raise Exception("error here")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_min_length = (
kwargs.get("min_length") is None
and generation_config.min_length is not None
)
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
use_dynamic_cache_by_default = False
if "mamba" in self.__class__.__name__.lower():
cache_name = "cache_params"
else:
cache_name = "past_key_values"
if generation_config.cache_implementation is not None and (
model_kwargs.get(cache_name) is not None
):
raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
if (
generation_config.cache_implementation
in NEED_SETUP_CACHE_CLASSES_MAPPING
):
if (
generation_config.cache_implementation == "static"
and not self._supports_static_cache
):
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs[cache_name] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif (
generation_config.cache_implementation is None
and self._supports_default_dynamic_cache()
):
past = model_kwargs.get(cache_name, None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder
or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
use_dynamic_cache_by_default = True
self._validate_generated_length(
generation_config, input_ids_length, has_default_max_length
)
# 7. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
if (streamer is not None or streamer_unit is not None) and (
generation_config.num_beams > 1
):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
" running `.generate()`.",
UserWarning,
)
# 8. prepare distribution pre_processing samplers
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
tokenizer=tokenizer,
**kwargs,
)
# 10. go into different generation modes
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
if streaming_unit_gen:
return self._sample_streaming_unit(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
streamer_unit=streamer_unit,
**model_kwargs,
)
else:
return self._sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
else:
raise NotImplementedError
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(
batch_size, dtype=torch.long, device=input_ids.device
)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update(
{"output_attentions": output_attentions} if output_attentions else {}
)
model_inputs.update(
{"output_hidden_states": output_hidden_states}
if output_hidden_states
else {}
)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores
)
this_peer_finished = unfinished_sequences.max() == 0
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def _sample_streaming_unit(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
streamer_unit: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(
batch_size, dtype=torch.long, device=input_ids.device
)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
generated_units = torch.tensor([])
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update(
{"output_attentions": output_attentions} if output_attentions else {}
)
model_inputs.update(
{"output_hidden_states": output_hidden_states}
if output_hidden_states
else {}
)
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# speechgen
hidden_states = torch.cat(
[decoder_hidden_states[0][-1][:, -1:, :]]
+ [
decoder_hidden_states[i][-1]
for i in range(1, len(decoder_hidden_states))
],
dim=1,
)
ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
cur_units = ctc_postprocess(
ctc_pred, blank=self.model.config.unit_vocab_size
)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
if streamer_unit is not None:
for i in range(len(generated_units), len(cur_units)):
streamer_unit.put(cur_units[i].unsqueeze(0))
generated_units = cur_units
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores
)
this_peer_finished = unfinished_sequences.max() == 0
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def ctc_postprocess(self, tokens, blank):
_toks = tokens.squeeze(0).tolist()
deduplicated_toks = [
v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]
]
hyp = torch.tensor([v for v in deduplicated_toks if v != blank])
return hyp
class Typhoon2AudioForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class = Typhoon2AudioConfig
_supports_cache_class = True
def __init__(
self,
config,
attn_implementation=None, # only for the LLM
):
super().__init__(config)
# 1. Speech Encoder
# 1.1) Whisper Encoder
# feature_extractor
self.feature_extractor = WhisperFeatureExtractor(
feature_size=config.whisper_extractor_feature_size
)
# whisper encoder
if isinstance(config.whisper, dict):
config.whisper = WhisperConfig(**config.whisper)
self.speech_encoder = WhisperModel(config.whisper).encoder
self.ln_speech = nn.LayerNorm(config.whisper.d_model)
# 1.2) BEATs
if isinstance(config.beats, dict):
config.beats = BEATsConfig(config.beats)
self.beats = BEATs(config.beats)
self.ln_audio = nn.LayerNorm(config.beats.encoder_embed_dim)
# 1.3) Speech QFormer
self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer(
config.speech_qformer_token_num,
config.whisper.d_model + config.beats.encoder_embed_dim,
config.speech_qformer_layer,
)
self.second_per_frame = config.second_per_frame
self.second_stride = config.second_stride
# 2. LLM (e.g., Llama3)
self.llama_model = AutoModelForCausalLM.from_pretrained(
config.llama_base_model, attn_implementation=attn_implementation
)
# tokenizer
self.llama_tokenizer = AutoTokenizer.from_pretrained(
config.llama_base_model, use_fast=False
)
self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.llama_tokenizer.padding_side = "right"
# speech -> LLM projection
self.speech_llama_proj = nn.Linear(
self.speech_Qformer.config.hidden_size,
self.llama_model.config.hidden_size,
)
def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2):
encoder_config = BertConfig()
encoder_config.num_hidden_layers = num_hidden_layers
encoder_config.encoder_width = speech_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = 1
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size),
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
def encode_speech_only(self, audio):
# whisper
spectrogram = (
self.feature_extractor(audio, return_tensors="pt", sampling_rate=16000)
.input_features.to(self.device)
.to(self.dtype)
) # [1, 80, 3000]
speech_embeds = self.speech_encoder(
spectrogram, return_dict=True
).last_hidden_state
# beats
raw_wav = torch.from_numpy(audio).to(self.device).unsqueeze(0)
audio_padding_mask = torch.zeros(raw_wav.shape, device=self.device).bool()
audio_embeds, _ = self.beats.extract_features(
raw_wav,
padding_mask=audio_padding_mask,
feature_only=True,
torch_dtype=self.dtype,
)
# auditory embeds
speech_embeds = self.ln_speech(speech_embeds)
audio_embeds = self.ln_audio(audio_embeds)
audio_embeds = F.pad(
audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))
)
speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1)
# split frames
B, T, C = speech_embeds.shape
kernel = round(T * self.second_per_frame / 30.0)
stride = round(T * self.second_stride / 30.0)
kernel = (1, kernel)
stride = (1, stride)
speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
speech_embeds_overlap = F.unfold(
speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride
)
_, _, L = speech_embeds_overlap.shape
speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
speech_atts = torch.ones(
speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device
)
# Qformer
query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1)
query_output = self.speech_Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=speech_embeds,
encoder_attention_mask=speech_atts,
return_dict=True,
)
speech_embeds = self.speech_llama_proj(query_output.last_hidden_state)
speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous()
return speech_embeds
def _get_text_from_content_list(self, content_list: List):
for content in content_list:
if content["type"] == "text":
return content["text"]
return ""
def _get_audio_from_content_list(self, content_list: List):
for content in content_list:
if content["type"] == "audio":
return f"<Speech>{content['audio_url']}</Speech> "
return ""
def _get_audio_url_from_string(self, content: str):
return content.split("<Speech>")[1].split("</Speech>")[0]
def _filter_only_audio_content(self, content_list: List):
return [
self._get_audio_url_from_string(content)
for content in content_list
if "<Speech>" in content
]
def _split_conversation_by_speech(self, conversation_str: str):
intermediate_list = [conversation_str]
if "<Speech>" in conversation_str:
result = conversation_str.split("<Speech>")
intermediate_list = [
item + ("<Speech>" if i < len(result) - 1 else "")
for i, item in enumerate(result)
]
processed_list = []
for item in intermediate_list:
if "</Speech>" in item:
parts = item.split("</Speech>")
file_path = parts[0]
remaining_context = (
"</Speech>" + parts[1] if len(parts) > 1 else "</Speech>"
)
processed_list.extend([file_path, remaining_context])
else:
processed_list.append(item)
return processed_list
def _convert_conv_to_embeds(self, conversation_list: List, speech_embeds: List):
embeds = []
speech_embeds_keys = [speech["audio_url"] for speech in speech_embeds]
for item in conversation_list:
if item in speech_embeds_keys:
selected = [
speech["audio"]
for speech in speech_embeds
if speech["audio_url"] == item
][0]
selected = selected.to(self.device)
embeds.append(selected)
else:
tokenized = self.llama_tokenizer(
item, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.device)
token_embeds = self.llama_model.model.embed_tokens(tokenized)
embeds.append(token_embeds)
return embeds
def encode_speech_with_text(self, conversation: List):
converted_conversation = [
f"<|start_header_id|>{msg['role']}<|end_header_id|>\n\n{msg['content'] if not isinstance(msg['content'], list) else self._get_audio_from_content_list(msg['content']) + self._get_text_from_content_list(msg['content'])}<|eot_id|>"
for msg in conversation
]
conversation_str = (
"".join(converted_conversation)
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
conversation_list = self._split_conversation_by_speech(conversation_str)
speech_embeds = [
{"audio_url": audio, "audio": self.encode_speech_only(sf.read(audio)[0])}
for audio in self._filter_only_audio_content(converted_conversation)
]
bos_embeds = self.llama_model.model.embed_tokens(
torch.ones(
[1, 1],
dtype=torch.long,
device=self.device,
)
* self.llama_tokenizer.bos_token_id
)
embed_list = [bos_embeds] + self._convert_conv_to_embeds(
conversation_list, speech_embeds
)
embeds = torch.cat(embed_list, dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
return embeds, atts
def forward(
self,
conversation: List,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
# TODO: support batch_size > 1
embeds, atts = self.encode_speech_with_text(conversation)
# forward
outputs = self.llama_model.forward(
inputs_embeds=embeds,
attention_mask=atts,
labels=labels,
return_dict=return_dict,
)
return outputs
# def forward(
# self,
# input_ids: torch.LongTensor = None,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.LongTensor] = None,
# past_key_values: Optional[List[torch.FloatTensor]] = None,
# inputs_embeds: Optional[torch.FloatTensor] = None,
# labels: Optional[torch.LongTensor] = None,
# use_cache: Optional[bool] = None,
# output_attentions: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# return_dict: Optional[bool] = None,
# cache_position: Optional[torch.LongTensor] = None,
# ) -> Union[Tuple, CausalLMOutputWithPast]:
# llama_output = self.llama_model.forward(
# input_ids=input_ids,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
# labels=labels,
# use_cache=use_cache,
# output_attentions=output_attentions,
# output_hidden_states=True,
# return_dict=return_dict,
# cache_position=cache_position,
# )
# loss = llama_output.loss
# return CausalLMOutputWithPast(
# loss=loss,
# logits=llama_output.logits,
# past_key_values=llama_output.past_key_values,
# hidden_states=llama_output.hidden_states,
# attentions=llama_output.attentions
# )
def generate(
self,
conversation: List,
max_new_tokens=1024,
num_beams=1,
do_sample=True,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1.0,
streamer=None,
) -> str:
embeds, atts = self.encode_speech_with_text(conversation)
# generate
output = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
attention_mask=atts,
bos_token_id=self.llama_tokenizer.bos_token_id,
eos_token_id=self.llama_tokenizer.eos_token_id,
pad_token_id=self.llama_tokenizer.pad_token_id,
streamer=streamer,
)
output_text = self.llama_tokenizer.batch_decode(
output, add_special_tokens=False, skip_special_tokens=True
)
return output_text[0]
# ------------------------------------------------------------------------------- #
# November 2024 -- multi-turn
def init_multiturn(
self,
system_prompt="<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant named ไต้ฝุ่น. You always answer in Thai.<|eot_id|>",
user_prompt_prefix="<|start_header_id|>user<|end_header_id|>\n\n",
user_prompt_suffix="</Speech> <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
):
self.conversations = []
self.user_prompt_prefix = user_prompt_prefix
self.user_prompt_suffix = user_prompt_suffix
if system_prompt is not None:
embed_tokens = (
self.llama_model.model.model.embed_tokens
if self.lora
else self.llama_model.model.embed_tokens
)
system_prompt_ids = (
self.llama_tokenizer(
system_prompt, return_tensors="pt", add_special_tokens=False
)
.to(self.device)
.input_ids
)
system_prompt_embeds = embed_tokens(system_prompt_ids)
self.add_cache(dtype="text:system_prompt", embeds=system_prompt_embeds)
print("multi-turn conversation initialized!")
def add_cache(self, dtype, embeds):
# cache
# --> for text, cache content = token embeddings
# --> for wav, cache content = speech embeddings
self.conversations.append({"dtype": dtype, "embeds": embeds})
def generate_multiturn(
self,
wav_path,
device="cuda:0",
max_length=1500,
num_beams=4,
do_sample=True,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1.0,
streamer=None,
):
embed_tokens = (
self.llama_model.model.model.embed_tokens
if self.lora
else self.llama_model.model.embed_tokens
)
# prefix: <|start_header_id|>user<|end_header_id|>\n\n
user_prompt_prefix_ids = (
self.llama_tokenizer(
self.user_prompt_prefix, return_tensors="pt", add_special_tokens=False
)
.to(self.device)
.input_ids
)
user_prompt_prefix_embeds = embed_tokens(user_prompt_prefix_ids)
self.add_cache(
dtype="text:user_prompt_prefix", embeds=user_prompt_prefix_embeds
)
# process the new wav
speech_embeds = self.process_wav(wav_path)
self.add_cache(dtype="wav:user_input", embeds=speech_embeds)
# suffix: </Speech> <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
user_prompt_suffix_ids = (
self.llama_tokenizer(
self.user_prompt_suffix, return_tensors="pt", add_special_tokens=False
)
.to(self.device)
.input_ids
)
user_prompt_suffix_embeds = embed_tokens(user_prompt_suffix_ids)
self.add_cache(
dtype="text:user_prompt_suffix", embeds=user_prompt_suffix_embeds
)
# --------------------------------------------------------------------------- #
list_of_embeds = []
for em in self.conversations:
list_of_embeds.append(em["embeds"])
# for em in self.conversations: print(em['dtype'], em['embeds'].shape)
embeds = torch.cat(list_of_embeds, dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
print("seq_length:", embeds.shape[1])
# generate
output = self.llama_model.generate(
inputs_embeds=embeds,
max_length=max_length,
num_beams=num_beams,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
attention_mask=atts,
bos_token_id=self.llama_tokenizer.bos_token_id,
eos_token_id=self.llama_tokenizer.eos_token_id,
pad_token_id=self.llama_tokenizer.pad_token_id,
streamer=streamer,
)
# add assistant generation
output_text = self.llama_tokenizer.batch_decode(
output, add_special_tokens=False, skip_special_tokens=True
)
assistant_text_ids = (
self.llama_tokenizer(
output_text[0] + "<|eot_id|>",
return_tensors="pt",
add_special_tokens=False,
)
.to(self.device)
.input_ids
)
assistant_text_embeds = embed_tokens(assistant_text_ids)
self.add_cache(dtype="text:assistant_generation", embeds=assistant_text_embeds)
return output_text[0]
class Typhoon2Audio2AudioForConditionalGeneration(
Typhoon2AudioForConditionalGeneration, GenerationWithCTC
):
config_class = Typhoon2AudioConfig
def __init__(self, config):
super().__init__(config)
"""
Initialize
1) speech decoder (llm output representation -> speech unit)
2) unit vocoder (speech unit -> wav)
"""
self.pretraining_tp = config.pretraining_tp
self.speech_generator = SpeechGeneratorCTC(config)
self.init_vocoder(config)
def init_vocoder(self, config=None, checkpoint_path=None):
# separate vocoder initialization as it is supposed to be float32
# other parts should be in float16
if config is None:
config = self.config
self.vocoder = CodeHiFiGANVocoder(
model_cfg=config.vocoder_config, checkpoint_path=checkpoint_path
)
self.vocoder.to(self.device)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
llama_output = self.llama_model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
loss = llama_output.loss
return CausalLMOutputWithPast(
loss=loss,
logits=llama_output.logits,
past_key_values=llama_output.past_key_values,
hidden_states=llama_output.hidden_states,
attentions=llama_output.attentions,
)
@torch.no_grad()
def generate(
self,
# ----------------- #
inputs_embeds=None,
attention_mask=None,
output_hidden_states=True,
return_dict_in_generate=True,
streaming_unit_gen=False,
max_length=8000,
# ----------------- #
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
if "conversation" in kwargs and inputs_embeds is None:
conversation = kwargs.get("conversation", [])
inputs_embeds, attention_mask = self.encode_speech_with_text(conversation)
outputs = GenerationWithCTC.generate(
self,
# position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
streaming_unit_gen=streaming_unit_gen,
# typhoon2 (llama3.1) will set this to 20 somehow otherwise
max_length=max_length,
# ------------------- #
bos_token_id=128000,
eos_token_id=[128001, 128008, 128009],
)
hidden_states = outputs["hidden_states"]
hidden_states = torch.cat(
[hidden_states[0][-1][:, -1:, :]]
+ [hidden_states[i][-1] for i in range(1, len(hidden_states))],
dim=1,
)
ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
# processing
output_ids, output_units = outputs.sequences, ctc_pred
# text
output_text = self.llama_tokenizer.batch_decode(
output_ids, add_special_tokens=False, skip_special_tokens=True
)[0]
# wav
output_audio = self.ctc_pred_to_audio(output_units)
return {"text": output_text, "unit": output_units, "audio": output_audio}
@torch.no_grad()
def synthesize_speech(
self,
text,
):
# apply chat template adds (supposed to be) unnecessary tokens
# however, this wa applied during training, so it should be added here
# in the next version, please consider removing `apply_chat_template`
text_ = self.llama_tokenizer.apply_chat_template(
[{"role": "assistant", "content": text}], tokenize=False
)
inputs = self.llama_tokenizer(text_, return_tensors="pt").to(self.device)
outputs = self(**inputs)
hidden_states = outputs["hidden_states"][-1]
ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
output_audio = self.ctc_pred_to_audio(ctc_pred)
return output_audio
def ctc_pred_to_audio(self, units):
# vocoder
if hasattr(self, "vocoder"):
units = self.ctc_postprocess(units, blank=self.config.unit_vocab_size)
units = [(list(map(int, units.strip().split())))]
units_tensor = torch.tensor(units, dtype=torch.int64, device=self.device)
audio_arr = self.vocoder({"code": units_tensor}, True)
audio_arr = audio_arr.detach().cpu().numpy()
else:
audio_arr = None
return {
"array": audio_arr,
"sampling_rate": self.config.vocoder_config["sampling_rate"],
}
def ctc_postprocess(self, tokens, blank):
_toks = tokens.squeeze(0).tolist()
deduplicated_toks = [
v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]
]
hyp = [v for v in deduplicated_toks if v != blank]
hyp = " ".join(list(map(str, hyp)))
return hyp
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
# taken from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
"""
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
slicing inputs given the existing cache.
See the forward pass in the model documentation for expected arguments (different models might have different
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
"""
# 1. Handle BC:
model_inputs = {}
# - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
if self._supports_cache_class:
model_inputs["cache_position"] = cache_position
# - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
elif cache_position is None:
past_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
cache_position = torch.arange(
past_length,
input_ids.shape[1],
dtype=torch.long,
device=input_ids.device,
)
# 2. Generic cache-dependent input preparation
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if (
inputs_embeds is not None # Exception 1
# Exception 3
or (
is_torchdynamo_compiling()
or cache_position[-1] >= input_ids.shape[1]
)
):
input_ids = input_ids[:, -cache_position.shape[0] :]
# Default case (the "else", a no op, is Exception 2)
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
# 3. Prepare base model inputs
input_ids_key = (
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
# `clone` calls in this function ensure a consistent stride. See #32227
model_inputs[input_ids_key] = input_ids.clone(
memory_format=torch.contiguous_format
)
model_inputs["inputs_embeds"] = None
else:
model_inputs[input_ids_key] = input_ids.clone(
memory_format=torch.contiguous_format
)
# 4. Create missing `position_ids` on the fly
if (
attention_mask is not None
and kwargs.get("position_ids") is None
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
# placed in kwargs for further processing (see below)
kwargs["position_ids"] = position_ids
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
for model_input_name in ["position_ids", "token_type_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs["inputs_embeds"] is not None
else model_inputs[input_ids_key].shape[1]
)
model_input = model_input[:, -current_input_length:]
model_input = model_input.clone(
memory_format=torch.contiguous_format
)
model_inputs[model_input_name] = model_input
# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs[input_ids_key].shape
device = model_inputs[input_ids_key].device
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
base_model = getattr(self, self.base_model_prefix, None)
if base_model is None:
causal_mask_creation_function = getattr(
self, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
else:
causal_mask_creation_function = getattr(
base_model,
"_prepare_4d_causal_attention_mask_with_cache_position",
None,
)
if causal_mask_creation_function is None:
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
"writing code, see Llama for an example implementation. If you're a user, please report this "
"issue on GitHub."
)
else:
attention_mask = causal_mask_creation_function(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
model_inputs.pop("labels", None)
return model_inputs
def _get_logits_warper(
self,
generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if (
generation_config.temperature is not None
and generation_config.temperature != 1.0
):
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(
TopKLogitsWarper(
top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(
TopPLogitsWarper(
top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers.append(
MinPLogitsWarper(
min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep
)
)
if (
generation_config.typical_p is not None
and generation_config.typical_p < 1.0
):
warpers.append(
TypicalLogitsWarper(
mass=generation_config.typical_p,
min_tokens_to_keep=min_tokens_to_keep,
)
)
if (
generation_config.epsilon_cutoff is not None
and 0.0 < generation_config.epsilon_cutoff < 1.0
):
warpers.append(
EpsilonLogitsWarper(
epsilon=generation_config.epsilon_cutoff,
min_tokens_to_keep=min_tokens_to_keep,
)
)
if (
generation_config.eta_cutoff is not None
and 0.0 < generation_config.eta_cutoff < 1.0
):
warpers.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff,
min_tokens_to_keep=min_tokens_to_keep,
device=device,
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
# ------------------------------------------------------------------------------------------ #
# Speech Decoder Componnt
class SpeechGeneratorCTC(nn.Module):
def __init__(self, config):
super().__init__()
n_layers, n_dims, n_heads, n_inter_dims = list(
map(int, config.ctc_decoder_config[1:-1].split(","))
)
_config = copy.deepcopy(config)
_config.hidden_size = n_dims
_config.num_hidden_layers = n_layers
_config.num_attention_heads = n_heads
_config.num_key_value_heads = n_heads
_config.intermediate_size = n_inter_dims
_config._attn_implementation = "flash_attention_2"
self.upsample_factor = config.ctc_upsample_factor
self.input_proj = nn.Linear(config.hidden_size, n_dims)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]
)
self.unit_vocab_size = config.unit_vocab_size
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1)
self.speech_decoder_ignore_index = config.speech_decoder_ignore_index
def upsample(self, reps, tgt_units=None):
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)
up_lens = src_lens * self.upsample_factor
if tgt_units is not None:
tgt_lens = tgt_units.ne(self.speech_decoder_ignore_index).long().sum(dim=-1)
up_lens = torch.max(up_lens, tgt_lens)
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)
padding_mask = self._lengths_to_padding_mask(up_lens)
mapped_inputs = self._uniform_assignment(src_lens, up_lens).masked_fill(
padding_mask, 0
)
copied_reps = torch.gather(
reps,
1,
mapped_inputs.unsqueeze(-1).expand(*mapped_inputs.size(), reps.size(-1)),
)
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)
position_ids = (
torch.arange(0, max(up_lens))
.unsqueeze(0)
.expand(len(reps), -1)
.to(device=copied_reps.device)
)
return copied_reps, ~padding_mask, position_ids
def forward(self, tgt_reps, labels, tgt_units):
tgt_label_reps = []
for tgt_rep, label in zip(tgt_reps, labels):
tgt_label_reps.append(tgt_rep[label != self.speech_decoder_ignore_index])
hidden_states, attention_mask, position_ids = self.upsample(
tgt_label_reps, tgt_units
)
hidden_states = self.input_proj(hidden_states)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_lens = attention_mask.long().sum(dim=-1)
ctc_tgt_lens = tgt_units.ne(self.speech_decoder_ignore_index).long().sum(dim=-1)
ctc_tgt_mask = ~self._lengths_to_padding_mask(ctc_tgt_lens)
ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask)
ctc_loss = F.ctc_loss(
ctc_lprobs.transpose(0, 1),
ctc_tgt_flat,
ctc_lens,
ctc_tgt_lens,
reduction="sum",
zero_infinity=True,
blank=self.unit_vocab_size,
)
ctc_loss /= ctc_tgt_lens.sum().item()
return ctc_loss
def predict(self, tgt_reps):
hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])
hidden_states = self.input_proj(hidden_states)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(
~attention_mask, self.unit_vocab_size
)
return ctc_pred
def _lengths_to_padding_mask(self, lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def _uniform_assignment(self, src_lens, tgt_lens):
tgt_indices = (
torch.arange(torch.max(tgt_lens))
.expand(len(tgt_lens), -1)
.to(tgt_lens.device)
)
ratio = tgt_lens / src_lens
index_t = (tgt_indices / ratio.view(-1, 1)).long()
return index_t
# Code HiFiGAN
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/text_to_speech/vocoder.py
class CodeHiFiGANVocoder(BaseFairseqModel):
def __init__(
self, model_cfg: Dict[str, str], checkpoint_path: str = None, fp16: bool = False
) -> None:
super().__init__()
self.model = CodeHiFiGANModel(model_cfg)
if checkpoint_path is not None:
self.load_checkpoint(checkpoint_path)
self.model.eval()
if fp16:
self.model.half()
self.model.remove_weight_norm()
logger.info(f"initialized CodeHiFiGAN checkpoint")
def load_checkpoint(self, checkpoint_path: str) -> None:
if torch.cuda.is_available():
state_dict = torch.load(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.model.load_state_dict(state_dict["generator"])
logger.info(f"loaded CodeHiFiGAN checkpoint from {checkpoint_path}")
def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor:
assert "code" in x
x["dur_prediction"] = dur_prediction
# remove invalid code
mask = x["code"] >= 0
x["code"] = x["code"][mask].unsqueeze(dim=0)
if "f0" in x:
f0_up_ratio = x["f0"].size(1) // x["code"].size(1)
mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1))
x["f0"] = x["f0"][mask].unsqueeze(dim=0)
return self.model(**x).detach().squeeze()
# ---------------------------------------------------------------------------------------- #
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[
:, past_key_values_length : seq_length + past_key_values_length
].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(-1, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1
)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype
) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = (
attention_scores
+ relative_position_scores_query
+ relative_position_scores_key
)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = (
self.self.attention_head_size * self.self.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if (
self.config.add_cross_attention
and layer_num % self.config.cross_attention_freq == 0
):
self.crossattention = BertAttention(
config, is_cross_attention=self.config.add_cross_attention
)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.intermediate_query = BertIntermediate(config)
self.output_query = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
() if output_attentions and self.config.add_cross_attention else None
)
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
*inputs, past_key_value, output_attentions, query_length
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: Tensor,
input_shape: Tuple[int],
device: device,
is_decoder: bool,
has_query: bool = False,
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
if has_query: # UniLM style attention mask
causal_mask = torch.cat(
[
torch.zeros(
(batch_size, prefix_seq_len, seq_length),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=1,
)
causal_mask = torch.cat(
[
torch.ones(
(batch_size, causal_mask.shape[1], prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is None:
assert (
query_embeds is not None
), "You have to specify query_embeds when input_ids is None"
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length
if past_key_values is not None
else 0
)
query_length = query_embeds.shape[1] if query_embeds is not None else 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
query_embeds=query_embeds,
past_key_values_length=past_key_values_length,
)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if is_decoder:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask,
input_ids.shape,
device,
is_decoder,
has_query=(query_embeds is not None),
)
else:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0
].size()
else:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask) for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction="mean",
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
use_cache = False
if past_key_values is not None:
query_embeds = None
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
sequence_output = outputs[0]
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(
shifted_prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if reduction == "none":
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"query_embeds": query_embeds,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# ------------------------------------------------------ #
class BEATs(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(
1,
self.embed,
kernel_size=self.input_patch_size,
stride=self.input_patch_size,
bias=cfg.conv_bias,
)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
if cfg.finetuned_model:
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
else:
self.predictor = None
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2**15
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=16000,
frame_length=25,
frame_shift=10,
)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
feature_only=False,
torch_dtype=torch.float32,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(
torch_dtype
)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
if not feature_only and self.predictor is not None:
x = self.predictor_dropout(x)
logits = self.predictor(x)
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
)
else:
logits = logits.mean(dim=1)
lprobs = torch.sigmoid(logits)
return lprobs, padding_mask
else:
return x, padding_mask
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
deep_norm=args.deep_norm,
has_relative_attention_bias=self.relative_position_embedding,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
encoder_layers=args.encoder_layers,
)
for i in range(args.encoder_layers)
]
)
if self.relative_position_embedding:
for i in range(1, args.encoder_layers):
del self.layers[i].self_attn.relative_attention_bias
self.layers[i].self_attn.relative_attention_bias = self.layers[
0
].self_attn.relative_attention_bias
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
if args.deep_norm:
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
for i in range(args.encoder_layers):
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
nn.init.xavier_normal_(
self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta
)
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
nn.init.xavier_normal_(
self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta
)
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
self.layer_wise_gradient_decay_ratio = getattr(
args, "layer_wise_gradient_decay_ratio", 1
)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
if self.layer_wise_gradient_decay_ratio != 1.0:
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
pos_bias=pos_bias,
)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
deep_norm: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
encoder_layers: int = 0,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = LayerNorm(self.embedding_dim)
self.deep_norm = deep_norm
if self.deep_norm:
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
else:
self.deep_norm_alpha = 1
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None,
):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual * self.deep_norm_alpha + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual * self.deep_norm_alpha + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(
relative_positions, torch.zeros_like(relative_positions)
)
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large,
torch.full_like(relative_postion_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_positions, relative_postion_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position, bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(
self.relative_attention_bias.weight.device
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = (
position_bias.unsqueeze(0)
.repeat(bsz, 1, 1, 1)
.view(bsz * self.num_heads, tgt_len, src_len)
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
alpha = 32
q *= 1 / alpha
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = (
attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]
) * alpha
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos == 1:
query_layer = (
q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
* alpha
/ self.scaling
)
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(
self.grep_linear(query_layer)
.view(_B, _H, _L, 2, 4)
.sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = (
gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
)
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
attn_weights = attn_weights + attn_mask_rel_pos
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (
x[:, :, 0 : self.output_dim]
* x[:, :, self.output_dim : self.output_dim * 2]
)
else:
x = x[:, :, 0 : self.output_dim] * self.glu_act(
x[:, :, self.output_dim : self.output_dim * 2]
)
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
class TokenizersConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = -1 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
# dropout probability after activation in FFN
self.activation_dropout: float = 0.0
# probability of dropping a tarnsformer layer
self.encoder_layerdrop: float = 0.0
# dropout to apply to the input (after feat extr)
self.dropout_input: float = 0.0
# positional embeddings
self.conv_pos: int = (
128 # number of filters for convolutional positional embeddings
)
# number of groups for convolutional positional embedding
self.conv_pos_groups: int = 16
# relative position embedding
# apply relative position embedding
self.relative_position_embedding: bool = False
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = (
1280 # maximum distance for relative position embedding
)
self.gru_rel_pos: bool = False # apply gated relative position embedding
# quantizer
self.quant_n: int = 1024 # codebook number in quantizer
self.quant_dim: int = 256 # codebook dimension in quantizer
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class Tokenizers(nn.Module):
def __init__(
self,
cfg: TokenizersConfig,
) -> None:
super().__init__()
logger.info(f"Tokenizers Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(
1,
self.embed,
kernel_size=self.input_patch_size,
stride=self.input_patch_size,
bias=cfg.conv_bias,
)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.quantize = NormEMAVectorQuantizer(
n_embed=cfg.quant_n,
embedding_dim=cfg.quant_dim,
beta=1.0,
kmeans_init=True,
decay=0.99,
)
self.quant_n = cfg.quant_n
self.quantize_layer = nn.Sequential(
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
nn.Tanh(),
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim), # for quantize
)
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2**15
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=16000,
frame_length=25,
frame_shift=10,
)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_labels(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
quantize_input = self.quantize_layer(x)
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
return embed_ind
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EmbeddingEMA(nn.Module):
def __init__(
self,
num_tokens,
codebook_dim,
decay=0.99,
eps=1e-5,
kmeans_init=True,
codebook_init_path="",
):
super().__init__()
self.num_tokens = num_tokens
self.codebook_dim = codebook_dim
self.decay = decay
self.eps = eps
if codebook_init_path == "":
if not kmeans_init:
weight = torch.randn(num_tokens, codebook_dim)
weight = l2norm(weight)
else:
weight = torch.zeros(num_tokens, codebook_dim)
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
else:
print(f"load init codebook weight from {codebook_init_path}")
codebook_ckpt_weight = torch.load(codebook_init_path, map_location="cpu")
weight = codebook_ckpt_weight.clone()
self.register_buffer("initted", torch.Tensor([True]))
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.update = True
@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return
print("Performing Kemans init for codebook")
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
self.weight.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(
new_cluster_size, alpha=1 - self.decay
)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
)
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
self.weight.data.copy_(embed_normalized)
def norm_ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
moving_avg.data.copy_(l2norm(moving_avg.data))
class NormEMAVectorQuantizer(nn.Module):
def __init__(
self,
n_embed,
embedding_dim,
beta,
decay=0.99,
eps=1e-5,
statistic_code_usage=True,
kmeans_init=False,
codebook_init_path="",
):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.decay = decay
# learnable = True if orthogonal_reg_weight > 0 else False
self.embedding = EmbeddingEMA(
self.num_tokens,
self.codebook_dim,
decay,
eps,
kmeans_init,
codebook_init_path,
)
self.statistic_code_usage = statistic_code_usage
if statistic_code_usage:
self.register_buffer("cluster_size", torch.zeros(n_embed))
if distributed.is_available() and distributed.is_initialized():
print(
"ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
)
self.all_reduce_fn = distributed.all_reduce
else:
self.all_reduce_fn = nn.Identity()
def reset_cluster_size(self, device):
if self.statistic_code_usage:
self.register_buffer("cluster_size", torch.zeros(self.num_tokens))
self.cluster_size = self.cluster_size.to(device)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
# z = rearrange(z, 'b c h w -> b h w c')
# z = z.transpose(1, 2)
z = l2norm(z)
z_flattened = z.reshape(-1, self.codebook_dim)
self.embedding.init_embed_(z_flattened)
d = (
z_flattened.pow(2).sum(dim=1, keepdim=True)
+ self.embedding.weight.pow(2).sum(dim=1)
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
if not self.training:
with torch.no_grad():
cluster_size = encodings.sum(0)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
if self.training and self.embedding.update:
# EMA cluster size
bins = encodings.sum(0)
self.all_reduce_fn(bins)
# self.embedding.cluster_size_ema_update(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = bins == 0
bins = bins.masked_fill(zero_mask, 1.0)
embed_sum = z_flattened.t() @ encodings
self.all_reduce_fn(embed_sum)
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = l2norm(embed_normalized)
embed_normalized = torch.where(
zero_mask[..., None], self.embedding.weight, embed_normalized
)
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
# z_q = rearrange(z_q, 'b h w c -> b c h w')
# z_q = z_q.transpose(1, 2)
return z_q, loss, encoding_indices