kaizen9's picture
Upload model checkpoints directly from S3
de5c30a verified
raw
history blame contribute delete
No virus
7.96 kB
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
class _UniformExpertAssignment(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, num_experts: int):
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
out = torch.remainder(out, num_experts)
return out.view(x.shape)
class LearnedRouter(torch.nn.Module):
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: Optional[float], moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device]) -> None:
super().__init__()
self.hidden_size: int = hidden_size
self.moe_num_experts: int = moe_num_experts
self.moe_top_k: int = moe_top_k
self.moe_jitter_eps: Optional[float] = moe_jitter_eps
self.moe_normalize_expert_weights: Optional[Union[int, float]] = moe_normalize_expert_weights
self.uniform_expert_assignment: bool = uniform_expert_assignment
self.layer: torch.nn.Module = torch.nn.Linear(hidden_size, moe_num_experts, bias=False, device=device)
def jitter(self, x: torch.Tensor) -> torch.Tensor:
assert self.moe_jitter_eps is not None
low: float = 1.0 - self.moe_jitter_eps
high: float = 1.0 + self.moe_jitter_eps
noise: torch.Tensor = torch.rand(x.size(), dtype=x.dtype, device=x.device)
return low + noise * (high - low)
def _top_k(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.moe_top_k == 1:
values, indices = scores.max(dim=-1)
return (values.unsqueeze(-1), indices.unsqueeze(-1))
return torch.topk(scores, self.moe_top_k, dim=-1)
def forward(self, x: torch.Tensor):
if self.training and self.moe_jitter_eps is not None:
x = x * self.jitter(x)
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
expert_weights, top_experts = self._top_k(scores)
if self.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(expert_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
top_experts = _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) if self.uniform_expert_assignment else top_experts
scores = scores.to(x.dtype)
expert_weights = expert_weights.to(x.dtype)
return (scores, expert_weights, top_experts)
class MLP(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]) -> None:
super().__init__()
self.moe_num_experts: int = moe_num_experts
self.ffn_hidden_size: int = ffn_hidden_size
self.hidden_size: int = hidden_size
self.activation_fn: Callable = activation_fn
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.activation_fn = activation_fn
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
before_activation = x @ expert_w1.t()
layer_1_output = self.activation_fn(before_activation)
output = layer_1_output @ expert_w2
return output
class GLU(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.v1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.activation_fn = activation_fn
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
x1 = x.matmul(expert_w1.t())
x2 = x.matmul(expert_v1.t())
x1 = self.activation_fn(x1)
x1 = x1 * x2
x1 = x1.matmul(expert_w2)
return x1
class DroplessMLP(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, moe_num_experts: int, activation_fn: Callable, bias: bool, device: Optional[torch.device]):
super().__init__()
self.moe_num_experts = moe_num_experts
if mlp_type == 'mlp':
self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
elif mlp_type == 'glu':
self.mlp = GLU(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
else:
raise ValueError(f'Received unknown mlp_type={mlp_type!r}')
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
in_shape = x.shape
hidden_size = in_shape[-1]
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
mlp_output = self.mlp(expert_tokens, expert_idx)
expert_weights = expert_weights.to(mlp_output.device)
expert_out = mlp_output * expert_weights[token_list, topk_list, None]
out = out.to(mlp_output.device)
token_idx = token_idx.to(mlp_output.device)
out.index_add_(0, token_idx, expert_out)
out = out.view(in_shape)
return out
class dMoE(torch.nn.Module):
def __init__(self, device: Optional[torch.device], hidden_size: int=1024, ffn_hidden_size: int=4096, moe_num_experts: int=1, moe_top_k: int=1, mlp_type: str='mlp', activation_fn: Callable=DEFAULT_ACTIVATION_FN, moe_jitter_eps: Optional[float]=None, moe_normalize_expert_weights: Optional[Union[int, float]]=None, uniform_expert_assignment: bool=False, bias: bool=True):
super().__init__()
self.router = LearnedRouter(hidden_size, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, moe_jitter_eps=moe_jitter_eps, moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=device)
self.experts = DroplessMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, mlp_type=mlp_type, moe_num_experts=moe_num_experts, activation_fn=activation_fn, bias=bias, device=device)
def forward(self, x: torch.Tensor):
scores, expert_weights, top_experts = self.router(x)
return self.experts(x, scores, expert_weights, top_experts)