xingzhaohu commited on
Commit
6586b2c
·
verified ·
1 Parent(s): a9a6802

init model files

Browse files
README.md CHANGED
@@ -1,3 +1,186 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ 星辰语义大模型-TeleChat2
4
+ </h1>
5
+ </div>
6
+
7
+
8
+ <p align="center">
9
+ 🦉 <a href="https://github.com/Tele-AI/TeleChat2" target="_blank">github</a>️ • 🤗 <a href="https://huggingface.co/Tele-AI" target="_blank">Hugging Face</a> • 🤖 <a href="https://modelscope.cn/organization/TeleAI" target="_blank">ModelScope</a> • 🏔 <a href="https://gitee.com/mindspore/mindformers/tree/dev/research/telechat2" target="_blank">MindSpore</a> • 🐾 <a href="https://gitee.com/Tele-AI/tele-chat2" target="_blank">gitee</a>️ • 💬 <a href="https://github.com/Tele-AI/Telechat/blob/master/images/wechat.jpg" target="_blank">WeChat</a>
10
+ </p>
11
+
12
+ # 目录
13
+
14
+ - [模型介绍](#模型介绍)
15
+ - [效果评测](#效果评测)
16
+ - [模型推理](#模型推理)
17
+ - [声明、协议、引用](#声明协议引用)
18
+
19
+ # 最新动态
20
+ - 2024.11.08 开源 **TeleChat2-3B**、**TeleChat2-7B**、**TeleChat2-35B**,该版本模型均具备 **Function Call** 功能。
21
+ - 2024.10.18 开源TeleChat2-35B模型。
22
+ - 2024.9.20 开源TeleChat2-115B模型,该模型是**首个完全国产算力训练并开源的千亿参数模型**。
23
+
24
+ # 模型介绍
25
+
26
+ ### 星辰语义大模型-TeleChat2
27
+
28
+ - 星辰语义大模型**TeleChat2**是由中国电信人工智能研究院研发训练的大语言模型,该系列模型**完全基于国产算力**训练。
29
+ - 本次开源的 **TeleChat2-3B**、**TeleChat2-7B**、**TeleChat2-35B** 模型已支持**工具调用**功能。在 **Function Call** 方面,我们针对性进行了效果优化,在相关榜单评测上相比同尺寸模型均有较好表现。
30
+ - **TeleChat2-115B**模型采用10万亿 Tokens中英文高质量语料进行训练,同步开源对话模型**TeleChat2-115B**的多格式、多平台权重文件。
31
+ - **TeleChat2**在训练数据、训练方法等方面进行了改进,在通用问答和知识类、代码类、数学类榜单上相比**TeleChat1**均有大幅提升。
32
+ - **TeleChat2**完全基于国产算力和国产深度学习框架进行训练,算力和算法框架更自主可控。优化MP、PP、SP实现方式提升模型性能,优化算子来提升训练速度。
33
+ - 我们使用大量小模型实验来验证scaling law规律,在不同模型结构、不同数据配比和数据清洗方式中寻找最优设计。
34
+ - 采用RingAttention及其他序列切分方式,实现长文训练性能提升;通过ntk-aware+attention-scaling的方式保证训练长度切换时的平稳过渡,以此来保证模型在不同长度数据下的训练效果。
35
+ - 在微调数据方面,我们进行了指令复杂性提升与多样性扩充,通过数据合成和人工标注生成高质量数据,并使用拒绝采样生成多样的推理路径;通过研究一套基于base模型反向选择偏好对齐数据方案,基于适配数据最大限度提升模型效果。
36
+ - 通用能力较TeleChat系列模型提升超过29%,在逻辑推理、总结摘要、长文写作和数学计算上均有大幅提升。
37
+
38
+ ### 模型结构
39
+
40
+ 我们采用标准的 `Decoder-only` 结构设计了 **TeleChat2** 模型,使用 [Rotary Embedding](https://arxiv.org/pdf/2104.09864.pdf)
41
+ 的位置编码方法、使用 [SwiGLU](https://arxiv.org/pdf/2002.05202.pdf)
42
+ 激活函数来替代GELU激活函数、使用基于 [RMSNorm](https://arxiv.org/abs/1910.07467) 的 Pre-Normalization进行层标准化操作。我们将**TeleChat2**的词嵌入层和输出lm
43
+ head层参数分开,有助于增强训练稳定性和收敛性。我们选择了GQA以节约attention部分的参数量和计算量、提升训练和推理速度。
44
+
45
+ **TeleChat2**的模型结构配置如下表所示:
46
+
47
+ | | layer_num | hidden_size | ffn_hidden_size | head_num | tie_word_embeddings | GQA |
48
+ | ---- | --------- | ----------- | --------------- | -------- | ------------------- | ---- |
49
+ | 3B | 24 | 3072 | 6144 | 24 | 否 | 否 |
50
+ | 7B | 30 | 4096 | 12288 | 32 | 否 | 否 |
51
+ | 35B | 64 | 6144 | 20480 | 48 | 否 | 否 |
52
+ | 115B | 96 | 8192 | 40960 | 64 | 否 | 是 |
53
+
54
+
55
+ 我们开源的 **TeleChat2** 模型:
56
+
57
+ - 支持deepspeed微调,开源了基于deepspeed的训练代码,支持Zero并行显存优化,同时集成了FlashAttention2
58
+ - 多轮能力支持。开源了多轮数据构建方式,针对多轮模型训练集成了针对多轮的mask loss训练方式,更好的聚焦多轮答案,提升问答效果。
59
+
60
+ 本次发布版本和下载链接见下表
61
+
62
+ | 模型版本 | 下载链接 |
63
+ | -------------- | -------- |
64
+ | telechat2-3B | [modelscope](https://modelscope.cn/models/TeleAI/TeleChat2-3B)|
65
+ | telechat2-7B | [modelscope](https://modelscope.cn/models/TeleAI/TeleChat2-7B)|
66
+ | telechat2-35B | [modelscope](https://modelscope.cn/models/TeleAI/TeleChat2-35B-Nov)|
67
+ | telechat2-115B | [modelscope](https://modelscope.cn/models/TeleAI/TeleChat2-115B)|
68
+
69
+
70
+ # 效果评测
71
+
72
+ **TeleChat2** 模型相比同规模模型在评测效果方面也有较好的表现,我们的评测集涵盖了包括MMLU、C-Eval、CMMLU、
73
+ GSM8K、MATH、HumanEval、BBH等数据集,评测能力包括了指令遵循、考试能力、数学计算和推理、代码生成等
74
+
75
+ ## 评测集介绍
76
+
77
+ ### 通用能力
78
+
79
+ - MMLU 数据集是一个全面的英文评测数据集,涵盖了 57 个学科,包括人文学科、社会科学、自然科学、初等数学、美国历史、计算机科学、法律等等。
80
+
81
+ - CEVAL 数据集是一个全面的中文评估测试集,包括初中、高中、大学和专业难度级别的多项选择题,涵盖了 52 个不同的学科领域。
82
+
83
+ - CMMLU 数据集同样是一个全面的中文评估测试集,涵盖了从基础学科到高级专业水平的67个主题。
84
+
85
+ ### 推理和代码能力
86
+
87
+ - GSM8K 数据集包含了8.5K高质量的小学数学题,能够评估语言模型在数学推理能力上的表现。
88
+
89
+ - HumanEval 数据集是一个由openai提供的代码能力测试数据集,它由 164 个编程问题组成,要求根据给定的问题和代码模板,生成正确的代码片段。
90
+
91
+ - BBH 数据集全名为BIG-Bench Hard(BBH),包含23个具有挑战性的BIG-Bench任务,均为之前的语言模型评估中没有超过平均人类评审者表现的任务。
92
+
93
+ - MBPP 数据集包含大约1000个众包的Python编程问题,涵盖编程基础知识、标准库功能等。每个问题包括任务描述、代码解决方案和3个自动化测试用例。
94
+
95
+ ### 主观题能力
96
+
97
+ - [AlignBench](https://github.com/THUDM/AlignBench)是一个多维度全面评估中文大模型对齐水平的评测基准,包含638道单轮主观评测题。
98
+
99
+ - [MT-bench](https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md)是一个用于评估聊天助手的具有挑战性的多轮开放式问题集,包含80通多轮主观评测题。
100
+
101
+ ### 指令遵循能力
102
+
103
+ - [IFEval](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/ifeval/README.md)旨在评估语言模型对指令的精确遵循能力,它包含了500条可精确验证的指令,是Open
104
+ LLM Leaderboard中使用的核心基准测试之一。
105
+
106
+ ## 评测结果如下
107
+
108
+ | Dataset | Llama-3.1-70B | Qwen1.5-110B | Qwen2-72-instruct | DeepSeek-v2 | TeleChat2-115B |TeleChat2-35B |TeleChat2-7B |TeleChat2-3B |
109
+ |:----------:|:-------------:|:------------:|:-----------------:|:-----------:|:--------------:|:--------------:|:--------------:|:----------------:|
110
+ | C-Eval | - | - | 83.8 | 78 | **86.9** | 85 | 82 | 75 |
111
+ | MMLU | **86** | 80.4 | 82.3 | 77.8 | 80.9 | 82 | 79.6 | 72.9 |
112
+ | CMMLU | 69.01 | 87.64 | 87.47 | 81.6 | **89.94** | 90.18 | 84.6 | 73 |
113
+ | BBH | - | 74.8 | - | 79.7 | **89.04** | 88.6 | 77.3 | 65.99 |
114
+ | GSM8K | **95.1** | 85.4 | 91.1 | 92.2 | 92.2 | 91 | 86.8 | 64.7 |
115
+ | HumanEval | 80.5 | 52.4 |**86** | 81.1 | 75 | 73 | 56 | 38 |
116
+ | MBPP | **86** | 58.1 | 80.2 | 72 | 78 | 75 | 62.6 | 47 |
117
+ | AlignBench | - | 7.86 | **8.27** | 7.91 | 8.03 | 7.88 | 6.96 | 5.74 |
118
+ | MT-bench | 8.79 | 8.88 | **9.12** | 8.97 | 8.89 | 8.2 | 7.2 | 5.72 |
119
+ | IFEval | **87.5** | - | 77.6 | 63.8 | 82.81 | 79.63 | 73.1 | 61.29 |
120
+
121
+ # 模型推理
122
+
123
+ ### 模型推理
124
+
125
+ 当前模型推理兼容了单卡和多卡推理,以及针对长文推理做了部分优化工作。
126
+
127
+ **模型推理方法示范**
128
+
129
+ ```python
130
+ >>> import os
131
+ >>> import torch
132
+ >>> from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
133
+ >>> tokenizer = AutoTokenizer.from_pretrained('TeleAI/TeleChat2-7B', trust_remote_code=True)
134
+ >>> model = AutoModelForCausalLM.from_pretrained('TeleAI/TeleChat2-7B', trust_remote_code=True, device_map="auto",
135
+ torch_dtype=torch.float16)
136
+ >>> prompt = "生抽与老抽的区别?"
137
+ >>> messages = [{"role": "user", "content": prompt}]
138
+ >>> text = tokenizer.apply_chat_template(messages,
139
+ >>> tokenize=False,
140
+ >>> add_generation_prompt=True
141
+ >>> )
142
+ >>> model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
143
+ >>> generated_ids = model.generate(
144
+ >>> **model_inputs,
145
+ >>> max_new_tokens=512
146
+ >>> )
147
+ >>> generated_ids = [
148
+ >>> output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
149
+ >>> ]
150
+
151
+ >>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
152
+ 生抽和老抽是两种不同的酱油,它们在风味、色泽和用途上都有所区别。
153
+
154
+ 1.颜色:生抽的颜色比较淡,而老抽的颜色较深。生抽的颜色呈红褐色或棕红色,而老抽的颜色则呈棕黑色。
155
+
156
+ 2.味道:生抽具有鲜美的咸味和微甜的味浅,而老抽浓郁,颜色较深。根据个人口味和烹饪需求选择不同的酱油类型可以获得更好的口感和菜肴效果。
157
+ ```
158
+
159
+ # 声明、协议、引用
160
+
161
+ ### 声明
162
+
163
+ 我们在此声明,不要使用TeleChat模型及其衍生模型进行任何危害国家社会安全或违法的活动。同时,我们也要求使用者不要将TeleChat模型用于没有安全审查和备案的互联网服务。我们希望所有使用者遵守上述原则,确保科技发展在合法合规的环境下进行。
164
+
165
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用TeleChat开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
166
+
167
+ ### 协议
168
+
169
+ 社区使用 TeleChat 模型需要遵循《[TeleChat模型社区许可协议](./TeleChat模型社区许可协议.pdf)》。TeleChat模型支持商业用途,如果您计划将 TeleChat
170
+ 模型或其衍生品用于商业目的,您需要通过以下联系邮箱
171
+ [email protected],提交《TeleChat模型社区许可协议》要求的申请材料。审核通过后,将特此授予您一个非排他性、全球性、不可转让、不可再许可、可撤销的商用版权许可。
172
+
173
+ ### 引用
174
+
175
+ 如需引用我们的工作,请使用如下 reference:
176
+
177
+ ```
178
+ @misc{wang2024telechat,
179
+ title={TeleChat Technical Report},
180
+ author={Zihan Wang and Xinzhang Liu and Shixuan Liu and Yitong Yao and Yuyao Huang and Zhongjiang He and Xuelong Li and Yongxiang Li and Zhonghao Che and Zhaoxi Zhang and Yan Wang and Xin Wang and Luwen Pu and Huihan Xu and Ruiyu Fang and Yu Zhao and Jie Zhang and Xiaomeng Huang and Zhilong Lu and Jiaxin Peng and Wenjun Zheng and Shiquan Wang and Bingkai Yang and Xuewei he and Zhuoru Jiang and Qiyi Xie and Yanhan Zhang and Zhongqiu Li and Lingling Shi and Weiwei Fu and Yin Zhang and Zilu Huang and Sishi Xiong and Yuxiang Zhang and Chao Wang and Shuangyong Song},
181
+ year={2024},
182
+ eprint={2401.03804},
183
+ archivePrefix={arXiv},
184
+ primaryClass={cs.CL}
185
+ }
186
+ ```
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "apply_residual_connection_post_layernorm": false,
3
+ "architectures": [
4
+ "TeleChat2ForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_telechat2.Telechat2Config",
8
+ "AutoModelForCausalLM": "modeling_telechat2.Telechat2ForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "attention_softmax_in_fp32": true,
12
+ "bias_dropout_fusion": true,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_dropout": 0.0,
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "layer_norm_epsilon": 1e-05,
19
+ "masked_softmax_fusion": true,
20
+ "max_position_embeddings": 32768,
21
+ "model_type": "telechat",
22
+ "n_head": 32,
23
+ "n_inner": null,
24
+ "n_layer": 30,
25
+ "num_key_value_heads":32,
26
+ "offset_alibi": 100,
27
+ "pad_token_id": 3,
28
+ "pretraining_tp": 2,
29
+ "skip_bias_add": true,
30
+ "skip_bias_add_qkv": false,
31
+ "slow_but_exact": false,
32
+ "transformers_version": "4.44.2",
33
+ "torch_dtype": "bfloat16",
34
+ "unk_token_id": 0,
35
+ "use_cache": true,
36
+ "vocab_size": 131072,
37
+ "ffn_hidden_size": 12288,
38
+ "flash_attn":true,
39
+ "tie_word_embeddings":false,
40
+ "rope_scaling": {
41
+ "factor": 1.0,
42
+ "rope_type": "dynamic"
43
+ },
44
+ "rope_theta": 1000000,
45
+ "training_seqlen":32768,
46
+ "base_seqlen":32768,
47
+ "seq_length": 32768
48
+ }
49
+
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-generation"}
configuration_telechat2.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Telechat configuration"""
17
+
18
+ from packaging import version
19
+ from collections import OrderedDict
20
+ from transformers.utils import is_torch_available, logging
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ class Telechat2Config(PretrainedConfig):
27
+ """
28
+ Args:
29
+ vocab_size (`int`, *optional*, defaults to 160256): Vocabulary size of the Telechat model.
30
+ hidden_size (`int`, *optional*, defaults to 4096): Dimensionality of the embeddings and hidden states.
31
+ ffn_hidden_size (`int`, *optional*, defaults to 12288): Dimensionality of the feed-forward hidden states.
32
+ n_layer (`int`, *optional*, defaults to 30): Number of hidden layers in the Transformer
33
+ n_head (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer.
34
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers.
35
+ initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
36
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`): If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
37
+ hidden_dropout (`float`, *optional*, defaults to 0.0): Dropout rate of the dropout function on the bias dropout.
38
+ attention_dropout (`float`, *optional*, defaults to 0.0): Dropout rate applied to the attention probs
39
+ use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions.
40
+ training_seqlen (`int`, *optional*, defaults to 8192): Sequence length during last finetuning.
41
+ logn (`bool`, *optional*, defaults to `True`): Whether or not to use logN during extrapolation.
42
+ embed_layernorm (`bool`, *optional*, defaults to `True`): Whether or not to use embedding layernorm.
43
+
44
+ """
45
+
46
+ model_type = "telechat"
47
+ keys_to_ignore_at_inference = ["past_key_values"]
48
+ attribute_map = {
49
+ "num_hidden_layers": "n_layer",
50
+ "num_attention_heads": "n_head",
51
+ }
52
+
53
+ def __init__(
54
+ self,
55
+ vocab_size=160256,
56
+ hidden_size=4096,
57
+ n_layer=30,
58
+ n_head=32,
59
+ layer_norm_epsilon=1e-5,
60
+ initializer_range=0.02,
61
+ use_cache=True,
62
+ bos_token_id=1,
63
+ eos_token_id=2,
64
+ apply_residual_connection_post_layernorm=False,
65
+ hidden_dropout=0.0,
66
+ attention_dropout=0.0,
67
+ ffn_hidden_size=12288,
68
+ training_seqlen = 8192,
69
+ logn = True,
70
+ embed_layernorm = False,
71
+ **kwargs,
72
+ ):
73
+ self.vocab_size = vocab_size
74
+ n_embed = kwargs.pop("n_embed", None)
75
+ self.hidden_size = hidden_size if n_embed is None else n_embed
76
+ self.n_layer = n_layer
77
+ self.n_head = n_head
78
+ self.layer_norm_epsilon = layer_norm_epsilon
79
+ self.initializer_range = initializer_range
80
+ self.use_cache = use_cache
81
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
82
+ self.hidden_dropout = hidden_dropout
83
+ self.attention_dropout = attention_dropout
84
+ self.bos_token_id = bos_token_id
85
+ self.eos_token_id = eos_token_id
86
+ self.logn = logn
87
+ self.ffn_hidden_size = ffn_hidden_size
88
+ self.training_seqlen = training_seqlen
89
+ self.embed_layernorm = embed_layernorm
90
+ self.num_key_value_heads= kwargs.pop("num_key_value_heads", None)
91
+
92
+
93
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
94
+
generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_new_tokens": 1000,
3
+ "do_sample": false,
4
+ "use_cache": true,
5
+ "temperature": 0.3,
6
+ "top_k": 5,
7
+ "top_p": 0.85,
8
+ "repetition_penalty": 1.02,
9
+ "pad_token_id": 3,
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "user_token_id": 4,
13
+ "bot_token_id": 5,
14
+ "start_token_id": 1
15
+ }
generation_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from collections import deque
3
+ from queue import Queue
4
+ import copy
5
+
6
+
7
+ class History:
8
+
9
+ def __init__(self, tokenizer, history):
10
+ '''
11
+ init from a list of dict
12
+ '''
13
+ # use deque to meet some special situation
14
+ self.input_history = deque()
15
+ self.tokenizer = tokenizer
16
+ if history:
17
+ self._transfer_from_list(history)
18
+
19
+ def _transfer_from_list(self, history):
20
+ for message in history:
21
+ content = message.get("content")
22
+ # the token result may not be equal to the result model gen
23
+ message.update(self.tokenizer(content))
24
+ self.input_history.append(message)
25
+
26
+ def append(self, message):
27
+ content = message.get("content")
28
+ if "input_ids" not in message or "attention_mask" not in message:
29
+ message.update(self.tokenizer(content))
30
+ self.input_history.append(message)
31
+
32
+ def append_left(self, message):
33
+ content = message.get("content")
34
+ if "input_ids" not in message or "attention_mask" not in message:
35
+ message.update(self.tokenizer(content))
36
+ self.input_history.appendleft(message)
37
+
38
+ def pop(self):
39
+ x = self.input_history.pop()
40
+ return x
41
+
42
+ def pop_left(self):
43
+ x = self.pop_left()
44
+ return x
45
+
46
+ def update(self, message):
47
+ self.input_history.pop()
48
+ self.append(message)
49
+
50
+ def __len__(self):
51
+ return self.input_history.__len__()
52
+
53
+ def __str__(self):
54
+ return self.input_history.__str__()
55
+
56
+ def __copy__(self):
57
+ new_instance = type(self)(self.tokenizer, [])
58
+ new_instance.input_history = copy.copy(self.input_history)
59
+ return new_instance
60
+
61
+ def __deepcopy__(self, memodict={}):
62
+ new_instance = type(self)(self.tokenizer, [])
63
+ new_instance.input_history = copy.deepcopy(self.input_history)
64
+ return new_instance
65
+
66
+
67
+ class TelechatIterTextStreamer:
68
+ """
69
+ With reference to the TextIterStreamers in transformers, we have rewritten this class
70
+ """
71
+
72
+ def __init__(
73
+ self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
74
+ **decode_kwargs
75
+ ):
76
+
77
+ self.tokenizer = tokenizer
78
+ self.history = history
79
+ self.skip_prompt = skip_prompt
80
+ self.timeout = timeout
81
+ self.decode_kwargs = decode_kwargs
82
+
83
+ self.text_queue = Queue()
84
+ self.cache_time = 0
85
+ self.text_until = ""
86
+ self.token_until = []
87
+ self.stop_signal = None
88
+ self.next_tokens_are_prompt = True
89
+
90
+ self.history.append({"role": "bot", "content": self.text_until})
91
+
92
+ def put(self, value):
93
+ """
94
+ put printable text into queue
95
+ """
96
+ if len(value.shape) > 1 and value.shape[0] > 1:
97
+ raise ValueError("TextStreamer only supports batch size 1")
98
+ elif len(value.shape) > 1:
99
+ value = value[0]
100
+
101
+ if self.skip_prompt and self.next_tokens_are_prompt:
102
+ self.next_tokens_are_prompt = False
103
+ return
104
+
105
+ if value[-1] == self.tokenizer.eos_token_id:
106
+ return
107
+
108
+ # there may be some smart way to decode.
109
+ self.token_until.extend(value.tolist())
110
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
111
+
112
+
113
+ if self._is_printable(text) or self.cache_time >= 6:
114
+ output_text = text[len(self.text_until):]
115
+ self.text_until = text
116
+
117
+ else:
118
+ self.cache_time+=1
119
+ return
120
+
121
+ self.on_finalized_text(output_text)
122
+
123
+ def end(self):
124
+ """Flushes any remaining cache and prints a newline to stdout."""
125
+ # Flush the cache, if it exists
126
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
127
+ output_text = text[len(self.text_until):]
128
+ self.text_until = text
129
+ self.on_finalized_text(output_text, stream_end=True)
130
+ self.clear_cache()
131
+
132
+ def clear_cache(self):
133
+ self.cache_time = 0
134
+ self.token_until = []
135
+ self.text_until = ""
136
+ self.history = None
137
+ self.next_tokens_are_prompt = True
138
+
139
+ def on_finalized_text(self, text: str, stream_end: bool = False):
140
+ """Put the text tuple in the queue."""
141
+ self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until,
142
+ "attention_mask": [1] * len(self.token_until)})
143
+ self.text_queue.put((text, self.history), timeout=self.timeout)
144
+ if stream_end:
145
+ self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)
146
+
147
+ @staticmethod
148
+ def _is_printable(cp):
149
+ """Checks whether tokens can be decoded or not"""
150
+ if "�" in cp:
151
+ return False
152
+ return True
153
+
154
+ def __iter__(self):
155
+ return self
156
+
157
+ def __next__(self):
158
+ value_now, history_until = self.text_queue.get(timeout=self.timeout)
159
+ if value_now == self.stop_signal:
160
+ raise StopIteration()
161
+ else:
162
+ return value_now, history_until
modeling_telechat2.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
17
+
18
+ # Copyright (c) 2021 EleutherAI
19
+ # This file is based on code by the authors denoted below and has been modified from its original version.
20
+ #
21
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
22
+ #
23
+ # Licensed under the Apache License, Version 2.0 (the "License");
24
+ # you may not use this file except in compliance with the License.
25
+ # You may obtain a copy of the License at
26
+ #
27
+ # http://www.apache.org/licenses/LICENSE-2.0
28
+ #
29
+ # Unless required by applicable law or agreed to in writing, software
30
+ # distributed under the License is distributed on an "AS IS" BASIS,
31
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32
+ # See the License for the specific language governing permissions and
33
+ # limitations under the License.
34
+
35
+
36
+ """PyTorch TELECHAT model."""
37
+
38
+ import warnings
39
+ from typing import Optional, Tuple, Union, List, Dict
40
+ from threading import Thread
41
+
42
+ import torch
43
+ import math
44
+ import copy
45
+ from torch import nn
46
+ import torch.utils.checkpoint
47
+ from torch.nn import functional as F
48
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
49
+ from transformers.modeling_outputs import (
50
+ BaseModelOutputWithPastAndCrossAttentions,
51
+ CausalLMOutputWithCrossAttentions
52
+ )
53
+ from transformers.modeling_utils import PreTrainedModel
54
+ from transformers.utils import logging
55
+ from transformers import GenerationConfig
56
+
57
+ from .configuration_telechat2 import Telechat2Config
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ _CHECKPOINT_FOR_DOC = "telechat"
63
+ _CONFIG_FOR_DOC = "Telechat2Config"
64
+
65
+ TELECHAT_PRETRAINED_MODEL_ARCHIVE_LIST = []
66
+
67
+ try:
68
+ from einops import rearrange
69
+ except ImportError:
70
+ rearrange = None
71
+
72
+ use_flash_attn = True
73
+ try:
74
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
75
+ except ImportError:
76
+ try:
77
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
78
+ except ImportError:
79
+ flash_attn_unpadded_func = None
80
+
81
+
82
+ class RotaryEmbedding(torch.nn.Module):
83
+ # Extracted from: https://github.com/EleutherAI/gpt-neox
84
+ def __init__(self, dim, config):
85
+ super().__init__()
86
+ self.config = config
87
+ self.dim = dim
88
+ self.base = config.rope_theta
89
+ self.inv_freq = 1. / (self.base ** (torch.arange(0, dim, 2).float().half() / dim))
90
+ self.max_seq_len_cached = None
91
+ self.cos_cached = None
92
+ self.sin_cached = None
93
+ self.precision = config.torch_dtype
94
+
95
+ def get_mscale(self, scale=1):
96
+ if scale <= 1:
97
+ return 1.0
98
+ return 0.1 * math.log(scale) + 1.0
99
+
100
+ def get_ntk_alpha(self, true_seq_len):
101
+ context_value = math.log(true_seq_len / self.config.base_seqlen, 2) + 1
102
+ # ntk_alpha = 2 ** context_value - 1
103
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
104
+ ntk_alpha = max(ntk_alpha, 1)
105
+ return ntk_alpha
106
+
107
+ def forward(self, x, seq_dim=0, seq_len=None):
108
+ if seq_len is None:
109
+ seq_len = x.shape[seq_dim]
110
+ seq_len = max(seq_len, self.config.training_seqlen)
111
+ ntk_alpha = self.get_ntk_alpha(seq_len)
112
+ self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
113
+ if True:
114
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
115
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
116
+ self.max_seq_len_cached = seq_len
117
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
118
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
121
+ if self.precision == torch.bfloat16:
122
+ emb = emb.float()
123
+ # [sx, 1 (b * np), hn]
124
+ self.cos_cached = self.mscale * emb.cos()[:, None, :].half()
125
+ self.sin_cached = self.mscale * emb.sin()[:, None, :].half()
126
+ if self.precision == torch.bfloat16:
127
+ self.cos_cached = self.cos_cached.bfloat16()
128
+ self.sin_cached = self.sin_cached.bfloat16()
129
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
130
+
131
+
132
+ # rotary pos emb helpers:
133
+ def rotate_half(x):
134
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
135
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
136
+
137
+
138
+ def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
139
+ cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
140
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
141
+
142
+
143
+ class MixedFusedRMSNorm(nn.Module):
144
+ # Extracted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
145
+ def __init__(self, hidden_size, eps=1e-6):
146
+ super().__init__()
147
+ self.weight = nn.Parameter(torch.ones(hidden_size))
148
+ self.variance_epsilon = eps
149
+
150
+ def forward(self, hidden_states):
151
+ input_dtype = hidden_states.dtype
152
+ hidden_states = hidden_states.to(torch.float32)
153
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
154
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
155
+ return self.weight * hidden_states.to(input_dtype)
156
+
157
+
158
+ class FlashSelfAttention(torch.nn.Module):
159
+ # Extracted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py
160
+ """Implement the scaled dot product attention with softmax.
161
+ Arguments
162
+ ---------
163
+ softmax_scale: The temperature to use for the softmax attention.
164
+ (default: 1/sqrt(d_keys) where d_keys is computed at
165
+ runtime)
166
+ attention_dropout: The dropout rate to apply to the attention
167
+ (default: 0.0)
168
+ """
169
+
170
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
171
+ device=None, dtype=None):
172
+ super().__init__()
173
+ assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
174
+ 'e.g., with pip install flash-attn')
175
+ assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
176
+ self.causal = causal
177
+ self.softmax_scale = softmax_scale
178
+ self.dropout_p = attention_dropout
179
+
180
+ def forward(self, q, k, v):
181
+ """Implements the multihead softmax attention.
182
+ Arguments
183
+ ---------
184
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
185
+ """
186
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
187
+ assert all((i.is_cuda for i in (q, k, v)))
188
+
189
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
190
+ seqlen_k = k.shape[1]
191
+
192
+ q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
193
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
194
+ device=q.device)
195
+ self.training = False
196
+ if self.training:
197
+ # during training q,k,v always have same seqlen
198
+ assert seqlen_k == seqlen_q
199
+
200
+ is_causal = self.causal
201
+ cu_seqlens_k = cu_seqlens_q
202
+ dropout_p = self.dropout_p
203
+ else:
204
+ # turn off FA causal mask after first inference autoregressive iteration
205
+ # only on first autoregressive step q,k,v have same seqlen
206
+ is_causal = seqlen_q == seqlen_k
207
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
208
+ device=q.device)
209
+ dropout_p = 0
210
+
211
+ output = flash_attn_unpadded_func(
212
+ q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
213
+ dropout_p=dropout_p,
214
+ softmax_scale=self.softmax_scale, causal=is_causal
215
+ )
216
+
217
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
218
+ return output
219
+
220
+
221
+ def _make_causal_mask(
222
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
223
+ ) -> torch.BoolTensor:
224
+ """
225
+ Make causal mask used for self-attention.
226
+ """
227
+ batch_size, target_length = input_ids_shape
228
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
229
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
230
+ seq_ids = torch.arange(target_length, device=device)
231
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
232
+
233
+ if past_key_values_length > 0:
234
+ mask[:, :past_key_values_length] = False
235
+
236
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
237
+ return expanded_mask
238
+
239
+
240
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
241
+ """
242
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
243
+ """
244
+ batch_size, src_length = mask.shape
245
+ tgt_length = tgt_length if tgt_length is not None else src_length
246
+
247
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
248
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
249
+
250
+
251
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
252
+ """
253
+ Dropout add function
254
+
255
+ Args:
256
+ x (`torch.tensor`, *required*):
257
+ input tensor
258
+ residual (`torch.tensor`, *required*):
259
+ residual tensor
260
+ prob (`float`, *required*):
261
+ dropout probability
262
+ training (`bool`, *required*):
263
+ training mode
264
+ """
265
+ out = F.dropout(x, p=prob, training=training)
266
+ out = residual + out
267
+ return out
268
+
269
+
270
+ def telechat_gelu_forward(x: torch.Tensor) -> torch.Tensor:
271
+ """
272
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
273
+ make the model jitable.
274
+
275
+ Args:
276
+ x (`torch.tensor`, *required*):
277
+ input hidden states
278
+ """
279
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
280
+
281
+
282
+ def telechat_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
283
+ """
284
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
285
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
286
+
287
+ Args:
288
+ g (`torch.tensor`, *required*):
289
+ gradient output tensor
290
+ x (`torch.tensor`, *required*):
291
+ input tensor
292
+ """
293
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
294
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
295
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
296
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
297
+ return ff * g
298
+
299
+
300
+ class GeLUFunction(torch.autograd.Function):
301
+ @staticmethod
302
+ def forward(ctx, input: torch.Tensor) -> torch.Tensor:
303
+ ctx.save_for_backward(input)
304
+ return telechat_gelu_forward(input)
305
+
306
+ @staticmethod
307
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
308
+ input = ctx.saved_tensors
309
+ tmp = telechat_gelu_back(grad_output, input)
310
+ return tmp
311
+
312
+
313
+ class TelechatGelu(nn.Module):
314
+ """
315
+ TelechatBiasGelu wrapper function that make use of the simple function on inference mode to make the model
316
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
317
+ copied from Megatron-DeepSpeed code and adapted for our needs
318
+
319
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
320
+ """
321
+
322
+ def __init__(self):
323
+ super().__init__()
324
+
325
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
326
+ if self.training:
327
+ return GeLUFunction.apply(x)
328
+ else:
329
+ return telechat_gelu_forward(x)
330
+
331
+
332
+ class TelechatAttention(nn.Module):
333
+ def __init__(self, config: Telechat2Config, layer_idx):
334
+ super().__init__()
335
+ self.kv_cache = None
336
+ self.layer_idx = layer_idx
337
+
338
+ self.hidden_size = config.hidden_size
339
+ self.num_heads = config.n_head
340
+ self.head_dim = self.hidden_size // self.num_heads
341
+ self.split_size = self.hidden_size
342
+ self.hidden_dropout = config.hidden_dropout
343
+ self.config = config
344
+
345
+ if self.head_dim * self.num_heads != self.hidden_size:
346
+ raise ValueError(
347
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
348
+ f" {self.num_heads})."
349
+ )
350
+
351
+ # Layer-wise attention scaling
352
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
353
+ self.beta = 1.0
354
+
355
+ self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads else self.num_heads
356
+ self.kv_projection_size = self.head_dim * self.num_key_value_heads
357
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
358
+ self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
359
+ self.key_value = nn.Linear(self.hidden_size, self.kv_projection_size * 2, bias=False)
360
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
361
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
362
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config=config)
363
+
364
+ self.core_attention_flash = FlashSelfAttention(
365
+ causal=True, attention_dropout=config.attention_dropout
366
+ )
367
+
368
+ self.last_key_layer = None
369
+ # logn_list = [math.log(i, 4096) if i > 4096 else 1 for i in range(1, 32768)]
370
+ # self.logn_tensor = torch.tensor(logn_list)[None, :, None, None].half().cuda()
371
+
372
+ def repeat_kv(self, hidden_states, n_rep):
373
+ slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
374
+ if n_rep == 1:
375
+ return hidden_states
376
+ hidden_states = hidden_states[:, :, :, None, :].expand(slen, batch, num_key_value_heads_per_partition, n_rep,
377
+ head_dim)
378
+ return hidden_states.reshape(slen, batch, num_key_value_heads_per_partition * n_rep, head_dim)
379
+
380
+ def split_tensor_along_last_dim(self,
381
+ tensor: torch.Tensor,
382
+ num_partitions: int,
383
+ contiguous_split_chunks: bool = False,
384
+ ):
385
+
386
+ # Get the size and dimension.
387
+ last_dim = tensor.dim() - 1
388
+ last_dim_size = tensor.size()[last_dim] // num_partitions
389
+ # Split.
390
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
391
+ # Note: torch.split does not create contiguous tensors by default.
392
+ if contiguous_split_chunks:
393
+ return tuple(chunk.contiguous() for chunk in tensor_list)
394
+
395
+ return tensor_list
396
+
397
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
398
+ batch_size_and_num_heads, seq_length, _ = x.shape
399
+ batch_size = batch_size_and_num_heads // self.num_heads
400
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
401
+ x = x.permute(0, 2, 1, 3)
402
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ residual: torch.Tensor,
408
+ attention_mask: torch.Tensor,
409
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
410
+ use_cache: bool = False,
411
+ output_attentions: bool = False,
412
+ ):
413
+ hidden_states = hidden_states.transpose(1, 0)
414
+ query_layer = self.query(hidden_states)
415
+ new_tensor_shape = query_layer.size()[:-1] + \
416
+ (self.num_heads,
417
+ self.head_dim)
418
+ query_layer = query_layer.view(*new_tensor_shape)
419
+
420
+ mixed_kv_layer = self.key_value(hidden_states)
421
+ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
422
+ (self.num_key_value_heads,
423
+ 2 * self.head_dim)
424
+ mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
425
+ (key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_kv_layer, 2)
426
+
427
+ output_size = (query_layer.size(1),
428
+ query_layer.size(2),
429
+ query_layer.size(0),
430
+ key_layer.size(0),
431
+ key_layer.size(2)
432
+ )
433
+
434
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
435
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[4], -1)
436
+
437
+ apply_rotary_fn = apply_rotary_pos_emb_torch
438
+
439
+ seq_len = key_layer.shape[0]
440
+ offset = 0
441
+
442
+ if use_cache and layer_past != None:
443
+ past_key, past_value = layer_past
444
+ offset = past_key.shape[0]
445
+ seq_len += offset
446
+
447
+ cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
448
+
449
+ query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
450
+ if use_cache:
451
+ if layer_past != None:
452
+ past_key, past_value = layer_past
453
+ key_layer = torch.cat((past_key, key_layer[-1, ...].unsqueeze(0)), dim=0)
454
+ value_layer = torch.cat((past_value, value_layer[-1, ...].unsqueeze(0)), dim=0)
455
+ layer_past = key_layer, value_layer
456
+
457
+ s_value, bz, kv_head, dim = value_layer.shape
458
+ s_key = key_layer.shape[0]
459
+ s_query = query_layer.shape[0]
460
+ q_head = output_size[1]
461
+
462
+ query_layer = query_layer.reshape((s_query, bz, q_head, dim))
463
+ key_layer = key_layer.reshape((s_key, bz, kv_head, dim))
464
+
465
+ key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
466
+ value_layer = self.repeat_kv(value_layer, self.num_key_value_groups)
467
+
468
+ if self.config.flash_attn:
469
+ q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in
470
+ (query_layer, key_layer, value_layer)]
471
+ context_layer = self.core_attention_flash(q, k, v)
472
+ context_layer = rearrange(context_layer, 'b s h d -> b s (h d)').contiguous()
473
+ else:
474
+ ##[sq, b, np, hn] -> [sq, b * np, hn]
475
+ query_layer = query_layer.reshape(s_query, bz * self.num_heads, dim)
476
+ # [sk, b, np, hn] -> [sk, b * np, hn]
477
+ key_layer = key_layer.reshape(s_key, bz * self.num_heads, dim)
478
+ matmul_result = self.inv_norm_factor * torch.einsum('bik,bkj->bij', query_layer.transpose(0, 1),
479
+ key_layer.transpose(0, 1).transpose(1, 2))
480
+
481
+ attention_scores = matmul_result.view(bz, self.num_heads, s_query, s_key)
482
+
483
+ input_dtype = attention_scores.dtype
484
+ if input_dtype == torch.float16:
485
+ attention_scores = attention_scores.to(torch.float)
486
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
487
+ attention_probs = F.softmax(attn_weights, dim=-1).to(input_dtype) ##dtype = torch.float32
488
+ attention_probs = self.attention_dropout(attention_probs)
489
+ attention_probs_reshaped = attention_probs.view(bz * self.num_heads, s_query, s_key)
490
+
491
+ value_layer = value_layer.reshape(s_key, bz * self.num_heads, dim)
492
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
493
+ context_layer = self._merge_heads(context_layer)
494
+ output_tensor = self.dense(context_layer)
495
+
496
+ output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
497
+ present = None
498
+ outputs = (output_tensor, present)
499
+ if output_attentions:
500
+ outputs += (attention_probs,)
501
+
502
+ return output_tensor, layer_past
503
+
504
+
505
+ class TelechatMLP(nn.Module):
506
+ def __init__(self, config: Telechat2Config):
507
+ super().__init__()
508
+ hidden_size = config.hidden_size
509
+ self.gate_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
510
+ self.up_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
511
+ self.down_proj = nn.Linear(config.ffn_hidden_size, hidden_size, bias=True)
512
+ self.hidden_dropout = config.hidden_dropout
513
+
514
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
515
+ intermediate_output = self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
516
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
517
+ return output
518
+
519
+
520
+ class TelechatBlock(nn.Module):
521
+ def __init__(self, config: Telechat2Config, layer_idx):
522
+ super().__init__()
523
+ hidden_size = config.hidden_size
524
+
525
+ self.input_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
526
+ self.num_heads = config.n_head
527
+ self.layer_idx = layer_idx
528
+ self.self_attention = TelechatAttention(config, layer_idx)
529
+ self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
530
+
531
+ self.mlp = TelechatMLP(config)
532
+
533
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
534
+ self.hidden_dropout = config.hidden_dropout
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.Tensor,
539
+ attention_mask: torch.Tensor,
540
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
541
+ use_cache: bool = False,
542
+ output_attentions: bool = False,
543
+ ):
544
+ layernorm_output = self.input_layernorm(hidden_states)
545
+ if self.apply_residual_connection_post_layernorm:
546
+ residual = layernorm_output
547
+ else:
548
+ residual = hidden_states
549
+
550
+ attn_outputs = self.self_attention(
551
+ layernorm_output,
552
+ residual,
553
+ layer_past=layer_past,
554
+ attention_mask=attention_mask,
555
+ use_cache=use_cache,
556
+ output_attentions=output_attentions,
557
+ )
558
+
559
+ attention_output = attn_outputs[0]
560
+ outputs = attn_outputs[1:]
561
+ layernorm_output = self.post_attention_layernorm(attention_output)
562
+
563
+ if self.apply_residual_connection_post_layernorm:
564
+ residual = layernorm_output
565
+ else:
566
+ residual = attention_output
567
+ output = self.mlp(layernorm_output, residual)
568
+
569
+ if use_cache:
570
+ outputs = (output,) + outputs
571
+ else:
572
+ outputs = (output,) + outputs[1:]
573
+
574
+ return outputs
575
+
576
+
577
+ class TelechatPreTrainedModel(PreTrainedModel):
578
+ config_class = Telechat2Config
579
+ base_model_prefix = "transformer"
580
+ supports_gradient_checkpointing = True
581
+ _no_split_modules = ["TelechatBlock"]
582
+ _skip_keys_device_placement = "past_key_values"
583
+
584
+ def __init__(self, *inputs, **kwargs):
585
+ super().__init__(*inputs, **kwargs)
586
+
587
+ def _init_weights(self, module: nn.Module):
588
+ """Initialize the weights."""
589
+ if isinstance(module, nn.Linear):
590
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
591
+ if module.bias is not None:
592
+ module.bias.data.zero_()
593
+
594
+ elif isinstance(module, nn.Embedding):
595
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
596
+ if module.padding_idx is not None:
597
+ module.weight.data[module.padding_idx].zero_()
598
+
599
+ elif isinstance(module, LayerNorm):
600
+ module.bias.data.zero_()
601
+ module.weight.data.fill_(1.0)
602
+
603
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
604
+ if isinstance(module, TelechatModel):
605
+ module.gradient_checkpointing = value
606
+
607
+
608
+ class TelechatModel(TelechatPreTrainedModel):
609
+ def __init__(self, config: Telechat2Config):
610
+ super().__init__(config)
611
+
612
+ self.embed_dim = config.hidden_size
613
+ self.num_heads = config.n_head
614
+ self.config = config
615
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
616
+ if self.config.embed_layernorm:
617
+ self.word_embeddings_layernorm = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
618
+
619
+ self.h = nn.ModuleList([TelechatBlock(config, _) for _ in range(config.num_hidden_layers)])
620
+ self.ln_f = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
621
+ self.gradient_checkpointing = False
622
+ self.post_init()
623
+
624
+ def get_input_embeddings(self):
625
+ return self.word_embeddings
626
+
627
+ def _prepare_attn_mask(
628
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
629
+ ) -> torch.BoolTensor:
630
+ combined_attention_mask = None
631
+ device = attention_mask.device
632
+ _, src_length = input_shape
633
+
634
+ if src_length > 1:
635
+ combined_attention_mask = _make_causal_mask(
636
+ input_shape, device=device, past_key_values_length=past_key_values_length
637
+ )
638
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
639
+ combined_attention_mask = (
640
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
641
+ )
642
+
643
+ return combined_attention_mask
644
+
645
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
646
+ self.word_embeddings = new_embeddings
647
+
648
+ def forward(
649
+ self,
650
+ input_ids: Optional[torch.LongTensor] = None,
651
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ inputs_embeds: Optional[torch.LongTensor] = None,
654
+ use_cache: Optional[bool] = None,
655
+ output_attentions: Optional[bool] = None,
656
+ output_hidden_states: Optional[bool] = None,
657
+ return_dict: Optional[bool] = None,
658
+ **deprecated_arguments,
659
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
660
+
661
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
+ output_hidden_states = (
663
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
664
+ )
665
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
666
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
667
+
668
+ if input_ids is not None:
669
+ batch_size, seq_length = input_ids.shape
670
+ elif inputs_embeds is not None:
671
+ batch_size, seq_length, _ = inputs_embeds.shape
672
+
673
+ if past_key_values is None:
674
+ past_key_values = tuple([None] * len(self.h))
675
+ # input_ids = torch.load("Megatron-LM-0624-3B/tensors/input_ids.pt").to(input_ids.device)
676
+ if inputs_embeds is None:
677
+ inputs_embeds = self.word_embeddings(input_ids)
678
+ hidden_states = inputs_embeds
679
+ # print(f"[INFO_Telechat]: inputs_embeds={inputs_embeds}")
680
+ if self.config.embed_layernorm:
681
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
682
+
683
+ presents = () if use_cache else None
684
+ all_self_attentions = () if output_attentions else None
685
+ all_hidden_states = () if output_hidden_states else None
686
+
687
+ if self.gradient_checkpointing and self.training:
688
+ if use_cache:
689
+ use_cache = False
690
+
691
+ seq_length_with_past = seq_length
692
+ past_key_values_length = 0
693
+ if past_key_values[0] is not None:
694
+ past_key_values_length = past_key_values[0][0].shape[2]
695
+ seq_length_with_past = seq_length_with_past + past_key_values_length
696
+ if attention_mask is None:
697
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
698
+ else:
699
+ attention_mask = attention_mask.to(hidden_states.device)
700
+ causal_mask = self._prepare_attn_mask(
701
+ attention_mask,
702
+ input_shape=(batch_size, seq_length),
703
+ past_key_values_length=past_key_values_length,
704
+ )
705
+
706
+ # print(f"[INFO_Telechat]: word_embeddings_layernorm={hidden_states}")
707
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
708
+ if output_hidden_states:
709
+ all_hidden_states = all_hidden_states + (hidden_states,)
710
+
711
+ if self.gradient_checkpointing and self.training:
712
+
713
+ def create_custom_forward(module):
714
+ def custom_forward(*inputs):
715
+ # None for past_key_value
716
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
717
+
718
+ return custom_forward
719
+
720
+ outputs = torch.utils.checkpoint.checkpoint(
721
+ create_custom_forward(block),
722
+ hidden_states,
723
+ causal_mask,
724
+ layer_past,
725
+ )
726
+ else:
727
+ outputs = block(
728
+ hidden_states,
729
+ layer_past=layer_past,
730
+ attention_mask=causal_mask,
731
+ use_cache=use_cache,
732
+ output_attentions=output_attentions,
733
+ )
734
+
735
+ # print(f"[INFO_Telechat]: outputs{i}={outputs}")
736
+ hidden_states = outputs[0]
737
+ if use_cache is True:
738
+ presents = presents + (outputs[1],)
739
+
740
+ if output_attentions:
741
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
742
+ hidden_states = self.ln_f(hidden_states)
743
+ # print(f"[INFO_Telechat]: hidden_states={hidden_states}")
744
+ # ref = torch.load("Megatron-LM-0624-3B/tensors/final_layernorm.pt")
745
+ # print(hidden_states.squeeze()[2048:])
746
+ # print(ref.squeeze())
747
+ # print(torch.max(hidden_states.squeeze()[2048:] - ref.squeeze().to(hidden_states.device)))
748
+ # exit()
749
+ # print(ref.shape,hidden_states.shape)
750
+ # print(hidden_states)
751
+ # exit()
752
+ if output_hidden_states:
753
+ all_hidden_states = all_hidden_states + (hidden_states,)
754
+ if not return_dict:
755
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
756
+ return BaseModelOutputWithPastAndCrossAttentions(
757
+ last_hidden_state=hidden_states,
758
+ past_key_values=presents,
759
+ hidden_states=all_hidden_states,
760
+ attentions=all_self_attentions,
761
+ )
762
+
763
+
764
+ class Telechat2ForCausalLM(TelechatPreTrainedModel):
765
+ # _tied_weights_keys = ["lm_head.weight"]
766
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
767
+
768
+ def __init__(self, config: Telechat2Config):
769
+ super().__init__(config)
770
+ self.transformer = TelechatModel(config)
771
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
772
+ self.post_init()
773
+
774
+ def get_output_embeddings(self):
775
+ return self.lm_head
776
+
777
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
778
+ self.lm_head = new_embeddings
779
+
780
+ def prepare_inputs_for_generation(
781
+ self,
782
+ input_ids: torch.LongTensor,
783
+ past_key_values: Optional[torch.Tensor] = None,
784
+ attention_mask: Optional[torch.Tensor] = None,
785
+ inputs_embeds: Optional[torch.Tensor] = None,
786
+ **kwargs,
787
+ ) -> dict:
788
+ if past_key_values:
789
+ input_ids = input_ids[:, -1].unsqueeze(-1)
790
+ if inputs_embeds is not None and past_key_values is None:
791
+ model_inputs = {"inputs_embeds": inputs_embeds}
792
+ else:
793
+ model_inputs = {"input_ids": input_ids}
794
+
795
+ model_inputs.update(
796
+ {
797
+ "past_key_values": past_key_values,
798
+ "use_cache": kwargs.get("use_cache"),
799
+ "attention_mask": attention_mask,
800
+ }
801
+ )
802
+ return model_inputs
803
+
804
+ def forward(
805
+ self,
806
+ input_ids: Optional[torch.LongTensor] = None,
807
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
808
+ attention_mask: Optional[torch.Tensor] = None,
809
+ inputs_embeds: Optional[torch.Tensor] = None,
810
+ labels: Optional[torch.Tensor] = None,
811
+ use_cache: Optional[bool] = None,
812
+ output_attentions: Optional[bool] = None,
813
+ output_hidden_states: Optional[bool] = None,
814
+ return_dict: Optional[bool] = None,
815
+ **deprecated_arguments,
816
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
817
+
818
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
819
+
820
+ transformer_outputs = self.transformer(
821
+ input_ids,
822
+ past_key_values=past_key_values,
823
+ attention_mask=attention_mask,
824
+ inputs_embeds=inputs_embeds,
825
+ use_cache=use_cache,
826
+ output_attentions=output_attentions,
827
+ output_hidden_states=output_hidden_states,
828
+ return_dict=return_dict,
829
+ )
830
+ hidden_states = transformer_outputs[0]
831
+ lm_logits = self.lm_head(hidden_states)
832
+
833
+ loss = None
834
+ if labels is not None:
835
+ labels = labels.to(lm_logits.device)
836
+ shift_logits = lm_logits[..., :-1, :].contiguous()
837
+ shift_labels = labels[..., 1:].contiguous()
838
+ batch_size, seq_length, vocab_size = shift_logits.shape
839
+ loss_fct = CrossEntropyLoss()
840
+ loss = loss_fct(
841
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
842
+ )
843
+
844
+ if not return_dict:
845
+ output = (lm_logits,) + transformer_outputs[1:]
846
+ return ((loss,) + output) if loss is not None else output
847
+
848
+ return CausalLMOutputWithCrossAttentions(
849
+ loss=loss,
850
+ logits=lm_logits,
851
+ past_key_values=transformer_outputs.past_key_values,
852
+ hidden_states=transformer_outputs.hidden_states,
853
+ attentions=transformer_outputs.attentions,
854
+ )
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15234703360
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model_00004-of-00004.bin",
7
+ "transformer.h.0.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
8
+ "transformer.h.0.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
9
+ "transformer.h.0.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
10
+ "transformer.h.0.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
11
+ "transformer.h.0.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
12
+ "transformer.h.0.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
13
+ "transformer.h.0.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
14
+ "transformer.h.0.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
15
+ "transformer.h.0.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
16
+ "transformer.h.0.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
17
+ "transformer.h.1.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
18
+ "transformer.h.1.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
19
+ "transformer.h.1.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
20
+ "transformer.h.1.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
21
+ "transformer.h.1.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
22
+ "transformer.h.1.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
23
+ "transformer.h.1.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
24
+ "transformer.h.1.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
25
+ "transformer.h.1.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
26
+ "transformer.h.1.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
27
+ "transformer.h.10.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
28
+ "transformer.h.10.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
29
+ "transformer.h.10.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
30
+ "transformer.h.10.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
31
+ "transformer.h.10.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
32
+ "transformer.h.10.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
33
+ "transformer.h.10.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
34
+ "transformer.h.10.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
35
+ "transformer.h.10.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
36
+ "transformer.h.10.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
37
+ "transformer.h.11.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
38
+ "transformer.h.11.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
39
+ "transformer.h.11.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
40
+ "transformer.h.11.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
41
+ "transformer.h.11.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
42
+ "transformer.h.11.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
43
+ "transformer.h.11.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
44
+ "transformer.h.11.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
45
+ "transformer.h.11.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
46
+ "transformer.h.11.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
47
+ "transformer.h.12.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
48
+ "transformer.h.12.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
49
+ "transformer.h.12.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
50
+ "transformer.h.12.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
51
+ "transformer.h.12.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
52
+ "transformer.h.12.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
53
+ "transformer.h.12.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
54
+ "transformer.h.12.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
55
+ "transformer.h.12.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
56
+ "transformer.h.12.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
57
+ "transformer.h.13.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
58
+ "transformer.h.13.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
59
+ "transformer.h.13.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
60
+ "transformer.h.13.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
61
+ "transformer.h.13.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
62
+ "transformer.h.13.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
63
+ "transformer.h.13.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
64
+ "transformer.h.13.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
65
+ "transformer.h.13.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
66
+ "transformer.h.13.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
67
+ "transformer.h.14.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
68
+ "transformer.h.14.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
69
+ "transformer.h.14.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
70
+ "transformer.h.14.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
71
+ "transformer.h.14.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
72
+ "transformer.h.14.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
73
+ "transformer.h.14.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
74
+ "transformer.h.14.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
75
+ "transformer.h.14.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
76
+ "transformer.h.14.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
77
+ "transformer.h.15.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
78
+ "transformer.h.15.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
79
+ "transformer.h.15.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
80
+ "transformer.h.15.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
81
+ "transformer.h.15.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
82
+ "transformer.h.15.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
83
+ "transformer.h.15.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
84
+ "transformer.h.15.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
85
+ "transformer.h.15.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
86
+ "transformer.h.15.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
87
+ "transformer.h.16.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
88
+ "transformer.h.16.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
89
+ "transformer.h.16.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
90
+ "transformer.h.16.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
91
+ "transformer.h.16.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
92
+ "transformer.h.16.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
93
+ "transformer.h.16.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
94
+ "transformer.h.16.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
95
+ "transformer.h.16.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
96
+ "transformer.h.16.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
97
+ "transformer.h.17.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
98
+ "transformer.h.17.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
99
+ "transformer.h.17.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
100
+ "transformer.h.17.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
101
+ "transformer.h.17.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
102
+ "transformer.h.17.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
103
+ "transformer.h.17.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
104
+ "transformer.h.17.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
105
+ "transformer.h.17.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
106
+ "transformer.h.17.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
107
+ "transformer.h.18.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
108
+ "transformer.h.18.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
109
+ "transformer.h.18.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
110
+ "transformer.h.18.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
111
+ "transformer.h.18.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
112
+ "transformer.h.18.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
113
+ "transformer.h.18.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
114
+ "transformer.h.18.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
115
+ "transformer.h.18.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
116
+ "transformer.h.18.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
117
+ "transformer.h.19.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
118
+ "transformer.h.19.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
119
+ "transformer.h.19.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
120
+ "transformer.h.19.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
121
+ "transformer.h.19.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
122
+ "transformer.h.19.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
123
+ "transformer.h.19.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
124
+ "transformer.h.19.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
125
+ "transformer.h.19.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
126
+ "transformer.h.19.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
127
+ "transformer.h.2.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
128
+ "transformer.h.2.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
129
+ "transformer.h.2.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
130
+ "transformer.h.2.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
131
+ "transformer.h.2.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
132
+ "transformer.h.2.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
133
+ "transformer.h.2.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
134
+ "transformer.h.2.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
135
+ "transformer.h.2.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
136
+ "transformer.h.2.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
137
+ "transformer.h.20.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
138
+ "transformer.h.20.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
139
+ "transformer.h.20.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
140
+ "transformer.h.20.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
141
+ "transformer.h.20.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
142
+ "transformer.h.20.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
143
+ "transformer.h.20.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
144
+ "transformer.h.20.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
145
+ "transformer.h.20.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
146
+ "transformer.h.20.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
147
+ "transformer.h.21.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
148
+ "transformer.h.21.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
149
+ "transformer.h.21.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
150
+ "transformer.h.21.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
151
+ "transformer.h.21.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
152
+ "transformer.h.21.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
153
+ "transformer.h.21.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
154
+ "transformer.h.21.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
155
+ "transformer.h.21.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
156
+ "transformer.h.21.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
157
+ "transformer.h.22.input_layernorm.weight": "pytorch_model_00003-of-00004.bin",
158
+ "transformer.h.22.mlp.down_proj.bias": "pytorch_model_00003-of-00004.bin",
159
+ "transformer.h.22.mlp.down_proj.weight": "pytorch_model_00003-of-00004.bin",
160
+ "transformer.h.22.mlp.gate_proj.weight": "pytorch_model_00003-of-00004.bin",
161
+ "transformer.h.22.mlp.up_proj.weight": "pytorch_model_00003-of-00004.bin",
162
+ "transformer.h.22.post_attention_layernorm.weight": "pytorch_model_00003-of-00004.bin",
163
+ "transformer.h.22.self_attention.dense.bias": "pytorch_model_00003-of-00004.bin",
164
+ "transformer.h.22.self_attention.dense.weight": "pytorch_model_00003-of-00004.bin",
165
+ "transformer.h.22.self_attention.key_value.weight": "pytorch_model_00003-of-00004.bin",
166
+ "transformer.h.22.self_attention.query.weight": "pytorch_model_00003-of-00004.bin",
167
+ "transformer.h.23.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
168
+ "transformer.h.23.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
169
+ "transformer.h.23.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
170
+ "transformer.h.23.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
171
+ "transformer.h.23.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
172
+ "transformer.h.23.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
173
+ "transformer.h.23.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
174
+ "transformer.h.23.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
175
+ "transformer.h.23.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
176
+ "transformer.h.23.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
177
+ "transformer.h.24.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
178
+ "transformer.h.24.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
179
+ "transformer.h.24.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
180
+ "transformer.h.24.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
181
+ "transformer.h.24.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
182
+ "transformer.h.24.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
183
+ "transformer.h.24.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
184
+ "transformer.h.24.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
185
+ "transformer.h.24.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
186
+ "transformer.h.24.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
187
+ "transformer.h.25.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
188
+ "transformer.h.25.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
189
+ "transformer.h.25.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
190
+ "transformer.h.25.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
191
+ "transformer.h.25.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
192
+ "transformer.h.25.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
193
+ "transformer.h.25.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
194
+ "transformer.h.25.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
195
+ "transformer.h.25.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
196
+ "transformer.h.25.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
197
+ "transformer.h.26.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
198
+ "transformer.h.26.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
199
+ "transformer.h.26.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
200
+ "transformer.h.26.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
201
+ "transformer.h.26.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
202
+ "transformer.h.26.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
203
+ "transformer.h.26.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
204
+ "transformer.h.26.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
205
+ "transformer.h.26.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
206
+ "transformer.h.26.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
207
+ "transformer.h.27.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
208
+ "transformer.h.27.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
209
+ "transformer.h.27.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
210
+ "transformer.h.27.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
211
+ "transformer.h.27.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
212
+ "transformer.h.27.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
213
+ "transformer.h.27.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
214
+ "transformer.h.27.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
215
+ "transformer.h.27.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
216
+ "transformer.h.27.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
217
+ "transformer.h.28.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
218
+ "transformer.h.28.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
219
+ "transformer.h.28.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
220
+ "transformer.h.28.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
221
+ "transformer.h.28.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
222
+ "transformer.h.28.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
223
+ "transformer.h.28.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
224
+ "transformer.h.28.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
225
+ "transformer.h.28.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
226
+ "transformer.h.28.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
227
+ "transformer.h.29.input_layernorm.weight": "pytorch_model_00004-of-00004.bin",
228
+ "transformer.h.29.mlp.down_proj.bias": "pytorch_model_00004-of-00004.bin",
229
+ "transformer.h.29.mlp.down_proj.weight": "pytorch_model_00004-of-00004.bin",
230
+ "transformer.h.29.mlp.gate_proj.weight": "pytorch_model_00004-of-00004.bin",
231
+ "transformer.h.29.mlp.up_proj.weight": "pytorch_model_00004-of-00004.bin",
232
+ "transformer.h.29.post_attention_layernorm.weight": "pytorch_model_00004-of-00004.bin",
233
+ "transformer.h.29.self_attention.dense.bias": "pytorch_model_00004-of-00004.bin",
234
+ "transformer.h.29.self_attention.dense.weight": "pytorch_model_00004-of-00004.bin",
235
+ "transformer.h.29.self_attention.key_value.weight": "pytorch_model_00004-of-00004.bin",
236
+ "transformer.h.29.self_attention.query.weight": "pytorch_model_00004-of-00004.bin",
237
+ "transformer.h.3.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
238
+ "transformer.h.3.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
239
+ "transformer.h.3.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
240
+ "transformer.h.3.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
241
+ "transformer.h.3.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
242
+ "transformer.h.3.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
243
+ "transformer.h.3.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
244
+ "transformer.h.3.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
245
+ "transformer.h.3.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
246
+ "transformer.h.3.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
247
+ "transformer.h.4.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
248
+ "transformer.h.4.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
249
+ "transformer.h.4.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
250
+ "transformer.h.4.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
251
+ "transformer.h.4.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
252
+ "transformer.h.4.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
253
+ "transformer.h.4.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
254
+ "transformer.h.4.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
255
+ "transformer.h.4.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
256
+ "transformer.h.4.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
257
+ "transformer.h.5.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
258
+ "transformer.h.5.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
259
+ "transformer.h.5.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
260
+ "transformer.h.5.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
261
+ "transformer.h.5.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
262
+ "transformer.h.5.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
263
+ "transformer.h.5.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
264
+ "transformer.h.5.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
265
+ "transformer.h.5.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
266
+ "transformer.h.5.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
267
+ "transformer.h.6.input_layernorm.weight": "pytorch_model_00001-of-00004.bin",
268
+ "transformer.h.6.mlp.down_proj.bias": "pytorch_model_00001-of-00004.bin",
269
+ "transformer.h.6.mlp.down_proj.weight": "pytorch_model_00001-of-00004.bin",
270
+ "transformer.h.6.mlp.gate_proj.weight": "pytorch_model_00001-of-00004.bin",
271
+ "transformer.h.6.mlp.up_proj.weight": "pytorch_model_00001-of-00004.bin",
272
+ "transformer.h.6.post_attention_layernorm.weight": "pytorch_model_00001-of-00004.bin",
273
+ "transformer.h.6.self_attention.dense.bias": "pytorch_model_00001-of-00004.bin",
274
+ "transformer.h.6.self_attention.dense.weight": "pytorch_model_00001-of-00004.bin",
275
+ "transformer.h.6.self_attention.key_value.weight": "pytorch_model_00001-of-00004.bin",
276
+ "transformer.h.6.self_attention.query.weight": "pytorch_model_00001-of-00004.bin",
277
+ "transformer.h.7.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
278
+ "transformer.h.7.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
279
+ "transformer.h.7.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
280
+ "transformer.h.7.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
281
+ "transformer.h.7.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
282
+ "transformer.h.7.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
283
+ "transformer.h.7.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
284
+ "transformer.h.7.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
285
+ "transformer.h.7.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
286
+ "transformer.h.7.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
287
+ "transformer.h.8.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
288
+ "transformer.h.8.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
289
+ "transformer.h.8.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
290
+ "transformer.h.8.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
291
+ "transformer.h.8.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
292
+ "transformer.h.8.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
293
+ "transformer.h.8.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
294
+ "transformer.h.8.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
295
+ "transformer.h.8.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
296
+ "transformer.h.8.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
297
+ "transformer.h.9.input_layernorm.weight": "pytorch_model_00002-of-00004.bin",
298
+ "transformer.h.9.mlp.down_proj.bias": "pytorch_model_00002-of-00004.bin",
299
+ "transformer.h.9.mlp.down_proj.weight": "pytorch_model_00002-of-00004.bin",
300
+ "transformer.h.9.mlp.gate_proj.weight": "pytorch_model_00002-of-00004.bin",
301
+ "transformer.h.9.mlp.up_proj.weight": "pytorch_model_00002-of-00004.bin",
302
+ "transformer.h.9.post_attention_layernorm.weight": "pytorch_model_00002-of-00004.bin",
303
+ "transformer.h.9.self_attention.dense.bias": "pytorch_model_00002-of-00004.bin",
304
+ "transformer.h.9.self_attention.dense.weight": "pytorch_model_00002-of-00004.bin",
305
+ "transformer.h.9.self_attention.key_value.weight": "pytorch_model_00002-of-00004.bin",
306
+ "transformer.h.9.self_attention.query.weight": "pytorch_model_00002-of-00004.bin",
307
+ "transformer.ln_f.weight": "pytorch_model_00004-of-00004.bin",
308
+ "transformer.word_embeddings.weight": "pytorch_model_00001-of-00004.bin"
309
+ }
310
+ }
tokenization_telechat2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ import sentencepiece as spm
5
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
11
+
12
+ # TODO: when we get download url from huggingface, refresh the map
13
+ PRETRAINED_VOCAB_FILES_MAP = {
14
+ "vocab_file": {},
15
+ "tokenizer_file": {},
16
+ }
17
+
18
+
19
+ class Telechat2Tokenizer(PreTrainedTokenizer):
20
+ vocab_files_names = VOCAB_FILES_NAMES
21
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
22
+ model_input_names = ["input_ids", "attention_mask"]
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_file,
27
+ unk_token="<unk>",
28
+ bos_token="<_start>",
29
+ eos_token="<_end>",
30
+ pad_token="<_pad>",
31
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
32
+ add_bos_token=True,
33
+ add_eos_token=False,
34
+ clean_up_tokenization_spaces=False,
35
+ **kwargs,
36
+ ):
37
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
38
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
39
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
40
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
41
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
42
+ self.sp_model.Load(vocab_file)
43
+ super().__init__(
44
+ bos_token=bos_token,
45
+ eos_token=eos_token,
46
+ pad_token=pad_token,
47
+ add_bos_token=add_bos_token,
48
+ add_eos_token=add_eos_token,
49
+ sp_model_kwargs=self.sp_model_kwargs,
50
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
51
+ **kwargs,
52
+ )
53
+ self.vocab_file = vocab_file
54
+ self.add_bos_token = add_bos_token
55
+ self.add_eos_token = add_eos_token
56
+
57
+ def __getstate__(self):
58
+ state = self.__dict__.copy()
59
+ state["sp_model"] = None
60
+ return state
61
+
62
+ def __setstate__(self, d):
63
+ self.__dict__ = d
64
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
65
+ self.sp_model.Load(self.vocab_file)
66
+
67
+ @property
68
+ def vocab_size(self):
69
+ """Returns vocab size"""
70
+ return self.sp_model.get_piece_size()
71
+
72
+ def get_vocab(self):
73
+ """Returns vocab as a dict"""
74
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
75
+ vocab.update(self.added_tokens_encoder)
76
+ return vocab
77
+
78
+ @property
79
+ def vocab(self):
80
+ return self.get_vocab()
81
+
82
+ def _tokenize(self, text):
83
+ """Returns a tokenized string."""
84
+ return self.sp_model.encode(text, out_type=str)
85
+
86
+ def _convert_token_to_id(self, token):
87
+ """Converts a token (str) in an id using the vocab."""
88
+ return self.sp_model.piece_to_id(token)
89
+
90
+ def _convert_id_to_token(self, index):
91
+ """Converts an index (integer) in a token (str) using the vocab."""
92
+ token = self.sp_model.IdToPiece(index)
93
+ return token
94
+
95
+ def convert_tokens_to_string(self, tokens):
96
+ """Converts a sequence of tokens (string) in a single string."""
97
+ current_sub_tokens = []
98
+ out_string = ""
99
+ # prev_is_special = False
100
+ for i, token in enumerate(tokens):
101
+ # make sure that special tokens are not decoded using sentencepiece model
102
+ if token in self.all_special_tokens:
103
+ # if not prev_is_special and i != 0:
104
+ # out_string += " "
105
+ out_string += self.sp_model.decode(current_sub_tokens) + token
106
+ # prev_is_special = True
107
+ current_sub_tokens = []
108
+ else:
109
+ current_sub_tokens.append(token)
110
+ # prev_is_special = False
111
+ out_string += self.sp_model.decode(current_sub_tokens)
112
+ return out_string
113
+
114
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
115
+ """
116
+ Save the vocabulary and special tokens file to a directory.
117
+
118
+ Args:
119
+ save_directory (`str`):
120
+ The directory in which to save the vocabulary.
121
+
122
+ Returns:
123
+ `Tuple(str)`: Paths to the files saved.
124
+ """
125
+ if not os.path.isdir(save_directory):
126
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
127
+ return
128
+ out_vocab_file = os.path.join(
129
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
130
+ )
131
+
132
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
133
+ copyfile(self.vocab_file, out_vocab_file)
134
+ elif not os.path.isfile(self.vocab_file):
135
+ with open(out_vocab_file, "wb") as fi:
136
+ content_spiece_model = self.sp_model.serialized_model_proto()
137
+ fi.write(content_spiece_model)
138
+
139
+ return (out_vocab_file,)
140
+
141
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
142
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
143
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
144
+
145
+ output = bos_token_id + token_ids_0 + eos_token_id
146
+
147
+ if token_ids_1 is not None:
148
+ output = output + bos_token_id + token_ids_1 + eos_token_id
149
+
150
+ return output
151
+
152
+ def get_special_tokens_mask(
153
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
154
+ already_has_special_tokens: bool = False
155
+ ) -> List[int]:
156
+ """
157
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
158
+ special tokens using the tokenizer `prepare_for_model` method.
159
+
160
+ Args:
161
+ token_ids_0 (`List[int]`):
162
+ List of IDs.
163
+ token_ids_1 (`List[int]`, *optional*):
164
+ Optional second list of IDs for sequence pairs.
165
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
166
+ Whether or not the token list is already formatted with special tokens for the model.
167
+
168
+ Returns:
169
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
170
+ """
171
+ if already_has_special_tokens:
172
+ return super().get_special_tokens_mask(
173
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
174
+ )
175
+
176
+ bos_token_id = [1] if self.add_bos_token else []
177
+ eos_token_id = [1] if self.add_eos_token else []
178
+
179
+ if token_ids_1 is None:
180
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
181
+ return (
182
+ bos_token_id
183
+ + ([0] * len(token_ids_0))
184
+ + eos_token_id
185
+ + bos_token_id
186
+ + ([0] * len(token_ids_1))
187
+ + eos_token_id
188
+ )
189
+
190
+ def create_token_type_ids_from_sequences(
191
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
192
+ ) -> List[int]:
193
+ """
194
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
195
+ sequence pair mask has the following format:
196
+
197
+ ```
198
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
199
+ | first sequence | second sequence |
200
+ ```
201
+
202
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
203
+
204
+ Args:
205
+ token_ids_0 (`List[int]`):
206
+ List of ids.
207
+ token_ids_1 (`List[int]`, *optional*):
208
+ Optional second list of IDs for sequence pairs.
209
+
210
+ Returns:
211
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
212
+ """
213
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
214
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
215
+
216
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
217
+
218
+ if token_ids_1 is not None:
219
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
220
+
221
+ return output
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7a5b465bbc9465b214e0962076c1170783a8ee88fb01454b0c33609bd3cf954
3
+ size 2197499
tokenizer_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "Telechat2Tokenizer",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_telechat2.Telechat2Tokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "added_tokens_decoder": {
10
+ "1": {
11
+ "content": "<_start>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false,
16
+ "special": true
17
+ },
18
+ "2": {
19
+ "content": "<_end>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false,
24
+ "special": true
25
+ },
26
+ "3": {
27
+ "content": "<_pad>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false,
32
+ "special": true
33
+ },
34
+ "4": {
35
+ "content": "<_user>",
36
+ "lstrip": false,
37
+ "normalized": false,
38
+ "rstrip": false,
39
+ "single_word": false,
40
+ "special": true
41
+ },
42
+ "5": {
43
+ "content": "<_bot>",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false,
48
+ "special": true
49
+ },
50
+ "6": {
51
+ "content": "<_system>",
52
+ "lstrip": false,
53
+ "normalized": false,
54
+ "rstrip": false,
55
+ "single_word": false,
56
+ "special": true
57
+ },
58
+ "9": {
59
+ "content": "<tool_call>",
60
+ "lstrip": false,
61
+ "normalized": false,
62
+ "rstrip": false,
63
+ "single_word": false,
64
+ "special": true
65
+ },
66
+ "10": {
67
+ "content": "</tool_call>",
68
+ "lstrip": false,
69
+ "normalized": false,
70
+ "rstrip": false,
71
+ "single_word": false,
72
+ "special": true
73
+ },
74
+ "11": {
75
+ "content": "<tool_response>",
76
+ "lstrip": false,
77
+ "normalized": false,
78
+ "rstrip": false,
79
+ "single_word": false,
80
+ "special": true
81
+ },
82
+ "12": {
83
+ "content": "</tool_response>",
84
+ "lstrip": false,
85
+ "normalized": false,
86
+ "rstrip": false,
87
+ "single_word": false,
88
+ "special": true
89
+ }
90
+ },
91
+ "additional_special_tokens": [
92
+ "<_start>",
93
+ "<_end>",
94
+ "<_pad>",
95
+ "<_user>",
96
+ "<_bot>",
97
+ "<_system>",
98
+ "<tool_call>",
99
+ "</tool_call>",
100
+ "<tool_response>",
101
+ "</tool_response>"
102
+ ],
103
+ "add_bos_token": false,
104
+ "add_eos_token": false,
105
+ "use_fast": false,
106
+ "clean_up_tokenization_spaces": false,
107
+ "split_special_tokens": false,
108
+ "model_max_length": 100000000,
109
+ "sp_model_kwargs": {},
110
+ "bos_token": "<_start>",
111
+ "eos_token": "<_end>",
112
+ "pad_token": "<_pad>",
113
+ "chat_template": "{%- if tools %}\n {%- if messages[0]['role'] == 'system' %}\n {{-'<_system>'+messages[0]['content'] }}\n {%- else %}\n {{- '<_system>'+'你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。' }}\n {%- endif %}\n {{- '\\n\\n# 可用工具\\n你可以调用<tools></tools>标签中包含的一个或多个工具来辅助你回答问题,以下是可用工具详情:\\n<tools>\\n' }}\n {%- for tool in tools %}\n {{- tool | tojson }}\n {{-'\\n'}}\n {%- endfor %}\n {{- '</tools>\\n\\n# 调用方法\\n你需要遵循工具的要求,使用json格式返回工具名称及参数,并用<tool_call></tool_call>包含。下方是一个调用模板:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call>\\n\\n' }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<_system>' + messages[0]['content'] + '\\n' }}\n {%- else %}\n {{- '<_system>'+'你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == 'user') %}\n {{- '<_user>' + message.content }}\n {%- elif message.role == 'bot' or message.role == 'assistant' %}\n {{- '<_bot>' }}\n {%- if message.content %}\n {{- message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {%- if loop.index0 == 0 %}\n {{-'<tool_call>'}}\n {%- else %}\n {{-'\\n<tool_call>'}}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}{{ tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<_end>\\n' }}\n {%- elif message.role == 'tool' %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != 'tool') %}\n {{- '<_user>'+'<tool_response>\\n' }}\n {%- else %}\n {{- '\\n<tool_response>\\n' }}\n {%- endif %}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<_bot>' }}\n{%- endif %}"
114
+ }