File size: 3,729 Bytes
3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 3f96a16 881b143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import Dict, List, Optional, Type, Union
import torch
def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
class LPLayerNorm(torch.nn.LayerNorm):
def __init__(
self,
normalized_shape: Union[int, List[int], torch.Size],
eps: float = 1e-05,
elementwise_affine: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = (
_cast_if_autocast_enabled(self.weight)
if self.weight is not None
else self.weight
)
downcast_bias = (
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
)
with torch.autocast(enabled=False, device_type=module_device.type):
return torch.nn.functional.layer_norm(
downcast_x,
self.normalized_shape,
downcast_weight,
downcast_bias,
self.eps,
)
def rms_norm(
x: torch.Tensor, weight: Optional[torch.Tensor] = None, eps: float = 1e-05
) -> torch.Tensor:
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
if weight is not None:
return output * weight
return output
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape: Union[int, List[int], torch.Size],
eps: float = 1e-05,
weight: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.eps = eps
if weight:
self.weight = torch.nn.Parameter(
torch.ones(normalized_shape, dtype=dtype, device=device)
)
else:
self.register_parameter("weight", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
class LPRMSNorm(RMSNorm):
def __init__(
self,
normalized_shape: Union[int, List[int], torch.Size],
eps: float = 1e-05,
weight: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
weight=weight,
dtype=dtype,
device=device,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = (
_cast_if_autocast_enabled(self.weight)
if self.weight is not None
else self.weight
)
with torch.autocast(enabled=False, device_type=x.device.type):
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
"layernorm": torch.nn.LayerNorm,
"low_precision_layernorm": LPLayerNorm,
"rmsnorm": RMSNorm,
"low_precision_rmsnorm": LPRMSNorm,
}
|