davda54 commited on
Commit
9b07ad6
·
1 Parent(s): 670918f

Update modeling_nort5.py

Browse files
Files changed (1) hide show
  1. modeling_nort5.py +9 -8
modeling_nort5.py CHANGED
@@ -62,7 +62,7 @@ class Decoder(nn.Module):
62
  self_relative_embedding = self.self_relative_embedding()
63
  cross_relative_embedding = self.cross_relative_embedding()
64
 
65
- if past_key_values is not None:
66
  autoreg_mask = torch.triu(
67
  torch.full((x.size(0), x.size(0)), True, device=x.device),
68
  diagonal=1
@@ -259,12 +259,12 @@ class Attention(nn.Module):
259
 
260
  if past_key_value is not None:
261
  if not self.is_cross_attention:
262
- key = torch.cat([past_key_value[0], key], dim=1)
263
- value = torch.cat([past_key_value[1], value], dim=1)
264
  key_len = key.size(1)
265
  elif past_key_value[0].size(1) == kv.size(0):
266
- key = past_key_value[0]
267
- value = past_key_value[1]
268
 
269
  if self.position_indices.size(0) < max(query_len, key_len):
270
  position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
@@ -306,7 +306,10 @@ class Attention(nn.Module):
306
  context = self.post_layer_norm(context)
307
  context = self.dropout(context)
308
 
309
- return context, attention_probs.detach(), (key.detach(), value.detach())
 
 
 
310
 
311
 
312
  class WordEmbedding(nn.Module):
@@ -662,9 +665,7 @@ class NorT5ForConditionalGeneration(NorT5Model):
662
  reordered_layer_past_states = ()
663
  for layer_past_state in layer_past_states:
664
  # need to set correct `past` for each of the four key / value states
665
- layer_past_state = layer_past_state.unflatten(0, (-1, self.config.num_attention_heads))
666
  layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
667
- layer_past_state = layer_past_state.flatten(0, 1)
668
  reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
669
 
670
  assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
 
62
  self_relative_embedding = self.self_relative_embedding()
63
  cross_relative_embedding = self.cross_relative_embedding()
64
 
65
+ if past_key_values is None:
66
  autoreg_mask = torch.triu(
67
  torch.full((x.size(0), x.size(0)), True, device=x.device),
68
  diagonal=1
 
259
 
260
  if past_key_value is not None:
261
  if not self.is_cross_attention:
262
+ key = torch.cat([past_key_value[0].flatten(0, 1), key], dim=1)
263
+ value = torch.cat([past_key_value[1].flatten(0, 1), value], dim=1)
264
  key_len = key.size(1)
265
  elif past_key_value[0].size(1) == kv.size(0):
266
+ key = past_key_value[0].flatten(0, 1)
267
+ value = past_key_value[1].flatten(0, 1)
268
 
269
  if self.position_indices.size(0) < max(query_len, key_len):
270
  position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
 
306
  context = self.post_layer_norm(context)
307
  context = self.dropout(context)
308
 
309
+ key = key.detach().unflatten(0, (-1, self.num_heads))
310
+ value = value.detach().unflatten(0, (-1, self.num_heads))
311
+
312
+ return context, attention_probs.detach(), (key, value)
313
 
314
 
315
  class WordEmbedding(nn.Module):
 
665
  reordered_layer_past_states = ()
666
  for layer_past_state in layer_past_states:
667
  # need to set correct `past` for each of the four key / value states
 
668
  layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
 
669
  reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
670
 
671
  assert reordered_layer_past_states[0].shape == layer_past_states[0].shape