|
from functools import partial |
|
from typing import Callable, Optional, Union |
|
import torch |
|
import torch.nn.functional as F |
|
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') |
|
|
|
class _UniformExpertAssignment(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor, num_experts: int): |
|
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) |
|
out = torch.remainder(out, num_experts) |
|
return out.view(x.shape) |
|
|
|
class LearnedRouter(torch.nn.Module): |
|
|
|
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: Optional[float], moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device]) -> None: |
|
super().__init__() |
|
self.hidden_size: int = hidden_size |
|
self.moe_num_experts: int = moe_num_experts |
|
self.moe_top_k: int = moe_top_k |
|
self.moe_jitter_eps: Optional[float] = moe_jitter_eps |
|
self.moe_normalize_expert_weights: Optional[Union[int, float]] = moe_normalize_expert_weights |
|
self.uniform_expert_assignment: bool = uniform_expert_assignment |
|
self.layer: torch.nn.Module = torch.nn.Linear(hidden_size, moe_num_experts, bias=False, device=device) |
|
|
|
def jitter(self, x: torch.Tensor) -> torch.Tensor: |
|
assert self.moe_jitter_eps is not None |
|
low: float = 1.0 - self.moe_jitter_eps |
|
high: float = 1.0 + self.moe_jitter_eps |
|
noise: torch.Tensor = torch.rand(x.size(), dtype=x.dtype, device=x.device) |
|
return low + noise * (high - low) |
|
|
|
def _top_k(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
if self.moe_top_k == 1: |
|
values, indices = scores.max(dim=-1) |
|
return (values.unsqueeze(-1), indices.unsqueeze(-1)) |
|
return torch.topk(scores, self.moe_top_k, dim=-1) |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.training and self.moe_jitter_eps is not None: |
|
x = x * self.jitter(x) |
|
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) |
|
expert_weights, top_experts = self._top_k(scores) |
|
if self.moe_normalize_expert_weights: |
|
expert_weights = expert_weights / torch.norm(expert_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) |
|
top_experts = _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) if self.uniform_expert_assignment else top_experts |
|
scores = scores.to(x.dtype) |
|
expert_weights = expert_weights.to(x.dtype) |
|
return (scores, expert_weights, top_experts) |
|
|
|
class MLP(torch.nn.Module): |
|
|
|
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]) -> None: |
|
super().__init__() |
|
self.moe_num_experts: int = moe_num_experts |
|
self.ffn_hidden_size: int = ffn_hidden_size |
|
self.hidden_size: int = hidden_size |
|
self.activation_fn: Callable = activation_fn |
|
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device)) |
|
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device)) |
|
self.activation_fn = activation_fn |
|
|
|
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: |
|
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] |
|
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] |
|
before_activation = x @ expert_w1.t() |
|
layer_1_output = self.activation_fn(before_activation) |
|
output = layer_1_output @ expert_w2 |
|
return output |
|
|
|
class GLU(torch.nn.Module): |
|
|
|
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.ffn_hidden_size = ffn_hidden_size |
|
self.moe_num_experts = moe_num_experts |
|
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device)) |
|
self.v1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device)) |
|
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device)) |
|
self.activation_fn = activation_fn |
|
|
|
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): |
|
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] |
|
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] |
|
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx] |
|
x1 = x.matmul(expert_w1.t()) |
|
x2 = x.matmul(expert_v1.t()) |
|
x1 = self.activation_fn(x1) |
|
x1 = x1 * x2 |
|
x1 = x1.matmul(expert_w2) |
|
return x1 |
|
|
|
class DroplessMLP(torch.nn.Module): |
|
|
|
def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, moe_num_experts: int, activation_fn: Callable, bias: bool, device: Optional[torch.device]): |
|
super().__init__() |
|
self.moe_num_experts = moe_num_experts |
|
if mlp_type == 'mlp': |
|
self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device) |
|
elif mlp_type == 'glu': |
|
self.mlp = GLU(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device) |
|
else: |
|
raise ValueError(f'Received unknown mlp_type={mlp_type!r}') |
|
|
|
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): |
|
in_shape = x.shape |
|
hidden_size = in_shape[-1] |
|
x = x.view(-1, hidden_size) |
|
out = torch.zeros_like(x) |
|
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) |
|
for expert_idx in range(0, self.moe_num_experts): |
|
topk_idx, token_idx = torch.where(expert_mask[expert_idx]) |
|
if token_idx.shape[0] == 0: |
|
continue |
|
token_list = token_idx.tolist() |
|
topk_list = topk_idx.tolist() |
|
expert_tokens = x[None, token_list].reshape(-1, hidden_size) |
|
mlp_output = self.mlp(expert_tokens, expert_idx) |
|
expert_weights = expert_weights.to(mlp_output.device) |
|
expert_out = mlp_output * expert_weights[token_list, topk_list, None] |
|
out = out.to(mlp_output.device) |
|
token_idx = token_idx.to(mlp_output.device) |
|
out.index_add_(0, token_idx, expert_out) |
|
out = out.view(in_shape) |
|
return out |
|
|
|
class dMoE(torch.nn.Module): |
|
|
|
def __init__(self, device: Optional[torch.device], hidden_size: int=1024, ffn_hidden_size: int=4096, moe_num_experts: int=1, moe_top_k: int=1, mlp_type: str='mlp', activation_fn: Callable=DEFAULT_ACTIVATION_FN, moe_jitter_eps: Optional[float]=None, moe_normalize_expert_weights: Optional[Union[int, float]]=None, uniform_expert_assignment: bool=False, bias: bool=True): |
|
super().__init__() |
|
self.router = LearnedRouter(hidden_size, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, moe_jitter_eps=moe_jitter_eps, moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=device) |
|
self.experts = DroplessMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, mlp_type=mlp_type, moe_num_experts=moe_num_experts, activation_fn=activation_fn, bias=bias, device=device) |
|
|
|
def forward(self, x: torch.Tensor): |
|
scores, expert_weights, top_experts = self.router(x) |
|
return self.experts(x, scores, expert_weights, top_experts) |