xcodec2 / vq /bs_roformer5.py
yezhen
Initial commit
574a515
raw
history blame
3.26 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
import torchaudio
from einops import rearrange
import numpy as np
# from rotary_embedding_torch import RotaryEmbedding
from torchtune.modules import RotaryPositionalEmbeddings
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
output = x * torch.rsqrt(norm_x + self.eps) * self.weight
return output
class MLP(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
self.silu = nn.SiLU()
self.fc2 = nn.Linear(4 * dim, dim, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.silu(x)
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
super().__init__()
assert dim % n_heads == 0
self.n_heads = n_heads
self.dim = dim
self.rotary_embed = rotary_embed
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
assert self.flash, "Must have flash attention."
self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
self.c_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x):
r"""
Args:
x: (b, t, h*d)
Constants:
b: batch_size
t: time steps
r: 3
h: heads_num
d: heads_dim
"""
B, T, C = x.size()
q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
# q, k, v: (b, h, t, d)
q = self.rotary_embed(q)
k = self.rotary_embed(k)
if self.flash:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
y = rearrange(y, 'b h t d -> b t (h d)')
y = self.c_proj(y)
# shape: (b, t, h*d)
return y
class TransformerBlock(nn.Module):
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.att_norm = RMSNorm(dim)
self.ffn_norm = RMSNorm(dim)
self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
self.mlp = MLP(dim=dim)
def forward(
self,
x: torch.Tensor,
):
x = x + self.att(self.att_norm(x))
x = x + self.mlp(self.ffn_norm(x))
return x
if __name__ == '__main__':
rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
transformer_block = TransformerBlock(
dim=1024,
n_heads=8,
rotary_embed=rotary_embed_128
)
x = torch.randn(2, 128, 1024)
y = transformer_block(x)
print(y.shape)
c=1