Correct the output dtype of rmsnorm_func (#13)
Browse files- Correct the output dtype of rmsnorm_func (f2e665eb9ee6eae4abd08d7cb44fdf6422ee0c84)
Co-authored-by: ag0 <[email protected]>
- modeling_flash_llama.py +1 -1
modeling_flash_llama.py
CHANGED
@@ -68,7 +68,7 @@ def rmsnorm_func(hidden_states, weight, variance_epsilon):
|
|
68 |
hidden_states = hidden_states.to(torch.float32)
|
69 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
70 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
71 |
-
return weight * hidden_states.to(input_dtype)
|
72 |
|
73 |
|
74 |
class LlamaRMSNorm(nn.Module):
|
|
|
68 |
hidden_states = hidden_states.to(torch.float32)
|
69 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
70 |
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
71 |
+
return (weight * hidden_states).to(input_dtype)
|
72 |
|
73 |
|
74 |
class LlamaRMSNorm(nn.Module):
|