svakhreev commited on
Commit
1814817
1 Parent(s): 7485d16

Update modeling_gpt_refact.py

Browse files
Files changed (1) hide show
  1. modeling_gpt_refact.py +92 -76
modeling_gpt_refact.py CHANGED
@@ -21,29 +21,23 @@ logger = logging.get_logger(__name__)
21
 
22
  @torch.jit.script
23
  def upcast_masked_softmax(
24
- x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
25
  ):
26
  input_dtype = x.dtype
27
- x = x.to(softmax_dtype) * scale
28
  x = torch.where(mask, x, mask_value)
29
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
30
  return x
31
 
32
 
33
  @torch.jit.script
34
- def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
35
  input_dtype = x.dtype
36
- x = x.to(softmax_dtype) * scale
37
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
38
  return x
39
 
40
 
41
- @torch.jit.script
42
- def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
43
- x = torch.where(mask, x, mask_value)
44
- x = torch.nn.functional.softmax(x, dim=-1)
45
- return x
46
-
47
  @torch.jit.script
48
  def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
49
  """
@@ -76,7 +70,6 @@ def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
76
  m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
77
  # Concatenate the slopes with the remaining slopes.
78
  m = torch.cat([m, m_hat])
79
-
80
  return m
81
 
82
  @torch.jit.script
@@ -85,8 +78,7 @@ def get_alibi_biases(
85
  T: int,
86
  attn_heads: int,
87
  dev: torch.device,
88
- dtype: torch.dtype,
89
- causal: bool = True) -> torch.Tensor:
90
  """
91
  ## Calculate the attention biases matrix
92
  * `n_heads` is the number of heads in the attention layer
@@ -95,28 +87,25 @@ def get_alibi_biases(
95
  """
96
 
97
  # Get slopes $m$ for each head
98
- if causal:
99
- mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1)
100
- else:
101
- mask = torch.ones((T, T), device=dev, dtype=torch.bool)
102
 
103
- m = _get_slopes(attn_heads, dev)
104
 
105
  # Calculate distances $[0, 1, \dots, N]$
106
  # Here we calculate the distances using the mask.
107
  #
108
  # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
109
  # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
110
- distance = mask.cumsum(dim=-1)
111
 
112
  # Multiply them pair-wise to get the AliBi bias matrix
113
  biases = distance[:, :, None] * m[None, None, :]
114
  biases = biases.permute(2, 0, 1)[None, :, :T, :T]
115
- biases = biases.repeat(B, 1, 1, 1)
116
- return biases.to(dtype).contiguous()
117
 
118
 
119
  class Attention(nn.Module):
 
120
  def __init__(self, config, layer_idx=None):
121
  super().__init__()
122
  self.mask_value = None
@@ -126,7 +115,7 @@ class Attention(nn.Module):
126
  self.head_dim = self.embed_dim // self.num_heads
127
  self.kv_attn_heads = 1
128
 
129
- self.scale = self.head_dim ** -0.5
130
 
131
  if self.head_dim * self.num_heads != self.embed_dim:
132
  raise ValueError(
@@ -139,41 +128,63 @@ class Attention(nn.Module):
139
  self.scale_attention_softmax_in_fp32 = (
140
  config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
141
  )
 
142
 
143
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
144
- self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False)
145
- self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
146
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
147
 
 
 
 
 
 
 
148
  def _attn(self, query, key, value, attention_mask=None, alibi=None):
149
  dtype = query.dtype
150
  softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
 
151
  upcast = dtype != softmax_dtype
152
- unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
 
154
- attn_weights = alibi + torch.matmul(query * self.scale, key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if upcast:
 
 
157
  if attention_mask is None:
158
- attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
159
  else:
160
- mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
161
- attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
162
  else:
163
  if attention_mask is not None:
164
- attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
165
-
166
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
167
 
168
- attn_output = torch.matmul(attn_weights, value)
169
 
170
  return attn_output, attn_weights
171
 
172
- def _split_heads(self, tensor):
173
- new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim)
174
- tensor = tensor.view(new_shape)
175
- return tensor.permute(0, 2, 1, 3)
176
-
177
  def forward(
178
  self,
179
  hidden_states: torch.Tensor,
@@ -186,13 +197,9 @@ class Attention(nn.Module):
186
  Tuple[torch.Tensor, Optional[torch.Tensor]],
187
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
188
  ]:
189
- b, t, _ = hidden_states.shape
190
  query = self.q(hidden_states)
191
- key = self.k(hidden_states)
192
- value = self.v(hidden_states)
193
- query = self._split_heads(query)
194
- key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
195
- value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
196
 
197
  if layer_past is not None:
198
  past_key, past_value = layer_past
@@ -205,32 +212,31 @@ class Attention(nn.Module):
205
  present = None
206
 
207
  attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
208
-
209
- attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
210
  attn_output = self.c_proj(attn_output)
211
 
212
  outputs = (attn_output, present)
213
  if output_attentions:
 
214
  outputs += (attn_weights,)
215
 
216
  return outputs # a, present, (attentions)
217
 
218
 
219
  class MLP(nn.Module):
 
220
  def __init__(self, intermediate_size, config, multiple_of: int = 256):
221
  super().__init__()
222
  embed_dim = config.hidden_size
223
  hidden_dim = intermediate_size
224
  hidden_dim = int(2 * hidden_dim / 3)
225
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
226
- self.linear_1 = nn.Linear(embed_dim, hidden_dim, bias=False)
227
- self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False)
228
- self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
229
-
230
- def forward(self, x: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
231
- x1 = F.silu(self.linear_1(x))
232
- x2 = self.linear_3(x)
233
- x = self.c_proj(x1 * x2)
234
  return x
235
 
236
 
@@ -255,7 +261,6 @@ class GPTRefactBlock(nn.Module):
255
  self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
256
  self.attn = Attention(config, layer_idx=layer_idx)
257
  self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
258
-
259
  self.mlp = MLP(self.inner_dim, config)
260
 
261
  def forward(
@@ -297,6 +302,7 @@ class GPTRefactBlock(nn.Module):
297
 
298
 
299
  class GPTRefactPreTrainedModel(PreTrainedModel):
 
300
  config_class = GPTRefactConfig
301
  base_model_prefix = "transformer"
302
  supports_gradient_checkpointing = True
@@ -331,12 +337,9 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
331
  elif isinstance(module, LayerNormNoBias):
332
  module.weight.data.fill_(1.0)
333
 
334
- def _set_gradient_checkpointing(self, module, value=False):
335
- if isinstance(module, GPTRefactModel):
336
- module.gradient_checkpointing = value
337
-
338
 
339
  class GPTRefactModel(GPTRefactPreTrainedModel):
 
340
  def __init__(self, config):
341
  super().__init__(config)
342
  self.embed_dim = config.hidden_size
@@ -347,6 +350,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
347
  self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
348
 
349
  self.max_positions = config.max_position_embeddings
 
350
  self.register_buffer(
351
  "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
352
  persistent=False
@@ -357,15 +361,8 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
357
  # Initialize weights and apply final processing
358
  self.post_init()
359
 
360
- @staticmethod
361
- def _make_mask(seq_len: int, past_key_values_length: int):
362
- # prompt
363
- if past_key_values_length == 0:
364
- mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
365
- mask = torch.triu(mask, 1)
366
- else:
367
- mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
368
- return mask
369
 
370
  def forward(
371
  self,
@@ -408,19 +405,25 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
408
  else:
409
  past_length = past_key_values[0][0].size(-2)
410
 
411
- # Self-attention mask.
412
  query_length = input_shape[-1]
413
-
414
  seq_length_with_past = past_length + query_length
415
- if attention_mask is None:
416
- attention_mask = self._make_mask(query_length, past_length).to(device)
417
- else:
418
- attention_mask = attention_mask.to(device)
 
 
 
 
 
 
 
419
 
420
  hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
421
 
 
422
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
423
- self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :]
424
 
425
  output_shape = input_shape + (hidden_states.size(-1),)
426
 
@@ -489,6 +492,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
489
 
490
 
491
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
 
492
  _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
493
 
494
  def __init__(self, config):
@@ -499,6 +503,18 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
499
 
500
  # Initialize weights and apply final processing
501
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
504
  if inputs_embeds is not None and past_key_values is None:
@@ -583,4 +599,4 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
583
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
584
  beam_idx at every generation step.
585
  """
586
- return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
 
21
 
22
  @torch.jit.script
23
  def upcast_masked_softmax(
24
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, softmax_dtype: torch.dtype
25
  ):
26
  input_dtype = x.dtype
27
+ x = x.to(softmax_dtype)
28
  x = torch.where(mask, x, mask_value)
29
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
30
  return x
31
 
32
 
33
  @torch.jit.script
34
+ def upcast_softmax(x: torch.Tensor, softmax_dtype: torch.dtype):
35
  input_dtype = x.dtype
36
+ x = x.to(softmax_dtype)
37
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
38
  return x
39
 
40
 
 
 
 
 
 
 
41
  @torch.jit.script
42
  def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
43
  """
 
70
  m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
71
  # Concatenate the slopes with the remaining slopes.
72
  m = torch.cat([m, m_hat])
 
73
  return m
74
 
75
  @torch.jit.script
 
78
  T: int,
79
  attn_heads: int,
80
  dev: torch.device,
81
+ dtype: torch.dtype) -> torch.Tensor:
 
82
  """
83
  ## Calculate the attention biases matrix
84
  * `n_heads` is the number of heads in the attention layer
 
87
  """
88
 
89
  # Get slopes $m$ for each head
90
+ mask = torch.ones((T, T), device=dev, dtype=torch.bool)
 
 
 
91
 
92
+ m = _get_slopes(attn_heads, dev).to(dtype)
93
 
94
  # Calculate distances $[0, 1, \dots, N]$
95
  # Here we calculate the distances using the mask.
96
  #
97
  # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
98
  # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
99
+ distance = mask.cumsum(dim=-1).to(dtype)
100
 
101
  # Multiply them pair-wise to get the AliBi bias matrix
102
  biases = distance[:, :, None] * m[None, None, :]
103
  biases = biases.permute(2, 0, 1)[None, :, :T, :T]
104
+ return biases.contiguous()
 
105
 
106
 
107
  class Attention(nn.Module):
108
+
109
  def __init__(self, config, layer_idx=None):
110
  super().__init__()
111
  self.mask_value = None
 
115
  self.head_dim = self.embed_dim // self.num_heads
116
  self.kv_attn_heads = 1
117
 
118
+ self.scale_factor = self.head_dim ** -0.5
119
 
120
  if self.head_dim * self.num_heads != self.embed_dim:
121
  raise ValueError(
 
128
  self.scale_attention_softmax_in_fp32 = (
129
  config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
130
  )
131
+ self.attention_bias_in_fp32 = config.attention_bias_in_fp32
132
 
133
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
134
+ self.kv = nn.Linear(self.embed_dim, self.head_dim * 2, bias=False)
 
135
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
136
 
137
+ def _get_mask_value(self, device, dtype):
138
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
139
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
140
+ self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
141
+ return self.mask_value
142
+
143
  def _attn(self, query, key, value, attention_mask=None, alibi=None):
144
  dtype = query.dtype
145
  softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
146
+ mask_value = self._get_mask_value(query.device, softmax_dtype)
147
  upcast = dtype != softmax_dtype
 
148
 
149
+ query_shape = query.shape
150
+ batch_size = query_shape[0]
151
+ key_length = key.size(-1)
152
+
153
+ # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
154
+ # -> (batch_size, query_length, num_heads, key_length)
155
+ query_length = query_shape[1]
156
+ attn_shape = (batch_size, query_length, self.num_heads, key_length)
157
+ attn_view = (batch_size, query_length * self.num_heads, key_length)
158
+ # No copy needed for MQA 2, or when layer_past is provided.
159
+ query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
160
+
161
+ alibi = alibi.transpose(2, 1).reshape(alibi.shape[0], -1, alibi.shape[-1])
162
+ initial_dtype = query.dtype
163
+ new_dtype = torch.float32 if self.attention_bias_in_fp32 else initial_dtype
164
+ attn_weights = alibi.baddbmm(
165
+ batch1=query.to(new_dtype),
166
+ batch2=key.to(new_dtype),
167
+ beta=1,
168
+ alpha=self.scale_factor
169
+ ).view(attn_shape).to(initial_dtype)
170
 
171
  if upcast:
172
+ # Use a fused kernel to prevent a large overhead from casting and scaling.
173
+ # Sub-optimal when the key length is not a multiple of 8.
174
  if attention_mask is None:
175
+ attn_weights = upcast_softmax(attn_weights, softmax_dtype)
176
  else:
177
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, softmax_dtype)
 
178
  else:
179
  if attention_mask is not None:
180
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
181
+ attn_weights = torch.where(attention_mask, attn_weights, mask_value)
182
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
183
 
184
+ attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
185
 
186
  return attn_output, attn_weights
187
 
 
 
 
 
 
188
  def forward(
189
  self,
190
  hidden_states: torch.Tensor,
 
197
  Tuple[torch.Tensor, Optional[torch.Tensor]],
198
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
199
  ]:
 
200
  query = self.q(hidden_states)
201
+ kv = self.kv(hidden_states)
202
+ key, value = kv.split(self.head_dim, dim=-1)
 
 
 
203
 
204
  if layer_past is not None:
205
  past_key, past_value = layer_past
 
212
  present = None
213
 
214
  attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
 
 
215
  attn_output = self.c_proj(attn_output)
216
 
217
  outputs = (attn_output, present)
218
  if output_attentions:
219
+ attn_weights = attn_weights.transpose(1, 2)
220
  outputs += (attn_weights,)
221
 
222
  return outputs # a, present, (attentions)
223
 
224
 
225
  class MLP(nn.Module):
226
+
227
  def __init__(self, intermediate_size, config, multiple_of: int = 256):
228
  super().__init__()
229
  embed_dim = config.hidden_size
230
  hidden_dim = intermediate_size
231
  hidden_dim = int(2 * hidden_dim / 3)
232
+ self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
233
+ self.gate_up_proj = nn.Linear(embed_dim, self.hidden_dim * 2, bias=False)
234
+ self.c_proj = nn.Linear(self.hidden_dim, embed_dim, bias=False)
235
+
236
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
237
+ up_proj = self.gate_up_proj(x)
238
+ x1, x2 = torch.split(up_proj, self.hidden_dim, dim=-1)
239
+ x = self.c_proj(F.silu(x1) * x2)
 
240
  return x
241
 
242
 
 
261
  self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
262
  self.attn = Attention(config, layer_idx=layer_idx)
263
  self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
 
264
  self.mlp = MLP(self.inner_dim, config)
265
 
266
  def forward(
 
302
 
303
 
304
  class GPTRefactPreTrainedModel(PreTrainedModel):
305
+
306
  config_class = GPTRefactConfig
307
  base_model_prefix = "transformer"
308
  supports_gradient_checkpointing = True
 
337
  elif isinstance(module, LayerNormNoBias):
338
  module.weight.data.fill_(1.0)
339
 
 
 
 
 
340
 
341
  class GPTRefactModel(GPTRefactPreTrainedModel):
342
+
343
  def __init__(self, config):
344
  super().__init__(config)
345
  self.embed_dim = config.hidden_size
 
350
  self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
351
 
352
  self.max_positions = config.max_position_embeddings
353
+ self.attention_bias_in_fp32 = config.attention_bias_in_fp32
354
  self.register_buffer(
355
  "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
356
  persistent=False
 
361
  # Initialize weights and apply final processing
362
  self.post_init()
363
 
364
+ def get_input_embeddings(self):
365
+ return self.wte
 
 
 
 
 
 
 
366
 
367
  def forward(
368
  self,
 
405
  else:
406
  past_length = past_key_values[0][0].size(-2)
407
 
 
408
  query_length = input_shape[-1]
 
409
  seq_length_with_past = past_length + query_length
410
+
411
+ # Self-attention mask.
412
+ key_length = past_length + query_length
413
+ self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
414
+ if attention_mask is not None:
415
+ self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
416
+ dtype=torch.bool, device=self_attention_mask.device
417
+ )
418
+
419
+ # MQA models: (batch_size, query_length, n_heads, key_length)
420
+ attention_mask = self_attention_mask.unsqueeze(2)
421
 
422
  hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
423
 
424
+ alibi_dtype = torch.float32 if self.attention_bias_in_fp32 else self.wte.weight.dtype
425
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
426
+ self.num_heads, device, alibi_dtype)[:, :, -query_length:, :]
427
 
428
  output_shape = input_shape + (hidden_states.size(-1),)
429
 
 
492
 
493
 
494
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
495
+
496
  _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
497
 
498
  def __init__(self, config):
 
503
 
504
  # Initialize weights and apply final processing
505
  self.post_init()
506
+
507
+ # gradient checkpointing support for lower versions of transformers
508
+ import transformers
509
+ from packaging import version
510
+
511
+ def _set_gradient_checkpointing(module, value=False):
512
+ if isinstance(module, GPTRefactModel):
513
+ module.gradient_checkpointing = value
514
+
515
+ v = version.parse(transformers.__version__)
516
+ if v.major <= 4 and v.minor < 35:
517
+ self._set_gradient_checkpointing = _set_gradient_checkpointing
518
 
519
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
520
  if inputs_embeds is not None and past_key_values is None:
 
599
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
600
  beam_idx at every generation step.
601
  """
602
+ return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)