zR commited on
Commit
ade85af
·
1 Parent(s): 1127073

fix device problem

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +126 -1
modeling_chatglm.py CHANGED
@@ -332,6 +332,128 @@ class CoreAttention(torch.nn.Module):
332
 
333
  return context_layer
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  class SelfAttention(torch.nn.Module):
337
  """Parallel self-attention layer abstract class.
@@ -697,6 +819,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
697
  config_class = ChatGLMConfig
698
  base_model_prefix = "transformer"
699
  _no_split_modules = ["GLMBlock"]
 
 
700
 
701
  def _init_weights(self, module: nn.Module):
702
  """Initialize the weights."""
@@ -868,7 +992,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
868
  self.config.eoi_token_id)
869
  assert eoi_token_pos - boi_token_pos == 2
870
  new_input_embeds.append(torch.cat(
871
- (inputs_embeds[i, :boi_token_pos], images_features[i], inputs_embeds[i, eoi_token_pos + 1:])))
 
872
  new_position_ids.append(torch.cat(
873
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
874
  position_ids[i, eoi_token_pos:])
 
332
 
333
  return context_layer
334
 
335
+ class SdpaAttention(CoreAttention):
336
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
337
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
338
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
339
+ is_causal=True,
340
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
341
+ else:
342
+ if attention_mask is not None:
343
+ attention_mask = ~attention_mask
344
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
345
+ attention_mask,
346
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
347
+ context_layer = context_layer.transpose(1, 2).contiguous()
348
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
349
+ context_layer = context_layer.reshape(*new_context_layer_shape)
350
+ return context_layer
351
+
352
+
353
+ def _get_unpad_data(attention_mask):
354
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
355
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
356
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
357
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
358
+ return (
359
+ indices,
360
+ cu_seqlens,
361
+ max_seqlen_in_batch,
362
+ )
363
+
364
+
365
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
366
+ class FlashAttention2(CoreAttention):
367
+ def __init__(self, *args, **kwargs):
368
+ super().__init__(*args, **kwargs)
369
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
370
+
371
+ def forward(self, query_states, key_states, value_states, attention_mask):
372
+ query_states = query_states.transpose(1, 2)
373
+ key_states = key_states.transpose(1, 2)
374
+ value_states = value_states.transpose(1, 2)
375
+ batch_size, query_length = query_states.shape[:2]
376
+ if not self._flash_attn_uses_top_left_mask:
377
+ causal = self.is_causal
378
+ else:
379
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
380
+ causal = self.is_causal and query_length != 1
381
+ dropout = self.config.attention_dropout if self.training else 0.0
382
+ # Contains at least one padding token in the sequence
383
+ if attention_mask is not None:
384
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
385
+ query_states, key_states, value_states, attention_mask, query_length
386
+ )
387
+
388
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
389
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
390
+
391
+ attn_output_unpad = flash_attn_varlen_func(
392
+ query_states,
393
+ key_states,
394
+ value_states,
395
+ cu_seqlens_q=cu_seqlens_q,
396
+ cu_seqlens_k=cu_seqlens_k,
397
+ max_seqlen_q=max_seqlen_in_batch_q,
398
+ max_seqlen_k=max_seqlen_in_batch_k,
399
+ dropout_p=dropout,
400
+ softmax_scale=None,
401
+ causal=causal,
402
+ )
403
+
404
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
405
+ else:
406
+ attn_output = flash_attn_func(
407
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
408
+ )
409
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
410
+ return attn_output
411
+
412
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
413
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
414
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
415
+
416
+ key_layer = index_first_axis(
417
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
418
+ )
419
+ value_layer = index_first_axis(
420
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
421
+ )
422
+ if query_length == kv_seq_len:
423
+ query_layer = index_first_axis(
424
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
425
+ indices_k
426
+ )
427
+ cu_seqlens_q = cu_seqlens_k
428
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
429
+ indices_q = indices_k
430
+ elif query_length == 1:
431
+ max_seqlen_in_batch_q = 1
432
+ cu_seqlens_q = torch.arange(
433
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
434
+ ) # There is a memcpy here, that is very bad.
435
+ indices_q = cu_seqlens_q[:-1]
436
+ query_layer = query_layer.squeeze(1)
437
+ else:
438
+ # The -q_len: slice assumes left padding.
439
+ attention_mask = attention_mask[:, -query_length:]
440
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
441
+
442
+ return (
443
+ query_layer,
444
+ key_layer,
445
+ value_layer,
446
+ indices_q,
447
+ (cu_seqlens_q, cu_seqlens_k),
448
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
449
+ )
450
+
451
+
452
+ CORE_ATTENTION_CLASSES = {
453
+ "eager": CoreAttention,
454
+ "sdpa": SdpaAttention,
455
+ "flash_attention_2": FlashAttention2
456
+ }
457
 
458
  class SelfAttention(torch.nn.Module):
459
  """Parallel self-attention layer abstract class.
 
819
  config_class = ChatGLMConfig
820
  base_model_prefix = "transformer"
821
  _no_split_modules = ["GLMBlock"]
822
+ _supports_flash_attn_2 = True
823
+ _supports_sdpa = True
824
 
825
  def _init_weights(self, module: nn.Module):
826
  """Initialize the weights."""
 
992
  self.config.eoi_token_id)
993
  assert eoi_token_pos - boi_token_pos == 2
994
  new_input_embeds.append(torch.cat(
995
+ (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device),
996
+ inputs_embeds[i, eoi_token_pos + 1:])))
997
  new_position_ids.append(torch.cat(
998
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
999
  position_ids[i, eoi_token_pos:])