Commit
·
695bee3
1
Parent(s):
6ac85d6
Update modeling_phi.py
Browse files- modeling_phi.py +6 -1
modeling_phi.py
CHANGED
@@ -355,8 +355,10 @@ class SelfAttention(nn.Module):
|
|
355 |
key_padding_mask: Optional[torch.BoolTensor] = None,
|
356 |
**kwargs,
|
357 |
) -> torch.FloatTensor:
|
|
|
358 |
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
359 |
q, k, v = qkv.unbind(dim=2)
|
|
|
360 |
|
361 |
q = q.to(torch.float32)
|
362 |
k = k.to(torch.float32)
|
@@ -367,6 +369,7 @@ class SelfAttention(nn.Module):
|
|
367 |
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
368 |
# using float16, which might lead to overflow
|
369 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
|
|
370 |
|
371 |
if key_padding_mask is not None:
|
372 |
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
@@ -376,13 +379,15 @@ class SelfAttention(nn.Module):
|
|
376 |
|
377 |
if causal:
|
378 |
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
|
|
379 |
scores = scores + causal_mask.to(dtype=scores.dtype)
|
380 |
|
381 |
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
382 |
attention = self.drop(attention)
|
383 |
|
384 |
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
385 |
-
|
|
|
386 |
return output
|
387 |
|
388 |
|
|
|
355 |
key_padding_mask: Optional[torch.BoolTensor] = None,
|
356 |
**kwargs,
|
357 |
) -> torch.FloatTensor:
|
358 |
+
print(qkv.shape)
|
359 |
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
360 |
q, k, v = qkv.unbind(dim=2)
|
361 |
+
print(q.shape, k.shape, v.shape)
|
362 |
|
363 |
q = q.to(torch.float32)
|
364 |
k = k.to(torch.float32)
|
|
|
369 |
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
370 |
# using float16, which might lead to overflow
|
371 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
372 |
+
print(scores.shape)
|
373 |
|
374 |
if key_padding_mask is not None:
|
375 |
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
|
|
379 |
|
380 |
if causal:
|
381 |
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
382 |
+
print(causal_mask.shape)
|
383 |
scores = scores + causal_mask.to(dtype=scores.dtype)
|
384 |
|
385 |
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
386 |
attention = self.drop(attention)
|
387 |
|
388 |
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
389 |
+
print(output.shape)
|
390 |
+
|
391 |
return output
|
392 |
|
393 |
|