davidlvxin commited on
Commit
1ed39ad
1 Parent(s): 2990a09

Optimize the storage of KV cache

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. modeling_chatglm.py +21 -8
README.md CHANGED
@@ -15,8 +15,13 @@ tags:
15
  <p align="center">
16
  👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
17
  </p>
 
 
 
 
18
 
19
  ## 介绍
 
20
  ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
21
 
22
  ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
 
15
  <p align="center">
16
  👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
17
  </p>
18
+ ## 更新/Update
19
+
20
+ - 我们优化了KV Cache的存储方式,减少了显存碎片的产生。基于优化后的代码,模型可以在约**20G显存**的情况下处理32K长度的上下文(FP/BF16格式)。
21
+ - We have optimized the storage method of the KV Cache, reducing the generation of memory fragmentation. Based on the optimized code, the model can process a context length of 32K under approximately **20G** of memory (FP/BF16 format).
22
 
23
  ## 介绍
24
+
25
  ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
26
 
27
  ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
modeling_chatglm.py CHANGED
@@ -413,7 +413,10 @@ class SelfAttention(torch.nn.Module):
413
  key_layer = torch.cat((cache_k, key_layer), dim=0)
414
  value_layer = torch.cat((cache_v, value_layer), dim=0)
415
  if use_cache:
416
- kv_cache = (key_layer, value_layer)
 
 
 
417
  else:
418
  kv_cache = None
419
 
@@ -612,12 +615,8 @@ class GLMTransformer(torch.nn.Module):
612
  if not kv_caches:
613
  kv_caches = [None for _ in range(self.num_layers)]
614
  presents = () if use_cache else None
615
- if self.gradient_checkpointing and self.training:
616
- if use_cache:
617
- logger.warning_once(
618
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
619
- )
620
- use_cache = False
621
 
622
  all_self_attentions = None
623
  all_hidden_states = () if output_hidden_states else None
@@ -645,7 +644,15 @@ class GLMTransformer(torch.nn.Module):
645
  )
646
  hidden_states, kv_cache = layer_ret
647
  if use_cache:
648
- presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
649
 
650
  if output_hidden_states:
651
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -830,6 +837,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
830
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
831
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
832
  )
 
 
 
 
 
 
833
 
834
  if not return_dict:
835
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
413
  key_layer = torch.cat((cache_k, key_layer), dim=0)
414
  value_layer = torch.cat((cache_v, value_layer), dim=0)
415
  if use_cache:
416
+ if kv_cache is None:
417
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
418
+ else:
419
+ kv_cache = (key_layer, value_layer)
420
  else:
421
  kv_cache = None
422
 
 
615
  if not kv_caches:
616
  kv_caches = [None for _ in range(self.num_layers)]
617
  presents = () if use_cache else None
618
+ if self.training:
619
+ use_cache = False
 
 
 
 
620
 
621
  all_self_attentions = None
622
  all_hidden_states = () if output_hidden_states else None
 
644
  )
645
  hidden_states, kv_cache = layer_ret
646
  if use_cache:
647
+ # token by token decoding, use tuple format
648
+ if kv_caches[0] is not None:
649
+ presents = presents + (kv_cache,)
650
+ # prefilling in decoding, use tensor format to save cuda memory
651
+ else:
652
+ if len(presents) == 0:
653
+ presents = kv_cache
654
+ else:
655
+ presents = torch.cat((presents, kv_cache), dim=0)
656
 
657
  if output_hidden_states:
658
  all_hidden_states = all_hidden_states + (hidden_states,)
 
837
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
838
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
839
  )
840
+ if presents is not None and type(presents) is torch.Tensor:
841
+ presents = presents.split(1, dim=0)
842
+ presents = list(presents)
843
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
844
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
845
+ presents = tuple(presents)
846
 
847
  if not return_dict:
848
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)