alibabasglab's picture
Upload 161 files
8e8cd3e verified
raw
history blame
19.2 kB
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from models.mossformer_gan_se.conv_module import ConvModule
# Helper functions
def exists(val):
"""
Check if a value is not None.
Args:
val: The value to check.
Returns:
bool: True if the value exists (is not None), False otherwise.
"""
return val is not None
def default(val, d):
"""
Return the value if it exists, otherwise return a default value.
Args:
val: The value to check.
d: The default value to return if val is None.
Returns:
The original value or the default value.
"""
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
"""
Calculate padding to make a number a multiple of another number.
Args:
n (int): The number to pad.
mult (int): The multiple to pad to.
Returns:
int: The padding value.
"""
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
# ScaleNorm
class ScaleNorm(nn.Module):
"""
Normalization layer that scales inputs based on the dimensionality of the input.
Args:
dim (int): The input dimension.
eps (float): A small value to prevent division by zero (default: 1e-5).
"""
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5 # Scale factor based on input dimension
self.eps = eps
self.g = nn.Parameter(torch.ones(1)) # Learnable scale parameter
def forward(self, x):
# Normalize the input along the last dimension and apply scaling
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
# Absolute positional encodings
class ScaledSinuEmbedding(nn.Module):
"""
Sine-cosine absolute positional embeddings with scaling.
Args:
dim (int): The dimension of the positional embedding.
"""
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1,))
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) # Store frequency values for sine and cosine
def forward(self, x):
# Generate sine and cosine positional encodings
n, device = x.shape[1], x.device
t = torch.arange(n, device=device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
return emb * self.scale # Apply scaling to the positional embeddings
# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
"""
Relative positional bias based on T5 model design.
Args:
scale (float): Scaling factor for the bias.
causal (bool): Whether to apply a causal mask (default: False).
num_buckets (int): Number of relative position buckets (default: 32).
max_distance (int): Maximum distance for relative positions (default: 128).
"""
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128):
super().__init__()
self.eps = 1e-5
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, 1) # Bias embedding for relative positions
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
"""
Bucket relative positions into discrete ranges for bias calculation.
Args:
relative_position (Tensor): The relative position tensor.
causal (bool): Whether to consider causality.
num_buckets (int): Number of relative position buckets.
max_distance (int): Maximum distance for the position.
Returns:
Tensor: Bucketed relative positions.
"""
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, x):
# Calculate relative position bias for attention
i, j, device = *x.shape[-2:], x.device
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket) # Get bias values
bias = rearrange(values, 'i j 1 -> i j')
return bias * self.scale # Apply scaling to the bias
# Relative Position Embeddings
class RelativePosition(nn.Module):
"""
Relative positional embeddings with configurable number of units and max position.
Args:
num_units (int): The number of embedding units (default: 32).
max_relative_position (int): The maximum relative position (default: 128).
"""
def __init__(self, num_units=32, max_relative_position=128):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table) # Initialize embedding weights
def forward(self, x):
# Generate relative position embeddings
length_q, length_k, device = *x.shape[-2:], x.device
range_vec_q = torch.arange(length_q, dtype=torch.long, device=device)
range_vec_k = torch.arange(length_k, dtype=torch.long, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None] # Compute relative distances
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = (distance_mat_clipped + self.max_relative_position)
embeddings = self.embeddings_table[final_mat] # Get embeddings based on distances
return embeddings
# Offset and Scale module
class OffsetScale(nn.Module):
"""
Offset and scale operation applied across heads and dimensions.
Args:
dim (int): Input dimensionality.
heads (int): Number of attention heads (default: 1).
"""
def __init__(self, dim, heads=1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim)) # Learnable scaling parameter
self.beta = nn.Parameter(torch.zeros(heads, dim)) # Learnable offset parameter
nn.init.normal_(self.gamma, std=0.02) # Initialize gamma with small random values
def forward(self, x):
# Apply offset and scale across heads
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim=-2) # Return the result unbound along the last head dimension
class FFConvM(nn.Module):
"""
FFConvM is a feedforward convolutional module that applies a series of transformations
to an input tensor. The transformations include normalization, linear projection,
activation, convolution, and dropout. It combines feedforward layers with a convolutional
module to enhance the feature extraction process.
Args:
dim_in: Input feature dimension.
dim_out: Output feature dimension.
norm_klass: Normalization class to apply (default is LayerNorm).
dropout: Dropout probability to prevent overfitting (default is 0.1).
"""
def __init__(
self,
dim_in, # Input feature dimension
dim_out, # Output feature dimension
norm_klass=nn.LayerNorm, # Normalization class (default: LayerNorm)
dropout=0.1 # Dropout probability
):
super().__init__()
# Sequentially apply normalization, linear transformation, activation, convolution, and dropout
self.mdl = nn.Sequential(
norm_klass(dim_in), # Apply normalization (LayerNorm by default)
nn.Linear(dim_in, dim_out), # Linear projection from dim_in to dim_out
nn.SiLU(), # Activation function (SiLU - Sigmoid Linear Unit)
ConvModule(dim_out), # Apply convolution using ConvModule
nn.Dropout(dropout) # Apply dropout for regularization
)
def forward(self, x):
"""
Forward pass through the module.
Args:
x: Input tensor of shape (batch_size, seq_length, dim_in)
Returns:
output: Transformed output tensor of shape (batch_size, seq_length, dim_out)
"""
output = self.mdl(x) # Pass the input through the sequential model
return output # Return the processed output
class MossFormer(nn.Module):
"""
The MossFormer class implements a transformer-based model designed for handling
triple-attention mechanisms with both quadratic and linear attention components.
The model processes inputs through token shifts, multi-head attention, and gated
feedforward layers, while optionally supporting causal operations.
Args:
dim (int): Dimensionality of input features.
group_size (int): Size of the group dimension for attention.
query_key_dim (int): Dimensionality of the query and key vectors for attention.
expansion_factor (float): Expansion factor for the hidden dimensions.
causal (bool): Whether to apply causal masking for autoregressive tasks.
dropout (float): Dropout rate for regularization.
norm_klass (nn.Module): Normalization layer to be applied.
shift_tokens (bool): Whether to apply token shifting as a preprocessing step.
"""
def __init__(
self,
dim,
group_size = 256,
query_key_dim = 128,
expansion_factor = 4.,
causal = False,
dropout = 0.1,
norm_klass = nn.LayerNorm,
shift_tokens = True
):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
# Positional embeddings for attention.
self.rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
# Dropout layer for regularization.
self.dropout = nn.Dropout(dropout)
# Projection layers for input features to hidden dimensions.
self.to_hidden = FFConvM(
dim_in = dim,
dim_out = hidden_dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.to_qk = FFConvM(
dim_in = dim,
dim_out = query_key_dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
# Output projection layer to return to original feature dimensions.
self.to_out = FFConvM(
dim_in = dim * int(expansion_factor // 2),
dim_out = dim,
norm_klass = norm_klass,
dropout = dropout,
)
self.gateActivate = nn.Sigmoid()
def forward(
self,
x,
*,
mask = None
):
"""
Forward pass for the MossFormer module.
Args:
x (Tensor): Input tensor of shape (B, T, Q, C) where:
B = batch size,
T = total sequence length,
Q = number of query features,
C = feature dimension.
mask (Tensor, optional): Attention mask for padding.
Returns:
Tensor: Output tensor of shape (B, T, C).
"""
# Unpack input dimensions
B, T, Q, C = x.size()
x = x.view(B*T, Q, C) # Reshape input for processing
# Prenormalization step
normed_x = x
# Optionally shift tokens for better performance
residual = x # Store residual for skip connection
if self.shift_tokens:
# Split and shift tokens for enhanced information flow
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) # Pad to maintain shape
normed_x = torch.cat((x_shift, x_pass), dim = -1)
# Initial projections to hidden space
v, u = self.to_hidden(normed_x).chunk(2, dim = -1) # Split into two tensors
qk = self.to_qk(normed_x) # Project to query/key dimensions
# Offset and scale for attention
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u, B)
# Gate the outputs and apply skip connection
out = (att_u * v) * self.gateActivate(att_v * u)
x = x + self.to_out(out) # Combine with residual
return x
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, B, mask = None):
"""
Calculates both quadratic and linear attention outputs.
Args:
x (Tensor): Input tensor of shape (B, n, d).
quad_q (Tensor): Quadratic queries tensor.
lin_q (Tensor): Linear queries tensor.
quad_k (Tensor): Quadratic keys tensor.
lin_k (Tensor): Linear keys tensor.
v (Tensor): Value tensor for attention.
u (Tensor): Auxiliary tensor for attention.
B (int): Batch size.
mask (Tensor, optional): Attention mask for padding.
Returns:
Tuple[Tensor, Tensor]: Quadratic and linear attention outputs.
"""
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
if exists(mask):
# Apply mask to linear keys if provided
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask, 0.)
# Rotate queries and keys using positional embeddings
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
# Padding to handle groups
padding = padding_to_multiple_of(n, n)
if padding > 0:
# Pad tensors to accommodate group sizes
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
mask = F.pad(mask, (0, padding), value = False)
# Reshape for grouped attention
quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = n), (quad_q, quad_k, lin_q, lin_k, v, u))
BT, K, Q, C = quad_q.size()
quad_q_c = quad_q.view(B, -1, Q, C).transpose(2, 1) # Prepare for computation
quad_k_c = quad_k.view(B, -1, Q, C).transpose(2, 1)
v_c = v.view(B, -1, Q, C).transpose(2, 1)
u_c = u.view(B, -1, Q, C).transpose(2, 1)
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = n) # Adjust mask dimensions
# Calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / n
sim_c = einsum('... i d, ... j d -> ... i j', quad_q_c, quad_k_c) / quad_q_c.shape[-2]
# Avoid introducing infinite loss probability
attn = F.relu(sim) ** 2
attn = self.dropout(attn) # Apply dropout for regularization
attn_c = F.relu(sim_c) ** 2
attn_c = self.dropout(attn_c) # Apply dropout for the computed attention
mask_c = torch.eye(quad_q_c.shape[-2], dtype = torch.bool, device = device)
attn_c = attn_c.masked_fill(mask_c, 0.) # Mask diagonal for attention
if exists(mask):
attn = attn.masked_fill(~mask, 0.) # Apply the mask to the main attention
if self.causal:
# Create a causal mask for the attention
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
attn = attn.masked_fill(causal_mask, 0.) # Apply causal mask
# Calculate the output for quadratic attention
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
# Calculate output for the causal quadratic attention
quad_out_v_c = einsum('... i j, ... j d -> ... i d', attn_c, v_c)
quad_out_u_c = einsum('... i j, ... j d -> ... i d', attn_c, u_c)
quad_out_v_c = quad_out_v_c.transpose(2, 1).contiguous().view(BT, K, Q, C)
quad_out_u_c = quad_out_u_c.transpose(2, 1).contiguous().view(BT, K, Q, C)
# Combine the outputs from quadratic attention
quad_out_v = quad_out_v + quad_out_v_c
quad_out_u = quad_out_u + quad_out_u_c
# Calculate linear attention output
if self.causal:
# Handle causal linear attention
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / n
lin_kv = lin_kv.cumsum(dim = 1) # Exclusive cumulative sum
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / n
lin_ku = lin_ku.cumsum(dim = 1) # Exclusive cumulative sum
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
else:
# Handle non-causal linear attention
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
# Reshape and excise out padding
quad_attn_out_v, lin_attn_out_v = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v, lin_out_v))
quad_attn_out_u, lin_attn_out_u = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_u, lin_out_u))
return quad_attn_out_v + lin_attn_out_v, quad_attn_out_u + lin_attn_out_u