nathanrchn commited on
Commit
695bee3
·
1 Parent(s): 6ac85d6

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +6 -1
modeling_phi.py CHANGED
@@ -355,8 +355,10 @@ class SelfAttention(nn.Module):
355
  key_padding_mask: Optional[torch.BoolTensor] = None,
356
  **kwargs,
357
  ) -> torch.FloatTensor:
 
358
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
359
  q, k, v = qkv.unbind(dim=2)
 
360
 
361
  q = q.to(torch.float32)
362
  k = k.to(torch.float32)
@@ -367,6 +369,7 @@ class SelfAttention(nn.Module):
367
  # Autocast is manually disabled to avoid `torch.einsum` performing the operation
368
  # using float16, which might lead to overflow
369
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
 
370
 
371
  if key_padding_mask is not None:
372
  padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
@@ -376,13 +379,15 @@ class SelfAttention(nn.Module):
376
 
377
  if causal:
378
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
 
379
  scores = scores + causal_mask.to(dtype=scores.dtype)
380
 
381
  attention = torch.softmax(scores, dim=-1).to(v.dtype)
382
  attention = self.drop(attention)
383
 
384
  output = torch.einsum("bhts,bshd->bthd", attention, v)
385
-
 
386
  return output
387
 
388
 
 
355
  key_padding_mask: Optional[torch.BoolTensor] = None,
356
  **kwargs,
357
  ) -> torch.FloatTensor:
358
+ print(qkv.shape)
359
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
360
  q, k, v = qkv.unbind(dim=2)
361
+ print(q.shape, k.shape, v.shape)
362
 
363
  q = q.to(torch.float32)
364
  k = k.to(torch.float32)
 
369
  # Autocast is manually disabled to avoid `torch.einsum` performing the operation
370
  # using float16, which might lead to overflow
371
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
372
+ print(scores.shape)
373
 
374
  if key_padding_mask is not None:
375
  padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
 
379
 
380
  if causal:
381
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
382
+ print(causal_mask.shape)
383
  scores = scores + causal_mask.to(dtype=scores.dtype)
384
 
385
  attention = torch.softmax(scores, dim=-1).to(v.dtype)
386
  attention = self.drop(attention)
387
 
388
  output = torch.einsum("bhts,bshd->bthd", attention, v)
389
+ print(output.shape)
390
+
391
  return output
392
 
393