File size: 19,192 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from models.mossformer_gan_se.conv_module import ConvModule

# Helper functions

def exists(val):
    """
    Check if a value is not None.
    
    Args:
        val: The value to check.
        
    Returns:
        bool: True if the value exists (is not None), False otherwise.
    """
    return val is not None

def default(val, d):
    """
    Return the value if it exists, otherwise return a default value.
    
    Args:
        val: The value to check.
        d: The default value to return if val is None.
        
    Returns:
        The original value or the default value.
    """
    return val if exists(val) else d

def padding_to_multiple_of(n, mult):
    """
    Calculate padding to make a number a multiple of another number.
    
    Args:
        n (int): The number to pad.
        mult (int): The multiple to pad to.
    
    Returns:
        int: The padding value.
    """
    remainder = n % mult
    if remainder == 0:
        return 0
    return mult - remainder

# ScaleNorm
class ScaleNorm(nn.Module):
    """
    Normalization layer that scales inputs based on the dimensionality of the input.
    
    Args:
        dim (int): The input dimension.
        eps (float): A small value to prevent division by zero (default: 1e-5).
    """
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.scale = dim ** -0.5  # Scale factor based on input dimension
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1))  # Learnable scale parameter

    def forward(self, x):
        # Normalize the input along the last dimension and apply scaling
        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
        return x / norm.clamp(min=self.eps) * self.g

# Absolute positional encodings
class ScaledSinuEmbedding(nn.Module):
    """
    Sine-cosine absolute positional embeddings with scaling.
    
    Args:
        dim (int): The dimension of the positional embedding.
    """
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)  # Store frequency values for sine and cosine

    def forward(self, x):
        # Generate sine and cosine positional encodings
        n, device = x.shape[1], x.device
        t = torch.arange(n, device=device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
        return emb * self.scale  # Apply scaling to the positional embeddings

# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
    """
    Relative positional bias based on T5 model design.
    
    Args:
        scale (float): Scaling factor for the bias.
        causal (bool): Whether to apply a causal mask (default: False).
        num_buckets (int): Number of relative position buckets (default: 32).
        max_distance (int): Maximum distance for relative positions (default: 128).
    """
    def __init__(self, scale, causal=False, num_buckets=32, max_distance=128):
        super().__init__()
        self.eps = 1e-5
        self.scale = scale
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, 1)  # Bias embedding for relative positions

    @staticmethod
    def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
        """
        Bucket relative positions into discrete ranges for bias calculation.
        
        Args:
            relative_position (Tensor): The relative position tensor.
            causal (bool): Whether to consider causality.
            num_buckets (int): Number of relative position buckets.
            max_distance (int): Maximum distance for the position.

        Returns:
            Tensor: Bucketed relative positions.
        """
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, x):
        # Calculate relative position bias for attention
        i, j, device = *x.shape[-2:], x.device
        q_pos = torch.arange(i, dtype=torch.long, device=device)
        k_pos = torch.arange(j, dtype=torch.long, device=device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance)
        values = self.relative_attention_bias(rp_bucket)  # Get bias values
        bias = rearrange(values, 'i j 1 -> i j')
        return bias * self.scale  # Apply scaling to the bias

# Relative Position Embeddings
class RelativePosition(nn.Module):
    """
    Relative positional embeddings with configurable number of units and max position.
    
    Args:
        num_units (int): The number of embedding units (default: 32).
        max_relative_position (int): The maximum relative position (default: 128).
    """
    def __init__(self, num_units=32, max_relative_position=128):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        nn.init.xavier_uniform_(self.embeddings_table)  # Initialize embedding weights

    def forward(self, x):
        # Generate relative position embeddings
        length_q, length_k, device = *x.shape[-2:], x.device
        range_vec_q = torch.arange(length_q, dtype=torch.long, device=device)
        range_vec_k = torch.arange(length_k, dtype=torch.long, device=device)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]  # Compute relative distances
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = (distance_mat_clipped + self.max_relative_position)
        embeddings = self.embeddings_table[final_mat]  # Get embeddings based on distances

        return embeddings

# Offset and Scale module
class OffsetScale(nn.Module):
    """
    Offset and scale operation applied across heads and dimensions.
    
    Args:
        dim (int): Input dimensionality.
        heads (int): Number of attention heads (default: 1).
    """
    def __init__(self, dim, heads=1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))  # Learnable scaling parameter
        self.beta = nn.Parameter(torch.zeros(heads, dim))  # Learnable offset parameter
        nn.init.normal_(self.gamma, std=0.02)  # Initialize gamma with small random values

    def forward(self, x):
        # Apply offset and scale across heads
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim=-2)  # Return the result unbound along the last head dimension

class FFConvM(nn.Module):
    """
    FFConvM is a feedforward convolutional module that applies a series of transformations
    to an input tensor. The transformations include normalization, linear projection,
    activation, convolution, and dropout. It combines feedforward layers with a convolutional
    module to enhance the feature extraction process.

    Args:
        dim_in: Input feature dimension.
        dim_out: Output feature dimension.
        norm_klass: Normalization class to apply (default is LayerNorm).
        dropout: Dropout probability to prevent overfitting (default is 0.1).
    """
    def __init__(
        self,
        dim_in,    # Input feature dimension
        dim_out,   # Output feature dimension
        norm_klass=nn.LayerNorm,  # Normalization class (default: LayerNorm)
        dropout=0.1  # Dropout probability
    ):
        super().__init__()

        # Sequentially apply normalization, linear transformation, activation, convolution, and dropout
        self.mdl = nn.Sequential(
            norm_klass(dim_in),  # Apply normalization (LayerNorm by default)
            nn.Linear(dim_in, dim_out),  # Linear projection from dim_in to dim_out
            nn.SiLU(),  # Activation function (SiLU - Sigmoid Linear Unit)
            ConvModule(dim_out),  # Apply convolution using ConvModule
            nn.Dropout(dropout)  # Apply dropout for regularization
        )

    def forward(self, x):
        """
        Forward pass through the module.
        
        Args:
            x: Input tensor of shape (batch_size, seq_length, dim_in)
        
        Returns:
            output: Transformed output tensor of shape (batch_size, seq_length, dim_out)
        """
        output = self.mdl(x)  # Pass the input through the sequential model
        return output  # Return the processed output

class MossFormer(nn.Module):
    """
    The MossFormer class implements a transformer-based model designed for handling
    triple-attention mechanisms with both quadratic and linear attention components. 
    The model processes inputs through token shifts, multi-head attention, and gated 
    feedforward layers, while optionally supporting causal operations.
    
    Args:
        dim (int): Dimensionality of input features.
        group_size (int): Size of the group dimension for attention.
        query_key_dim (int): Dimensionality of the query and key vectors for attention.
        expansion_factor (float): Expansion factor for the hidden dimensions.
        causal (bool): Whether to apply causal masking for autoregressive tasks.
        dropout (float): Dropout rate for regularization.
        norm_klass (nn.Module): Normalization layer to be applied.
        shift_tokens (bool): Whether to apply token shifting as a preprocessing step.
    """
    def __init__(
        self,
        dim,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        causal = False,
        dropout = 0.1,
        norm_klass = nn.LayerNorm,
        shift_tokens = True
    ):
        super().__init__()
        hidden_dim = int(dim * expansion_factor)        
        self.group_size = group_size
        self.causal = causal
        self.shift_tokens = shift_tokens

        # Positional embeddings for attention.
        self.rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        # Dropout layer for regularization.
        self.dropout = nn.Dropout(dropout)
        
        # Projection layers for input features to hidden dimensions.
        self.to_hidden = FFConvM(
            dim_in = dim,
            dim_out = hidden_dim,
            norm_klass = norm_klass,
            dropout = dropout,
        )
        self.to_qk = FFConvM(
            dim_in = dim,
            dim_out = query_key_dim,
            norm_klass = norm_klass,
            dropout = dropout,
        )

        self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)

        # Output projection layer to return to original feature dimensions.
        self.to_out = FFConvM(
            dim_in = dim * int(expansion_factor // 2),
            dim_out = dim,
            norm_klass = norm_klass,
            dropout = dropout,
        )
        
        self.gateActivate = nn.Sigmoid() 

    def forward(
        self,
        x,
        *,
        mask = None
    ):
        """
        Forward pass for the MossFormer module.

        Args:
            x (Tensor): Input tensor of shape (B, T, Q, C) where:
                B = batch size,
                T = total sequence length,
                Q = number of query features,
                C = feature dimension.
            mask (Tensor, optional): Attention mask for padding.

        Returns:
            Tensor: Output tensor of shape (B, T, C).
        """
        
        # Unpack input dimensions
        B, T, Q, C = x.size()
        x = x.view(B*T, Q, C)  # Reshape input for processing

        # Prenormalization step
        normed_x = x 

        # Optionally shift tokens for better performance
        residual = x  # Store residual for skip connection

        if self.shift_tokens:
            # Split and shift tokens for enhanced information flow
            x_shift, x_pass = normed_x.chunk(2, dim = -1)
            x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)  # Pad to maintain shape
            normed_x = torch.cat((x_shift, x_pass), dim = -1)

        # Initial projections to hidden space
        v, u = self.to_hidden(normed_x).chunk(2, dim = -1)  # Split into two tensors
        qk = self.to_qk(normed_x)  # Project to query/key dimensions

        # Offset and scale for attention
        quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
        att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u, B)

        # Gate the outputs and apply skip connection
        out = (att_u * v) * self.gateActivate(att_v * u)       
        x = x + self.to_out(out)  # Combine with residual
        return x

    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, B, mask = None):
        """
        Calculates both quadratic and linear attention outputs.

        Args:
            x (Tensor): Input tensor of shape (B, n, d).
            quad_q (Tensor): Quadratic queries tensor.
            lin_q (Tensor): Linear queries tensor.
            quad_k (Tensor): Quadratic keys tensor.
            lin_k (Tensor): Linear keys tensor.
            v (Tensor): Value tensor for attention.
            u (Tensor): Auxiliary tensor for attention.
            B (int): Batch size.
            mask (Tensor, optional): Attention mask for padding.

        Returns:
            Tuple[Tensor, Tensor]: Quadratic and linear attention outputs.
        """
        
        b, n, device, g = x.shape[0], x.shape[-2],  x.device, self.group_size

        if exists(mask):
            # Apply mask to linear keys if provided
            lin_mask = rearrange(mask, '... -> ... 1')
            lin_k = lin_k.masked_fill(~lin_mask, 0.)

        # Rotate queries and keys using positional embeddings
        if exists(self.rotary_pos_emb):
            quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))

        # Padding to handle groups
        padding = padding_to_multiple_of(n, n)

        if padding > 0:
            # Pad tensors to accommodate group sizes
            quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))

            mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
            mask = F.pad(mask, (0, padding), value = False)

        # Reshape for grouped attention
        quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = n), (quad_q, quad_k, lin_q, lin_k, v, u))

        BT, K, Q, C = quad_q.size()
        quad_q_c = quad_q.view(B, -1, Q, C).transpose(2, 1)  # Prepare for computation
        quad_k_c = quad_k.view(B, -1, Q, C).transpose(2, 1)  
        v_c = v.view(B, -1, Q, C).transpose(2, 1)  
        u_c = u.view(B, -1, Q, C).transpose(2, 1)  

        if exists(mask):
            mask = rearrange(mask, 'b (g j) -> b g 1 j', j = n)  # Adjust mask dimensions

        # Calculate quadratic attention output
        sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / n
        sim_c = einsum('... i d, ... j d -> ... i j', quad_q_c, quad_k_c) / quad_q_c.shape[-2]
        
        # Avoid introducing infinite loss probability
        attn = F.relu(sim) ** 2
        attn = self.dropout(attn)  # Apply dropout for regularization
        
        attn_c = F.relu(sim_c) ** 2
        attn_c = self.dropout(attn_c)  # Apply dropout for the computed attention
        mask_c = torch.eye(quad_q_c.shape[-2], dtype = torch.bool, device = device)
        attn_c = attn_c.masked_fill(mask_c, 0.)  # Mask diagonal for attention

        if exists(mask):
            attn = attn.masked_fill(~mask, 0.)  # Apply the mask to the main attention

        if self.causal:
            # Create a causal mask for the attention
            causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
            attn = attn.masked_fill(causal_mask, 0.)  # Apply causal mask

        # Calculate the output for quadratic attention
        quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
        quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)

        # Calculate output for the causal quadratic attention
        quad_out_v_c = einsum('... i j, ... j d -> ... i d', attn_c, v_c)
        quad_out_u_c = einsum('... i j, ... j d -> ... i d', attn_c, u_c)
        quad_out_v_c = quad_out_v_c.transpose(2, 1).contiguous().view(BT, K, Q, C)
        quad_out_u_c = quad_out_u_c.transpose(2, 1).contiguous().view(BT, K, Q, C)

        # Combine the outputs from quadratic attention
        quad_out_v = quad_out_v + quad_out_v_c
        quad_out_u = quad_out_u + quad_out_u_c

        # Calculate linear attention output
        if self.causal:
            # Handle causal linear attention
            lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / n
            lin_kv = lin_kv.cumsum(dim = 1)  # Exclusive cumulative sum
            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)

            lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / n
            lin_ku = lin_ku.cumsum(dim = 1)  # Exclusive cumulative sum
            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
        else:
            # Handle non-causal linear attention
            lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
            lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)

            lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
            lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)

        # Reshape and excise out padding
        quad_attn_out_v, lin_attn_out_v = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v, lin_out_v))
        quad_attn_out_u, lin_attn_out_u = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_u, lin_out_u))

        return quad_attn_out_v + lin_attn_out_v, quad_attn_out_u + lin_attn_out_u