|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Flax Wav2Vec2 model.""" |
|
|
|
from functools import partial |
|
from typing import Optional, Tuple, Union |
|
|
|
import flax |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict |
|
from flax.linen import partitioning as nn_partitioning |
|
from flax.linen.attention import dot_product_attention_weights |
|
from jax import lax |
|
|
|
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput |
|
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
from models import Wav2Vec2Config |
|
|
|
scan_with_axes = nn_partitioning.scan_with_axes |
|
remat = nn_partitioning.remat |
|
|
|
|
|
@flax.struct.dataclass |
|
class FlaxWav2Vec2BaseModelOutput(ModelOutput): |
|
""" |
|
Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions. |
|
|
|
Args: |
|
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`): |
|
Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim` |
|
being the dimension of the last convolutional layer. |
|
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape |
|
`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
last_hidden_state: jnp.ndarray = None |
|
extract_features: jnp.ndarray = None |
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None |
|
attentions: Optional[Tuple[jnp.ndarray]] = None |
|
|
|
|
|
WAV_2_VEC_2_START_DOCSTRING = r""" |
|
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech |
|
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael |
|
Auli. |
|
|
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a Flax Linen |
|
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
|
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
|
|
|
Finally, this model supports inherent JAX features such as: |
|
|
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
|
|
|
Parameters: |
|
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
|
`jax.numpy.bfloat16` (on TPUs). |
|
|
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
|
specified all the computation will be performed with the given `dtype`. |
|
|
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model |
|
parameters.** |
|
|
|
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
|
[`~FlaxPreTrainedModel.to_bf16`]. |
|
""" |
|
|
|
|
|
WAV_2_VEC_2_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`): |
|
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file |
|
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install |
|
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding |
|
and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details. |
|
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, |
|
1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed |
|
if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor |
|
has `config.return_attention_mask == False`, such as |
|
[wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be |
|
passed to avoid degraded performance when doing batched inference. For such models `input_values` should |
|
simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly |
|
different results depending on whether `input_values` is padded or not. |
|
mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict |
|
masked extracted features in *config.proj_codevector_dim* space. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
class FlaxWav2Vec2LayerNormConvLayer(nn.Module): |
|
config: Wav2Vec2Config |
|
layer_id: int = 0 |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 |
|
self.out_conv_dim = self.config.conv_dim[self.layer_id] |
|
|
|
self.conv = nn.Conv( |
|
features=self.config.conv_dim[self.layer_id], |
|
kernel_size=(self.config.conv_kernel[self.layer_id],), |
|
strides=(self.config.conv_stride[self.layer_id],), |
|
use_bias=self.config.conv_bias, |
|
kernel_init=jax.nn.initializers.he_normal(), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.activation = ACT2FN[self.config.feat_extract_activation] |
|
|
|
def __call__(self, hidden_states): |
|
hidden_states = self.conv(hidden_states) |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxConvWithWeightNorm(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv = nn.Conv( |
|
features=self.config.hidden_size, |
|
kernel_size=(self.config.num_conv_pos_embeddings,), |
|
kernel_init=jax.nn.initializers.he_normal(), |
|
padding="VALID", |
|
feature_group_count=self.config.num_conv_pos_embedding_groups, |
|
dtype=self.dtype, |
|
) |
|
weight_shape = ( |
|
self.conv.features, |
|
self.conv.features // self.conv.feature_group_count, |
|
self.conv.kernel_size[0], |
|
) |
|
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape) |
|
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) |
|
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) |
|
self.prev_padding = self.conv.kernel_size[0] // 2 |
|
|
|
def _get_normed_weights(self): |
|
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] |
|
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) |
|
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) |
|
return normed_kernel |
|
|
|
def __call__(self, hidden_states): |
|
kernel = self._get_normed_weights() |
|
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))) |
|
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) |
|
self.activation = ACT2FN[self.config.feat_extract_activation] |
|
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 |
|
|
|
def __call__(self, hidden_states): |
|
hidden_states = hidden_states.transpose((0, 1, 2)) |
|
|
|
hidden_states = self.conv(hidden_states) |
|
|
|
if self.num_pad_remove > 0: |
|
hidden_states = hidden_states[:, : -self.num_pad_remove, :] |
|
hidden_states = self.activation(hidden_states) |
|
|
|
hidden_states = hidden_states.transpose((0, 1, 2)) |
|
return hidden_states |
|
|
|
|
|
class FlaxConvLayersCollection(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.config.feat_extract_norm == "layer": |
|
|
|
BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer |
|
self.layers = [ |
|
BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.num_feat_extract_layers) |
|
] |
|
elif self.config.feat_extract_norm == "group": |
|
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported") |
|
else: |
|
raise ValueError( |
|
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']" |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
for i, conv_layer in enumerate(self.layers): |
|
hidden_states = conv_layer(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2FeatureEncoder(nn.Module): |
|
"""Construct the features from raw audio waveform""" |
|
|
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__(self, input_values, freeze_feature_encoder=False): |
|
hidden_states = input_values[:, :, None] |
|
hidden_states = self.conv_layers(hidden_states) |
|
if freeze_feature_encoder: |
|
hidden_states = jax.lax.stop_gradient(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2FeatureProjection(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.projection = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) |
|
|
|
def __call__(self, hidden_states, deterministic=True): |
|
norm_hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = self.projection(norm_hidden_states) |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
return hidden_states, norm_hidden_states |
|
|
|
|
|
class FlaxWav2Vec2Attention(nn.Module): |
|
config: Wav2Vec2Config |
|
embed_dim: int |
|
num_heads: int |
|
dropout: float = 0.0 |
|
bias: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
dense = partial( |
|
nn.Dense, |
|
self.embed_dim, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
) |
|
|
|
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
|
|
|
self.fused_proj = nn.Dense( |
|
self.embed_dim * 3, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
) |
|
|
|
self.out_proj = dense() |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
|
def _split_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
key_value_states: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
if self.config.fuse_matmuls: |
|
attention_states = self.fused_proj(hidden_states) |
|
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1) |
|
|
|
else: |
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = self._split_heads(query_states) |
|
key_states = self._split_heads(key_states) |
|
value_states = self._split_heads(value_states) |
|
|
|
if attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query_states, |
|
key_states, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout, |
|
broadcast_dropout=True, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class FlaxWav2Vec2FeedForward(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout) |
|
|
|
self.intermediate_dense = nn.Dense( |
|
self.config.intermediate_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
if isinstance(self.config.hidden_act, str): |
|
self.intermediate_act_fn = ACT2FN[self.config.hidden_act] |
|
else: |
|
self.intermediate_act_fn = self.config.hidden_act |
|
|
|
self.output_dense = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) |
|
|
|
def __call__(self, hidden_states, deterministic=True): |
|
hidden_states = self.intermediate_dense(hidden_states) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic) |
|
|
|
hidden_states = self.output_dense(hidden_states) |
|
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) |
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.attention = FlaxWav2Vec2Attention( |
|
config=self.config, |
|
embed_dim=self.config.hidden_size, |
|
num_heads=self.config.num_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
dtype=self.dtype, |
|
) |
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
|
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype) |
|
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False): |
|
if self.config.use_scan: |
|
hidden_states = hidden_states[0] |
|
attn_residual = hidden_states |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states, attn_weights = self.attention( |
|
hidden_states, attention_mask=attention_mask, deterministic=deterministic |
|
) |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
hidden_states = attn_residual + hidden_states |
|
hidden_states = hidden_states + self.feed_forward( |
|
self.final_layer_norm(hidden_states), deterministic=deterministic |
|
) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
if self.config.use_scan: |
|
outputs = (outputs, None) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
num_layers = self.config.num_hidden_layers |
|
BlockEncoderLayer = ( |
|
remat( |
|
FlaxWav2Vec2EncoderLayerStableLayerNorm, |
|
static_argnums=(2, 3), |
|
prevent_cse=not self.config.use_scan, |
|
) |
|
if self.config.gradient_checkpointing |
|
else FlaxWav2Vec2EncoderLayerStableLayerNorm |
|
) |
|
|
|
if self.config.use_scan: |
|
|
|
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" |
|
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" |
|
hidden_states = (hidden_states,) |
|
|
|
hidden_states, _ = scan_with_axes( |
|
BlockEncoderLayer, |
|
variable_axes={"params": 0, "cache": 0}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast), |
|
length=num_layers, |
|
)(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)( |
|
hidden_states, attention_mask, deterministic, output_attentions |
|
) |
|
hidden_states = hidden_states[0] |
|
|
|
else: |
|
for layer in range(num_layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = BlockEncoderLayer( |
|
self.config, |
|
dtype=self.dtype, |
|
name=str(layer), |
|
)(hidden_states, attention_mask, deterministic, output_attentions) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions += (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype) |
|
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
|
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
deterministic=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
): |
|
|
|
if attention_mask is not None: |
|
|
|
hidden_states = jnp.where( |
|
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0 |
|
) |
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states) |
|
|
|
hidden_states = hidden_states + position_embeddings |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = self.layer_norm(outputs[0]) |
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
hidden_states = outputs[1] |
|
hidden_states = hidden_states[:-1] + (last_hidden_state,) |
|
|
|
if not return_dict: |
|
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions |
|
) |
|
|
|
|
|
class FlaxWav2Vec2Adapter(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
if self.config.output_hidden_size != self.config.hidden_size: |
|
self.proj = nn.Dense( |
|
self.config.output_hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
else: |
|
self.proj = self.proj_layer_norm = None |
|
|
|
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, deterministic=True): |
|
|
|
if self.proj is not None and self.proj_layer_norm is not None: |
|
hidden_states = self.proj(hidden_states) |
|
hidden_states = self.proj_layer_norm(hidden_states) |
|
|
|
hidden_states = self.layers(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2AdapterLayer(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv = nn.Conv( |
|
features=2 * self.config.output_hidden_size, |
|
kernel_size=(self.config.adapter_kernel_size,), |
|
strides=(self.config.adapter_stride,), |
|
padding=((1, 1),), |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
hidden_states = self.conv(hidden_states) |
|
hidden_states = nn.glu(hidden_states, axis=2) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2AdapterLayersCollection(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer |
|
self.layers = [ |
|
BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.num_adapter_layers) |
|
] |
|
|
|
def __call__(self, hidden_states): |
|
for conv_layer in self.layers: |
|
hidden_states = conv_layer(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = Wav2Vec2Config |
|
base_model_prefix: str = "wav2vec2" |
|
main_input_name = "input_values" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: Wav2Vec2Config, |
|
input_shape: Tuple = (1, 1024), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
|
|
input_values = jnp.zeros(input_shape, dtype="i4") |
|
attention_mask = jnp.ones_like(input_values) |
|
params_rng, dropout_rng = jax.random.split(rng, 2) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] |
|
|
|
def __call__( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
mask_time_indices=None, |
|
extract_features=None, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_features: Optional[bool] = None, |
|
freeze_feature_encoder: bool = False, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if attention_mask is None: |
|
batch_size, sequence_length = input_values.shape |
|
attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if extract_features is not None: |
|
extract_features = jnp.array(extract_features, dtype="f4") |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
return self.module.apply( |
|
inputs, |
|
jnp.array(input_values, dtype="f4"), |
|
jnp.array(attention_mask, dtype="i4"), |
|
mask_time_indices, |
|
extract_features, |
|
not train, |
|
output_attentions, |
|
output_hidden_states, |
|
output_features, |
|
freeze_feature_encoder, |
|
return_dict, |
|
rngs=rngs, |
|
) |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
|
): |
|
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) |
|
|
|
def _get_feature_vector_attention_mask( |
|
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
|
): |
|
return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter) |
|
|
|
|
|
class FlaxWav2Vec2Module(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype) |
|
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) |
|
self.masked_spec_embed = self.param( |
|
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) |
|
) |
|
|
|
if self.config.do_stable_layer_norm: |
|
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) |
|
else: |
|
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") |
|
|
|
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None |
|
|
|
def __call__( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
mask_time_indices=None, |
|
extract_features=None, |
|
deterministic=True, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
output_features=False, |
|
freeze_feature_encoder=False, |
|
return_dict=None, |
|
): |
|
|
|
|
|
if extract_features is None: |
|
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) |
|
|
|
if output_features: |
|
return extract_features |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = self._get_feature_vector_attention_mask( |
|
extract_features.shape[1], attention_mask, add_adapter=False |
|
) |
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) |
|
if mask_time_indices is not None: |
|
hidden_states = jnp.where( |
|
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), |
|
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), |
|
hidden_states, |
|
) |
|
|
|
encoder_outputs = self.encoder( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if self.adapter is not None: |
|
hidden_states = self.adapter(hidden_states) |
|
|
|
if not return_dict: |
|
return (hidden_states, extract_features) + encoder_outputs[1:] |
|
|
|
return FlaxWav2Vec2BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
extract_features=extract_features, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
|
): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return (input_length - kernel_size) // stride + 1 |
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
if add_adapter: |
|
for _ in range(self.config.num_adapter_layers): |
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
|
return input_lengths |
|
|
|
def _get_feature_vector_attention_mask( |
|
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
|
): |
|
|
|
|
|
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] |
|
|
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
|
|
batch_size = attention_mask.shape[0] |
|
|
|
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) |
|
|
|
|
|
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) |
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") |
|
return attention_mask |
|
|
|
|
|
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): |
|
module_class = FlaxWav2Vec2Module |
|
|
|
|
|
class FlaxWav2Vec2ForCTCModule(nn.Module): |
|
config: Wav2Vec2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) |
|
self.dropout = nn.Dropout(rate=self.config.final_dropout) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
mask_time_indices=None, |
|
extract_features=None, |
|
deterministic=True, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
output_features=False, |
|
freeze_feature_encoder=False, |
|
return_dict=None, |
|
): |
|
outputs = self.wav2vec2( |
|
input_values, |
|
attention_mask=attention_mask, |
|
mask_time_indices=mask_time_indices, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
freeze_feature_encoder=freeze_feature_encoder, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
return (logits,) + outputs[2:] |
|
|
|
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
|
): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return (input_length - kernel_size) // stride + 1 |
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
if add_adapter: |
|
for _ in range(self.config.num_adapter_layers): |
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
|
return input_lengths |
|
|
|
def _get_feature_vector_attention_mask( |
|
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
|
): |
|
|
|
|
|
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] |
|
|
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
|
|
batch_size = attention_mask.shape[0] |
|
|
|
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) |
|
|
|
|
|
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) |
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") |
|
return attention_mask |
|
|
|
|
|
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): |
|
module_class = FlaxWav2Vec2ForCTCModule |
|
|