Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +1 -1
modeling_gpt_refact.py
CHANGED
@@ -151,7 +151,7 @@ class Attention(nn.Module):
|
|
151 |
upcast = dtype != softmax_dtype
|
152 |
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
153 |
|
154 |
-
attn_weights = alibi + torch.matmul(query * self.scale, key)
|
155 |
|
156 |
if upcast:
|
157 |
if attention_mask is None:
|
|
|
151 |
upcast = dtype != softmax_dtype
|
152 |
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
153 |
|
154 |
+
attn_weights = (alibi + torch.matmul(query * self.scale, key)).to(query.dtype)
|
155 |
|
156 |
if upcast:
|
157 |
if attention_mask is None:
|