zxdu20 commited on
Commit
cde457b
1 Parent(s): acd41f7

Fix attention score on mps

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -4
modeling_chatglm.py CHANGED
@@ -280,10 +280,8 @@ def attention_fn(
280
  # [sk, b, np, hn] -> [sk, b * np, hn]
281
  key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
282
 
283
- matmul_result = torch.empty(
284
- output_size[0] * output_size[1],
285
- output_size[2],
286
- output_size[3],
287
  dtype=query_layer.dtype,
288
  device=query_layer.device,
289
  )
 
280
  # [sk, b, np, hn] -> [sk, b * np, hn]
281
  key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
282
 
283
+ matmul_result = torch.zeros(
284
+ 1, 1, 1,
 
 
285
  dtype=query_layer.dtype,
286
  device=query_layer.device,
287
  )