Upload 5 files
Browse files- README.md +95 -0
- quantization.py +188 -0
- tokenization_chatglm.py +235 -0
- tokenizer.model +3 -0
- tokenizer_config.json +12 -0
README.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- glm
|
7 |
+
- chatglm
|
8 |
+
- thudm
|
9 |
+
---
|
10 |
+
# ChatGLM2-6B
|
11 |
+
<p align="center">
|
12 |
+
💻 <a href="https://github.com/THUDM/ChatGLM2-6B" target="_blank">Github Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
|
13 |
+
</p>
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
## 介绍
|
20 |
+
ChatGLM**2**-6B 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了如下新特性:
|
21 |
+
|
22 |
+
1. **更强大的性能**:基于 ChatGLM 初代模型的开发经验,我们全面升级了 ChatGLM2-6B 的基座模型。ChatGLM2-6B 使用了 [GLM](https://github.com/THUDM/GLM) 的混合目标函数,经过了 1.4T 中英标识符的预训练与人类偏好对齐训练,[评测结果](#评测结果)显示,相比于初代模型,ChatGLM2-6B 在 MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。
|
23 |
+
2. **更长的上下文**:基于 [FlashAttention](https://github.com/HazyResearch/flash-attention) 技术,我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。但当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限,我们会在后续迭代升级中着重进行优化。
|
24 |
+
3. **更高效的推理**:基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%,INT4 量化下,6G 显存支持的对话长度由 1K 提升到了 8K。
|
25 |
+
|
26 |
+
ChatGLM**2**-6B is the second-generation version of the open-source bilingual (Chinese-English) chat model [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B). It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the following new features:
|
27 |
+
|
28 |
+
1. **Stronger Performance**: Based on the development experience of the first-generation ChatGLM model, we have fully upgraded the base model of ChatGLM2-6B. ChatGLM2-6B uses the hybrid objective function of [GLM](https://github.com/THUDM/GLM), and has undergone pre-training with 1.4T bilingual tokens and human preference alignment training. The [evaluation results](README.md#evaluation-results) show that, compared to the first-generation model, ChatGLM2-6B has achieved substantial improvements in performance on datasets like MMLU (+23%), CEval (+33%), GSM8K (+571%), BBH (+60%), showing strong competitiveness among models of the same size.
|
29 |
+
2. **Longer Context**: Based on [FlashAttention](https://github.com/HazyResearch/flash-attention) technique, we have extended the context length of the base model from 2K in ChatGLM-6B to 32K, and trained with a context length of 8K during the dialogue alignment, allowing for more rounds of dialogue. However, the current version of ChatGLM2-6B has limited understanding of single-round ultra-long documents, which we will focus on optimizing in future iterations.
|
30 |
+
3. **More Efficient Inference**: Based on [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique, ChatGLM2-6B has more efficient inference speed and lower GPU memory usage: under the official implementation, the inference speed has increased by 42% compared to the first generation; under INT4 quantization, the dialogue length supported by 6G GPU memory has increased from 1K to 8K.
|
31 |
+
|
32 |
+
## 软件依赖
|
33 |
+
|
34 |
+
```shell
|
35 |
+
pip install protobuf transformers==4.27.1 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate
|
36 |
+
```
|
37 |
+
|
38 |
+
## 代码调用
|
39 |
+
|
40 |
+
可以通过如下代码调用 ChatGLM-6B 模型来生成对话:
|
41 |
+
|
42 |
+
```ipython
|
43 |
+
>>> from transformers import AutoTokenizer, AutoModel
|
44 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
45 |
+
>>> model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda()
|
46 |
+
>>> model = model.eval()
|
47 |
+
>>> response, history = model.chat(tokenizer, "你好", history=[])
|
48 |
+
>>> print(response)
|
49 |
+
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
|
50 |
+
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
|
51 |
+
>>> print(response)
|
52 |
+
晚上睡不着可能会让你感到焦��或不舒服,但以下是一些可以帮助你入睡的方法:
|
53 |
+
|
54 |
+
1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
|
55 |
+
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
|
56 |
+
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
|
57 |
+
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
|
58 |
+
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
|
59 |
+
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。
|
60 |
+
|
61 |
+
如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
|
62 |
+
```
|
63 |
+
|
64 |
+
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM2-6B)。
|
65 |
+
|
66 |
+
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM2-6B).
|
67 |
+
|
68 |
+
## Change Log
|
69 |
+
* v1.0
|
70 |
+
|
71 |
+
## 协议
|
72 |
+
|
73 |
+
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM2-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
74 |
+
|
75 |
+
## 引用
|
76 |
+
|
77 |
+
如果你觉得我们的工作有帮助的话,请考虑引用下列论文,ChatGLM2-6B 的论文会在近期公布,尽情期待~
|
78 |
+
|
79 |
+
```
|
80 |
+
@article{zeng2022glm,
|
81 |
+
title={Glm-130b: An open bilingual pre-trained model},
|
82 |
+
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
|
83 |
+
journal={arXiv preprint arXiv:2210.02414},
|
84 |
+
year={2022}
|
85 |
+
}
|
86 |
+
```
|
87 |
+
```
|
88 |
+
@inproceedings{du2022glm,
|
89 |
+
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
|
90 |
+
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
|
91 |
+
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
92 |
+
pages={320--335},
|
93 |
+
year={2022}
|
94 |
+
}
|
95 |
+
```
|
quantization.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear
|
2 |
+
from torch.nn.parameter import Parameter
|
3 |
+
|
4 |
+
import bz2
|
5 |
+
import torch
|
6 |
+
import base64
|
7 |
+
import ctypes
|
8 |
+
from transformers.utils import logging
|
9 |
+
|
10 |
+
from typing import List
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
try:
|
16 |
+
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
17 |
+
|
18 |
+
class Kernel:
|
19 |
+
def __init__(self, code: bytes, function_names: List[str]):
|
20 |
+
self.code = code
|
21 |
+
self._function_names = function_names
|
22 |
+
self._cmodule = LazyKernelCModule(self.code)
|
23 |
+
|
24 |
+
for name in self._function_names:
|
25 |
+
setattr(self, name, KernelFunction(self._cmodule, name))
|
26 |
+
|
27 |
+
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
28 |
+
|
29 |
+
kernels = Kernel(
|
30 |
+
bz2.decompress(base64.b64decode(quantization_code)),
|
31 |
+
[
|
32 |
+
"int4WeightCompression",
|
33 |
+
"int4WeightExtractionFloat",
|
34 |
+
"int4WeightExtractionHalf",
|
35 |
+
"int8WeightExtractionFloat",
|
36 |
+
"int8WeightExtractionHalf",
|
37 |
+
],
|
38 |
+
)
|
39 |
+
except Exception as exception:
|
40 |
+
kernels = None
|
41 |
+
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
42 |
+
|
43 |
+
|
44 |
+
class W8A16Linear(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
47 |
+
ctx.inp_shape = inp.size()
|
48 |
+
ctx.weight_bit_width = weight_bit_width
|
49 |
+
out_features = quant_w.size(0)
|
50 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
51 |
+
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
52 |
+
ctx.weight_shape = weight.size()
|
53 |
+
output = inp.mm(weight.t())
|
54 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
55 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def backward(ctx, grad_output: torch.Tensor):
|
59 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
60 |
+
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
61 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
62 |
+
grad_input = grad_output.mm(weight)
|
63 |
+
grad_weight = grad_output.t().mm(inp)
|
64 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
65 |
+
|
66 |
+
|
67 |
+
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
68 |
+
with torch.cuda.device(weight.device):
|
69 |
+
n, m = weight.size(0), weight.size(1)
|
70 |
+
assert m % 2 == 0
|
71 |
+
m = m // 2
|
72 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
73 |
+
stream = torch.cuda.current_stream()
|
74 |
+
|
75 |
+
gridDim = (n, 1, 1)
|
76 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
77 |
+
|
78 |
+
kernels.int4WeightCompression(
|
79 |
+
gridDim,
|
80 |
+
blockDim,
|
81 |
+
0,
|
82 |
+
stream,
|
83 |
+
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
84 |
+
)
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
89 |
+
assert scale_list.dtype in [torch.half, torch.bfloat16]
|
90 |
+
assert weight.dtype in [torch.int8]
|
91 |
+
if source_bit_width == 8:
|
92 |
+
return weight.to(scale_list.dtype) * scale_list[:, None]
|
93 |
+
elif source_bit_width == 4:
|
94 |
+
func = (
|
95 |
+
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
assert False, "Unsupported bit-width"
|
99 |
+
|
100 |
+
with torch.cuda.device(weight.device):
|
101 |
+
n, m = weight.size(0), weight.size(1)
|
102 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
|
103 |
+
stream = torch.cuda.current_stream()
|
104 |
+
|
105 |
+
gridDim = (n, 1, 1)
|
106 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
107 |
+
|
108 |
+
func(
|
109 |
+
gridDim,
|
110 |
+
blockDim,
|
111 |
+
0,
|
112 |
+
stream,
|
113 |
+
[
|
114 |
+
ctypes.c_void_p(weight.data_ptr()),
|
115 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
116 |
+
ctypes.c_void_p(out.data_ptr()),
|
117 |
+
ctypes.c_int32(n),
|
118 |
+
ctypes.c_int32(m),
|
119 |
+
],
|
120 |
+
)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class QuantizedLinear(torch.nn.Module):
|
125 |
+
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
126 |
+
**kwargs):
|
127 |
+
super().__init__()
|
128 |
+
self.weight_bit_width = weight_bit_width
|
129 |
+
|
130 |
+
shape = weight.shape
|
131 |
+
|
132 |
+
if weight is None or empty_init:
|
133 |
+
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
134 |
+
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
135 |
+
else:
|
136 |
+
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
137 |
+
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
138 |
+
if weight_bit_width == 4:
|
139 |
+
self.weight = compress_int4_weight(self.weight)
|
140 |
+
|
141 |
+
self.weight = Parameter(self.weight.to(device), requires_grad=False)
|
142 |
+
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
|
143 |
+
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
144 |
+
|
145 |
+
def forward(self, input):
|
146 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
147 |
+
if self.bias is not None:
|
148 |
+
output = output + self.bias
|
149 |
+
return output
|
150 |
+
|
151 |
+
|
152 |
+
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
153 |
+
"""Replace fp16 linear with quantized linear"""
|
154 |
+
for layer in model.layers:
|
155 |
+
layer.self_attention.query_key_value = QuantizedLinear(
|
156 |
+
weight_bit_width=weight_bit_width,
|
157 |
+
weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
|
158 |
+
bias=layer.self_attention.query_key_value.bias,
|
159 |
+
dtype=layer.self_attention.query_key_value.weight.dtype,
|
160 |
+
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
161 |
+
empty_init=empty_init
|
162 |
+
)
|
163 |
+
layer.self_attention.dense = QuantizedLinear(
|
164 |
+
weight_bit_width=weight_bit_width,
|
165 |
+
weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
|
166 |
+
bias=layer.self_attention.dense.bias,
|
167 |
+
dtype=layer.self_attention.dense.weight.dtype,
|
168 |
+
device=layer.self_attention.dense.weight.device if device is None else device,
|
169 |
+
empty_init=empty_init
|
170 |
+
)
|
171 |
+
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
172 |
+
weight_bit_width=weight_bit_width,
|
173 |
+
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
174 |
+
bias=layer.mlp.dense_h_to_4h.bias,
|
175 |
+
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
176 |
+
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
177 |
+
empty_init=empty_init
|
178 |
+
)
|
179 |
+
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
180 |
+
weight_bit_width=weight_bit_width,
|
181 |
+
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
182 |
+
bias=layer.mlp.dense_4h_to_h.bias,
|
183 |
+
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
184 |
+
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
185 |
+
empty_init=empty_init
|
186 |
+
)
|
187 |
+
|
188 |
+
return model
|
tokenization_chatglm.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from typing import List, Optional, Union, Dict
|
4 |
+
from sentencepiece import SentencePieceProcessor
|
5 |
+
from transformers import PreTrainedTokenizer
|
6 |
+
from transformers.utils import logging, PaddingStrategy
|
7 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
8 |
+
|
9 |
+
|
10 |
+
class SPTokenizer:
|
11 |
+
def __init__(self, model_path: str):
|
12 |
+
# reload tokenizer
|
13 |
+
assert os.path.isfile(model_path), model_path
|
14 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
15 |
+
|
16 |
+
# BOS / EOS token IDs
|
17 |
+
self.n_words: int = self.sp_model.vocab_size()
|
18 |
+
self.bos_id: int = self.sp_model.bos_id()
|
19 |
+
self.eos_id: int = self.sp_model.eos_id()
|
20 |
+
self.pad_id: int = self.sp_model.eos_id()
|
21 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
+
|
23 |
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
24 |
+
self.special_tokens = {}
|
25 |
+
self.index_special_tokens = {}
|
26 |
+
for token in special_tokens:
|
27 |
+
self.special_tokens[token] = self.n_words
|
28 |
+
self.index_special_tokens[self.n_words] = token
|
29 |
+
self.n_words += 1
|
30 |
+
|
31 |
+
def tokenize(self, s: str):
|
32 |
+
return self.sp_model.EncodeAsPieces(s)
|
33 |
+
|
34 |
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
35 |
+
assert type(s) is str
|
36 |
+
t = self.sp_model.encode(s)
|
37 |
+
if bos:
|
38 |
+
t = [self.bos_id] + t
|
39 |
+
if eos:
|
40 |
+
t = t + [self.eos_id]
|
41 |
+
return t
|
42 |
+
|
43 |
+
def decode(self, t: List[int]) -> str:
|
44 |
+
return self.sp_model.decode(t)
|
45 |
+
|
46 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
47 |
+
text = self.sp_model.DecodePieces(tokens)
|
48 |
+
return text
|
49 |
+
|
50 |
+
def convert_token_to_id(self, token):
|
51 |
+
""" Converts a token (str) in an id using the vocab. """
|
52 |
+
if token in self.special_tokens:
|
53 |
+
return self.special_tokens[token]
|
54 |
+
return self.sp_model.PieceToId(token)
|
55 |
+
|
56 |
+
def convert_id_to_token(self, index):
|
57 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
+
if index in self.index_special_tokens:
|
59 |
+
return ""
|
60 |
+
return self.sp_model.IdToPiece(index)
|
61 |
+
|
62 |
+
|
63 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
64 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
65 |
+
|
66 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
67 |
+
|
68 |
+
def __init__(self, vocab_file, padding_side="left", **kwargs):
|
69 |
+
super().__init__(padding_side=padding_side, **kwargs)
|
70 |
+
self.name = "GLMTokenizer"
|
71 |
+
|
72 |
+
self.tokenizer = SPTokenizer(vocab_file)
|
73 |
+
self.special_tokens = {
|
74 |
+
"<bos>": self.tokenizer.bos_id,
|
75 |
+
"<eos>": self.tokenizer.eos_id,
|
76 |
+
"<pad>": self.tokenizer.pad_id
|
77 |
+
}
|
78 |
+
|
79 |
+
def get_command(self, token):
|
80 |
+
if token in self.special_tokens:
|
81 |
+
return self.special_tokens[token]
|
82 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
83 |
+
return self.tokenizer.special_tokens[token]
|
84 |
+
|
85 |
+
@property
|
86 |
+
def pad_token(self) -> str:
|
87 |
+
return "</s>"
|
88 |
+
|
89 |
+
@property
|
90 |
+
def pad_token_id(self):
|
91 |
+
return self.get_command("<pad>")
|
92 |
+
|
93 |
+
@property
|
94 |
+
def vocab_size(self):
|
95 |
+
return self.tokenizer.n_words
|
96 |
+
|
97 |
+
def get_vocab(self):
|
98 |
+
""" Returns vocab as a dict """
|
99 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
100 |
+
vocab.update(self.added_tokens_encoder)
|
101 |
+
return vocab
|
102 |
+
|
103 |
+
def _tokenize(self, text, **kwargs):
|
104 |
+
return self.tokenizer.tokenize(text)
|
105 |
+
|
106 |
+
def _convert_token_to_id(self, token):
|
107 |
+
""" Converts a token (str) in an id using the vocab. """
|
108 |
+
return self.tokenizer.convert_token_to_id(token)
|
109 |
+
|
110 |
+
def _convert_id_to_token(self, index):
|
111 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
112 |
+
return self.tokenizer.convert_id_to_token(index)
|
113 |
+
|
114 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
115 |
+
return self.tokenizer.decode_tokens(tokens)
|
116 |
+
|
117 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
118 |
+
"""
|
119 |
+
Save the vocabulary and special tokens file to a directory.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
save_directory (`str`):
|
123 |
+
The directory in which to save the vocabulary.
|
124 |
+
filename_prefix (`str`, *optional*):
|
125 |
+
An optional prefix to add to the named of the saved files.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
`Tuple(str)`: Paths to the files saved.
|
129 |
+
"""
|
130 |
+
if os.path.isdir(save_directory):
|
131 |
+
vocab_file = os.path.join(
|
132 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
vocab_file = save_directory
|
136 |
+
|
137 |
+
with open(self.vocab_file, 'rb') as fin:
|
138 |
+
proto_str = fin.read()
|
139 |
+
|
140 |
+
with open(vocab_file, "wb") as writer:
|
141 |
+
writer.write(proto_str)
|
142 |
+
|
143 |
+
return (vocab_file,)
|
144 |
+
|
145 |
+
def get_prefix_tokens(self):
|
146 |
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
147 |
+
return prefix_tokens
|
148 |
+
|
149 |
+
def build_inputs_with_special_tokens(
|
150 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
151 |
+
) -> List[int]:
|
152 |
+
"""
|
153 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
154 |
+
adding special tokens. A BERT sequence has the following format:
|
155 |
+
|
156 |
+
- single sequence: `[CLS] X [SEP]`
|
157 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
158 |
+
|
159 |
+
Args:
|
160 |
+
token_ids_0 (`List[int]`):
|
161 |
+
List of IDs to which the special tokens will be added.
|
162 |
+
token_ids_1 (`List[int]`, *optional*):
|
163 |
+
Optional second list of IDs for sequence pairs.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
167 |
+
"""
|
168 |
+
prefix_tokens = self.get_prefix_tokens()
|
169 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
170 |
+
if token_ids_1 is not None:
|
171 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
172 |
+
return token_ids_0
|
173 |
+
|
174 |
+
def _pad(
|
175 |
+
self,
|
176 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
177 |
+
max_length: Optional[int] = None,
|
178 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
179 |
+
pad_to_multiple_of: Optional[int] = None,
|
180 |
+
return_attention_mask: Optional[bool] = None,
|
181 |
+
) -> dict:
|
182 |
+
"""
|
183 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
184 |
+
|
185 |
+
Args:
|
186 |
+
encoded_inputs:
|
187 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
188 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
189 |
+
Will truncate by taking into account the special tokens.
|
190 |
+
padding_strategy: PaddingStrategy to use for padding.
|
191 |
+
|
192 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
193 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
194 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
195 |
+
The tokenizer padding sides are defined in self.padding_side:
|
196 |
+
|
197 |
+
- 'left': pads on the left of the sequences
|
198 |
+
- 'right': pads on the right of the sequences
|
199 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
200 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
201 |
+
`>= 7.5` (Volta).
|
202 |
+
return_attention_mask:
|
203 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
204 |
+
"""
|
205 |
+
# Load from model defaults
|
206 |
+
assert self.padding_side == "left"
|
207 |
+
|
208 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
209 |
+
seq_length = len(required_input)
|
210 |
+
|
211 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
212 |
+
max_length = len(required_input)
|
213 |
+
|
214 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
215 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
216 |
+
|
217 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
218 |
+
|
219 |
+
# Initialize attention mask if not present.
|
220 |
+
if "attention_mask" not in encoded_inputs:
|
221 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
222 |
+
|
223 |
+
if "position_ids" not in encoded_inputs:
|
224 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
225 |
+
|
226 |
+
if needs_to_be_padded:
|
227 |
+
difference = max_length - len(required_input)
|
228 |
+
|
229 |
+
if "attention_mask" in encoded_inputs:
|
230 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
231 |
+
if "position_ids" in encoded_inputs:
|
232 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
233 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
234 |
+
|
235 |
+
return encoded_inputs
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "THUDM/chatglm-6b",
|
3 |
+
"remove_space": false,
|
4 |
+
"do_lower_case": false,
|
5 |
+
"tokenizer_class": "ChatGLMTokenizer",
|
6 |
+
"auto_map": {
|
7 |
+
"AutoTokenizer": [
|
8 |
+
"tokenization_chatglm.ChatGLMTokenizer",
|
9 |
+
null
|
10 |
+
]
|
11 |
+
}
|
12 |
+
}
|