Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import * | |
from math import ceil | |
class AttentionBackend(Enum): | |
Naive = 0 | |
FlashAttentionCuda = 1 | |
FlashAttentionTriton = 2 | |
global_config = { | |
'attn_backend': AttentionBackend.Naive | |
} | |
class TransformerConfig: | |
vocab_size: int = -1, | |
num_layers: int = -1, | |
num_heads: int = -1, | |
hidden_size: int = -1, | |
max_seq_len: int = -1, | |
root_model: 'ToyTransformer' = None | |
device: torch.device = torch.device('cpu') | |
dtype: torch.dtype = torch.float32 | |
def expand_attn_mask(custom_attn_mask: torch.Tensor): | |
B, T = custom_attn_mask.shape | |
mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1)) | |
seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1)) | |
return seq_index_mask & (torch.tril(mask) > 0) | |
# expand attn mask to cu_seqlens for flash attn | |
def expand_attn_mask_to_seq_lengths(attn_mask: torch.Tensor): | |
attn_mask = attn_mask.to('cpu') | |
seq_len = attn_mask.shape[0] * attn_mask.shape[1] | |
disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1) | |
return torch.cat([torch.nonzero(disjoint_point.view((-1,))), torch.tensor([[seq_len]])]).to(dtype=torch.int32) | |
# naive RoPE implementation following https://arxiv.org/pdf/2104.09864.pdf | |
def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype): | |
assert dim % 2 == 0 | |
freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim) | |
freqs = torch.repeat_interleave(freqs, 2) | |
v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs) | |
v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs) | |
v2 = v2 * torch.tensor([1, -1] * (dim // 2)) | |
indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)]) | |
return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device) | |
def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None): | |
v1, v2, indices = rope_cache | |
seq_len, dim = x.shape[1:] | |
if positions is None: | |
v1 = v1[:seq_len, :] | |
v2 = v2[:seq_len, :] | |
else: | |
v1 = v1[positions, torch.arange(dim)].view((-1, dim)) | |
v2 = v2[positions, torch.arange(dim)].view((-1, dim)) | |
applied_x = x * v1 + (x * v2)[:, :, indices] | |
return applied_x | |
# Optimized RoPE implementation adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py | |
def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype): | |
freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))) | |
t = torch.arange(seq_len, device=freqs.device) | |
freqs = torch.outer(t, freqs).float() | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
return freqs_cis.to(device) | |
def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor: | |
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
if positions is None and x.shape[1] < rope_cache.shape[0]: | |
freqs_cis = rope_cache[:x.shape[1], :] | |
elif positions is not None: | |
freqs_cis = rope_cache[positions, :] | |
else: | |
freqs_cis = rope_cache | |
freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)]) | |
applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2) | |
return applied_x.type_as(x) | |
# RMSNorm implementation following https://arxiv.org/pdf/1910.07467.pdf | |
class RMSNorm(nn.Module): | |
def __init__(self, hidden_size, dtype, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor): | |
x_ = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
return self.weight * x_ | |
class AttentionHead(nn.Module): | |
def __init__(self, config: TransformerConfig): | |
super().__init__() | |
self.config = config | |
self.head_size = config.hidden_size // config.num_heads | |
self.dtype = config.dtype | |
self.q_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
self.k_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
self.v_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype) | |
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
q = self.q_proj(x) | |
k = self.k_proj(x) | |
v = self.v_proj(x) | |
# if global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton: | |
# padding the position indices for alignment | |
# positions = torch.tensor([kv_cache[0].shape[1]] * q.shape[1]).to(q.device) if kv_cache is not None else torch.arange(0, x.shape[1], 1).to(q.device) | |
positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None | |
q = apply_rope_fast(q, self.config.root_model.rope_cache, positions) | |
k = apply_rope_fast(k, self.config.root_model.rope_cache, positions) | |
if kv_cache is not None: | |
k = torch.concat([kv_cache[0], k], dim=1) | |
v = torch.concat([kv_cache[1], v], dim=1) | |
if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda: | |
q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2) | |
attn_result = flash_attn_func(q, k, v, causal=True) | |
q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2) | |
elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton: | |
q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2) | |
attn_result = flash_attn_func_triton(q, k, v, attn_masked_bias.unsqueeze(1) if attn_masked_bias is not None else None, | |
True if kv_cache is None else False) | |
q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2) | |
else: | |
attn_score = (q @ k.permute(0, 2, 1) / (self.head_size ** 0.5)) + attn_masked_bias | |
attn_result = torch.softmax(attn_score, dim=2) @ v | |
return attn_result, [k, v] | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config: TransformerConfig): | |
super().__init__() | |
self.config = config | |
self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)]) | |
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype) | |
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]: | |
head_outputs = [head(x, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) for idx, head in | |
enumerate(self.attn_heads)] | |
return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs] | |
class DecoderLayer(nn.Module): | |
def __init__(self, config: TransformerConfig): | |
super().__init__() | |
self.config = config | |
self.mha = MultiHeadAttention(config) | |
self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype) | |
self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype) | |
self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype) | |
self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype) | |
self.act = nn.GELU() | |
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor], | |
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]: | |
mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_masked_bias, kv_cache) | |
mha_output = x + mha_output | |
ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output)))) | |
return mha_output + ffn_output, new_kv_cache | |
class ToyTransformer(nn.Module): | |
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int, | |
device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32): | |
super().__init__() | |
self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device, | |
dtype) | |
self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype) | |
self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype) | |
self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)]) | |
self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype) | |
self.to(device) | |
def forward(self, seq: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]: | |
# sanity checks | |
assert attn_mask is None or kv_cache is None # No support for attn_mask and kv_cache both enabled | |
if kv_cache is not None: | |
assert seq.shape[0] == 1, 'kv_cache is not supported for batch inference' | |
# handle flash-attn triton alignment requirement (actually only needed for backward) | |
seq_length = seq.shape[1] | |
if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0: | |
if attn_mask is None: # forcibly enable attn_mask due to padding | |
attn_mask = torch.ones(seq.shape, device=self.device) | |
pad_length = (ceil(seq_length / 128) * 128) - seq_length | |
seq = nn.functional.pad(seq, (0, pad_length)) | |
attn_mask = nn.functional.pad(attn_mask, (0, pad_length)) | |
# handle attn_bias | |
if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda: | |
assert attn_mask is None, 'FlashAttn-Cuda does not support custom attn_mask' | |
attn_masked_bias = None | |
elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and attn_mask is None: | |
attn_masked_bias = None | |
elif attn_mask is not None: | |
attn_masked_bias = expand_attn_mask(attn_mask) | |
elif attn_mask is None and kv_cache is None: | |
attn_masked_bias = expand_attn_mask(torch.ones(seq.shape, device=self.device)) | |
elif kv_cache is not None: | |
attn_masked_bias = torch.ones((1, seq.shape[1], seq.shape[1]), dtype=torch.bool, device=self.device) | |
else: | |
attn_masked_bias = None | |
if attn_masked_bias is not None: | |
mask_zero = torch.tensor(0, dtype=self.config.dtype) | |
mask_val = torch.tensor(torch.finfo(self.config.dtype).min / 2, dtype=self.config.dtype) | |
attn_masked_bias = torch.where(attn_masked_bias, mask_zero, mask_val).to(self.device) | |
hidden = self.sem_embed(seq) | |
new_kv_cache = [] | |
for idx, decoder in enumerate(self.decoder_layers): | |
hidden, layer_kv_cache = decoder(hidden, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) | |
new_kv_cache.append(layer_kv_cache) | |
logits = self.lm_head(hidden) | |
# remove padding for flash-attn triton | |
if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0: | |
logits = logits[:, :seq_length, :] | |
new_kv_cache = [[[cache[:, :seq_length, :] for cache in head] for head in layer] for layer in new_kv_cache] | |
return logits, new_kv_cache | |
def device(self): | |
return next(self.parameters()).device | |