Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +1 -1
modeling_gpt_refact.py
CHANGED
@@ -420,7 +420,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
420 |
hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
|
421 |
|
422 |
alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
|
423 |
-
self.num_heads, device,
|
424 |
|
425 |
output_shape = input_shape + (hidden_states.size(-1),)
|
426 |
|
|
|
420 |
hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
|
421 |
|
422 |
alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
|
423 |
+
self.num_heads, device, torch.float32)[:, :, -query_length:, :]
|
424 |
|
425 |
output_shape = input_shape + (hidden_states.size(-1),)
|
426 |
|