nengrenjie83 commited on
Commit
b78b52f
1 Parent(s): 247f626

Upload 28 files

Browse files
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
CITATION.cff ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ authors:
4
+ - family-names: "Xu"
5
+ given-names: "Ming"
6
+ title: "MedicalGPT: Training Your Own Medical GPT Model with ChatGPT Training Pipeline"
7
+ url: "https://github.com/shibing624/MedicalGPT"
8
+ data-released: 2023-06-02
9
+ version: 0.0.4
CONTRIBUTING.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either
4
+ side, please stick to the following process:
5
+
6
+ 1. Check if there is already an issue for your concern.
7
+ 2. If there is not, open a new one to start a discussion. We hate to close finished PRs!
8
+ 3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the
9
+ commit guidelines below.
DISCLAIMER ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The software project, data, and models provided by our GitHub project are provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement.
2
+
3
+ In no event shall the project owners or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software project, data, or models, even if advised of the possibility of such damage.
4
+
5
+ Users of this software project, data, and models are solely responsible for any consequences of their use. The project owners and contributors shall not be held responsible for any subsequent or potential harm caused by the use of this software project, data, or models.
6
+
7
+ By using this software project, data, or models, users accept and agree to this disclaimer. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
8
+
9
+ It is important to note that this software project, data, and models are still in the research phase and are provided for experimental purposes only. As such, the project owners and contributors do not guarantee the accuracy, completeness, or usefulness of the software project, data, or models.
10
+
11
+ Furthermore, due to the experimental nature of this software project, data, and models, it is possible that they may contain or generate inappropriate responses, errors, or inconsistencies. Users should exercise caution when using this software project, data, or models, and should not rely solely on them for any critical or sensitive tasks.
12
+
13
+ The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models, including but not limited to, any inappropriate responses generated by the software project, data, or models.
14
+
15
+ By using this software project, data, or models, users acknowledge and accept the experimental nature of the software project, data, and models, and understand the potential risks and limitations associated with their use. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
16
+
17
+ The software project, data, and models provided by our GitHub project are intended for research purposes only. They should not be used for any commercial, business, or legal purposes, and should not be relied upon as a substitute for professional advice or judgment.
18
+
19
+ Users of this software project, data, and models are strictly prohibited from using them for any commercial purposes, including but not limited to, selling, licensing, or distributing the software project, data, or models to third parties.
20
+
21
+ The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models for any commercial or business purposes.
22
+
23
+ By using this software project, data, or models, users agree to use them for research purposes only, and not for any commercial or business purposes. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,326 @@
1
- ---
2
- title: MedicalGPT Main
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.41.2
8
- app_file: app.py
9
- pinned: false
10
- license: other
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
2
+
3
+ <div align="center">
4
+ <a href="https://github.com/shibing624/MedicalGPT">
5
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" height="100" alt="Logo">
6
+ </a>
7
+ </div>
8
+
9
+ -----------------
10
+
11
+ # MedicalGPT: Training Medical GPT Model
12
+ [![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
13
+ [![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
14
+ [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
15
+ [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
16
+ [![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
17
+ [![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
18
+ [![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
19
+
20
+ ## 📖 Introduction
21
+
22
+ **MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
23
+ Supervised Finetuning, RLHF(Reward Modeling and Reinforcement Learning) and DPO(Direct Preference Optimization).
24
+
25
+ **MedicalGPT** 训练医疗大模型,实现了包括增量预训练、有监督微调、RLHF(奖励建模、强化学习训练)和DPO(直接偏好优化)。
26
+
27
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/dpo.jpg" width="860" />
28
+
29
+ - RLHF training pipeline来自Andrej Karpathy的演讲PDF [State of GPT](https://karpathy.ai/stateofgpt.pdf),视频 [Video](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2)
30
+ - DPO方法来自论文[Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
31
+
32
+ ## 🔥 News
33
+ [2023/08/28] v1.5版本: 新增[DPO(直接偏好优化)](https://arxiv.org/pdf/2305.18290.pdf)方法,DPO通过直接优化语言模型来实现对其行为的精确控制,可以有效学习到人类偏好。详见[Release-v1.5](https://github.com/shibing624/MedicalGPT/releases/tag/1.5.0)
34
+
35
+ [2023/08/08] v1.4版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat),和对应的LoRA模型[shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora),详见[Release-v1.4](https://github.com/shibing624/MedicalGPT/releases/tag/1.4.0)
36
+
37
+ [2023/08/02] v1.3版本: 新增LLaMA, LLaMA2, Bloom, ChatGLM, ChatGLM2, Baichuan模型的多轮对话微调训练;新增领域词表扩充功能;新增中文预训练数据集和中文ShareGPT微调训练集,详见[Release-v1.3](https://github.com/shibing624/MedicalGPT/releases/tag/1.3.0)
38
+
39
+ [2023/07/13] v1.1版本: 发布中文医疗LLaMA-13B模型[shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的完整模型权重,详见[Release-v1.1](https://github.com/shibing624/MedicalGPT/releases/tag/1.1)
40
+
41
+ [2023/06/15] v1.0版本: 发布中文医疗LoRA模型[shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的LoRA权重,详见[Release-v1.0](https://github.com/shibing624/MedicalGPT/releases/tag/1.0.0)
42
+
43
+ [2023/06/05] v0.2版本: 以医疗为例,训练领域大模型,实现了四阶段训练:包括二次预训练、有监督微调、奖励建模、强化学习训练。详见[Release-v0.2](https://github.com/shibing624/MedicalGPT/releases/tag/0.2.0)
44
+
45
+
46
+ ## 😊 Features
47
+
48
+
49
+ 基于ChatGPT Training Pipeline,本项目实现了领域模型--医疗行业语言大模型的训练:
50
+
51
+
52
+ - 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识(可选)
53
+ - 第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
54
+ - 第三阶段
55
+ - RLHF(Reinforcement Learning from Human Feedback)基于人类反馈对语言模型进行强化学习,分为两步:
56
+ - RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
57
+ - RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
58
+ - [DPO(Direct Preference Optimization)](https://arxiv.org/pdf/2305.18290.pdf)直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好
59
+
60
+
61
+
62
+ ### Release Models
63
+
64
+
65
+ | Model | Base Model | Introduction |
66
+ |:------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
67
+ | [shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的LoRA权重(单轮对话) |
68
+ | [shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的完整模型权重(单轮对话) |
69
+ | [shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的LoRA权重 |
70
+ | [shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的完整模型权重 |
71
+
72
+ 演示[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat)模型效果:
73
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/demo-screen.gif" width="860" />
74
+ 具体case见[Inference Examples](#inference-examples)
75
+
76
+ ## ▶️ Demo
77
+
78
+
79
+ 我们提供了一个简洁的基于gradio的交互式web界面,启动服务后,可通过浏览器访问,输入问题,模型会返回答案。
80
+
81
+ 启动服务,命令如下:
82
+ ```shell
83
+ CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --model_type base_model_type --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
84
+ ```
85
+
86
+ 参数说明:
87
+
88
+ - `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
89
+ - `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录,也可使用HF Model Hub模型调用名称
90
+ - `--lora_model {lora_model}`:LoRA文件所在目录,也可使用HF Model Hub模型调用名称。若lora权重已经合并到预训练模型,则删除--lora_model参数
91
+ - `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
92
+ - `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
93
+ - `--only_cpu`: 仅使用CPU进行推理
94
+ - `--gpus {gpu_ids}`: 指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
95
+ - `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不��整
96
+
97
+
98
+ ## 💾 Install
99
+ #### Updating the requirements
100
+ From time to time, the `requirements.txt` changes. To update, use this command:
101
+
102
+ ```markdown
103
+ git clone https://github.com/shibing624/MedicalGPT
104
+ conda activate gpt
105
+ cd MedicalGPT
106
+ pip install -r requirements.txt --upgrade
107
+ ```
108
+
109
+ ## 🚀 Training Pipeline
110
+
111
+ Training Stage:
112
+
113
+ | Stage | Introduction | Python script | Shell script |
114
+ |:--------------------------------|:-------------|:--------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------|
115
+ | Continue Pretraining | 增量预训练 | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |
116
+ | Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
117
+ | Direct Preference Optimization | 直接偏好优化 | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |
118
+ | Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
119
+ | Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |
120
+
121
+ - 提供完整PT+SFT+DPO全阶段串起来训练的pipeline:[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb),运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kMIe3pTec2snQvLBA00Br8ND1_zwy3Gr?usp=sharing)
122
+ - 提供完整PT+SFT+RLHF全阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,运行完大概需要20分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
123
+ - [训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
124
+ - [数据集wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%95%B0%E6%8D%AE%E9%9B%86)
125
+ - [扩充词表wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%89%A9%E5%85%85%E4%B8%AD%E6%96%87%E8%AF%8D%E8%A1%A8)
126
+ - [FAQ](https://github.com/shibing624/MedicalGPT/wiki/FAQ)
127
+
128
+ #### Supported Models
129
+
130
+ | 模型名 | 模型大小 | Template |
131
+ | ------------------------------------------------------- | --------------------------- |---------------|
132
+ | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | vicuna |
133
+ | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | alpaca |
134
+ | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
135
+ | [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | baichuan-chat |
136
+ | [InternLM](https://github.com/InternLM/InternLM) | 7B | intern |
137
+ | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | chatml |
138
+ | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | xverse |
139
+ | [ChatGLM](https://github.com/THUDM/ChatGLM-6B) | 6B | chatglm |
140
+ | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | chatglm2 |
141
+
142
+ The following models are tested:
143
+
144
+ bloom:
145
+ - [bigscience/bloomz-560m](https://huggingface.co/bigscience/bloomz-560m)
146
+ - [bigscience/bloomz-1b7](https://huggingface.co/bigscience/bloomz-1b7)
147
+ - [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
148
+
149
+ llama:
150
+ - [shibing624/chinese-alpaca-plus-7b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-7b-hf)
151
+ - [shibing624/chinese-alpaca-plus-13b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-13b-hf)
152
+ - [minlik/chinese-llama-plus-7b-merged](https://huggingface.co/minlik/chinese-llama-plus-7b-merged)
153
+ - [shibing624/chinese-llama-plus-13b-hf](https://huggingface.co/shibing624/chinese-llama-plus-13b-hf)
154
+ - [decapoda-research/llama-7b-hf](https://huggingface.co/decapoda-research/llama-7b-hf)
155
+ - [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1)
156
+
157
+ llama2:
158
+ - [daryl149/llama-2-7b-chat-hf](https://huggingface.co/daryl149/llama-2-7b-chat-hf)
159
+ - [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
160
+ - [ziqingyang/chinese-alpaca-2-7b](https://huggingface.co/ziqingyang/chinese-alpaca-2-7b)
161
+
162
+ chatglm:
163
+ - [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
164
+ - [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
165
+
166
+ baichuan:
167
+ - [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
168
+ - [baichuan-inc/Baichuan-13B-Base](https://huggingface.co/baichuan-inc/Baichuan-13B-Base)
169
+ - [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat)
170
+
171
+ xverse:
172
+ - [xverse/XVERSE-13B-Chat](https://huggingface.co/xverse/XVERSE-13B-Chat)
173
+
174
+ Qwen:
175
+ - [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
176
+
177
+ ## 💻 Inference
178
+ 训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。
179
+
180
+ ```shell
181
+ CUDA_VISIBLE_DEVICES=0 python inference.py \
182
+ --model_type base_model_type \
183
+ --base_model path_to_model_hf_dir \
184
+ --tokenizer_path path_to_model_hf_dir \
185
+ --lora_model path_to_lora \
186
+ --interactive
187
+ ```
188
+
189
+ 参数说明:
190
+
191
+ - `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
192
+ - `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
193
+ - `--tokenizer_path {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
194
+ - `--lora_model {lora_model}`:LoRA解压后文件所在目录,也可使用HF Model Hub模型调用名称。如果已经合并了LoRA权重到预训练模型,则可以不提供此参数
195
+ - `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
196
+ - `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
197
+ - `--interactive`:以交互方式启动多轮问答,使用流式推理
198
+ - `--data_file {file_name}`:非交互方式启动下,按行读取file_name中的的内容进行预测
199
+ - `--predictions_file {file_name}`:非交互式方式下,将预测的结果以json格式写入file_name
200
+ - `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不调整
201
+ - `--only_cpu`:仅使用CPU进行推理
202
+ - `--gpus {gpu_ids}`:指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
203
+
204
+
205
+ #### Inference Examples
206
+ [shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) inference examples:
207
+
208
+ <details><summary>Model output case</summary>
209
+
210
+ | Human | Assitant |
211
+ |------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
212
+ | 小孩发烧怎么办 | 发烧是身体对感染或其他疾病的自然反应。以下是一些可以帮助缓解发烧症状的建议:<br/>1. 保持身体水分充足:发烧时,身体会失去水分,因此需要保持水分充足。<br/>2. 使用冰袋:将冰袋放在额头或颈部可以帮助降低体温。<br/>3. 使用退烧药:退烧药可以帮助降低体温。<br/>4. 保持身体温度:保持身体温度可以帮助降低体温。<br/>5. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>6. 使用退烧药:退烧药可以帮助降低体温。<br/>7. 保持身体温度:保持身体温度可以帮助降低体温。<br/>8. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>9. 使用退烧药:退烧药可以帮助降低体温。<br/>10. 保持身体温度:保持身体温度可以帮助降低体温。 |
213
+ | 一岁宝宝发烧能吃啥药? | 1岁宝宝发烧时,可以考虑以下几种药物:<br/>1. 退烧药:如布洛芬(Ibuprofen)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助降低体温,缓解发烧症状。<br/>2. 抗生素:如果发烧是由细菌感染引起的,可以考虑使用抗生素。但请务必在医生的建议下使用。<br/>3. 抗病毒药物:如果发烧是由病毒感染引起的,可以考虑使用抗病毒药物。但请务必在医生的建议下使用。<br/>4. 解热镇痛药:如阿司匹林(Aspirin)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助缓解发烧症状。<br/>5. 抗组胺药:如苯海拉明(Diphenhydramine)或非索非那定(Naproxen)。这些药物可以帮助缓解发烧症状。<br/>请务必在医生的建议下使用这些药物,并遵循药物说明书上的用量和用法。同时,保持宝宝充足的水分摄入,并确保宝宝有足够的休息。 |
214
+ | 肛门病变可能是什么疾病的症状 | 肛门病变可能由多种原因引起,包括感染、炎症、肿瘤或其他疾病。一些常见的肛门病变包括:<br/>1. 肛裂:这是一种肛门皮肤撕裂,通常由便秘或其他压力因素引起。<br/>2. 肛瘘:这是一种肛门感染,导致肛门皮肤和直肠之间形成瘘管。<br/>3. 肛门疣:这是一种肛门皮肤上形成的疣状病变,通常由人类乳头状瘤病毒(HPV)引起。<br/>4. 肛门脓肿:这是一种肛门感染,导致肛门皮肤下形成脓肿。<br/>5. 肛门肿瘤:这是一种肛门皮肤或直肠内形成的肿瘤。<br/>6. 肛门溃疡:这是一种肛门皮肤上形成的溃疡,通常由感染或炎症引起。<br/>7. 肛门瘙痒:这是一种肛门皮肤瘙痒的症状,可能由感染、炎症或其他疾病引起。<br/>如果您有肛门病变的症状,请咨询医疗保健专业人员以获得适当的诊断和治疗。 |
215
+ | 医生,您好白带增多泛黄怎么回事? | 您好,白带增多和发黄可能是由多种原因引起的,包括感染、荷尔蒙变化、阴道干燥或其他妇科问题。以下是一些可能的原因:<br/>1. 感染:白带增多和发黄可能是由阴道或尿道感染引起的。感染可能由细菌、病毒或真菌引起。<br/>2. 荷尔蒙变化:荷尔蒙变化,如月经周期或更年期,可能会导致白带增多和发黄。<br/>3. 阴道干燥:阴道干燥可能会导致白带增多和发黄。<br/>4. 其他妇科问题:其他妇科问题,如子宫颈炎、子宫内膜炎或子宫内膜异位症,也可能导致白带增多和发黄。<br/>如果您担心自己的白带,建议您咨询妇科医生进行检查和诊断。医生可以根据您的症状和检查结果确定原因并提供适当的治疗。 |
216
+ | 药流后有巧克力白带是没流干净吗 | 药流后出现巧克力白带并不一定意味着没有流干净。巧克力白带是由于子宫内膜脱落和出血引起的,通常在药流后的一段时间内会出现。<br/><br/>药流后出现巧克力白带的原因可能包括:<br/><br/>1. 子宫内膜��落:药流后,子宫内膜会脱落,导致出血。这种出血可能表现为巧克力白带。<br/>2. 感染:药流后,子宫内膜和宫颈口可能存在创伤,容易受到细菌感染。感染可能导致巧克力白带。<br/>3. 激素水平变化:药流后,激素水平可能发生变化,导致巧克力白带。<br/><br/>如果巧克力白带持续时间较长,或者伴有其他症状,如腹痛、发热、分泌物异味等,建议及时就诊,进行相关检查,如妇科检查、B超等,以确定是否存在感染或其他问题。在医生的指导下进行相应的治疗。 |
217
+ | 维胺酯维E乳膏能治理什么疾病 | 维胺酯维E乳膏是一种外用药物,主要用于治疗痤疮、脂溢性皮炎等皮肤病。它可以帮助减少皮肤油脂分泌,减轻炎症,促进皮肤修复。<br/><br/>需要注意的是,维胺酯维E乳膏并非万能药,对于某些皮肤病可能并不适用。在使用前,建议咨询医生或药师,了解药物的适应症、禁忌症、用法用量等信息。同时,使用药物时应遵循医嘱,避免自行滥用。 |
218
+ | 痤疮是什么原因导致的? | 痤疮,也称为粉刺,是由皮肤毛孔堵塞引起的。堵塞毛孔的原因有很多,包括:<br/>1. 皮脂腺过度活跃:皮脂腺产生过多油脂,导致毛孔堵塞。<br/>2. 毛孔堵塞:皮肤表面的死皮细胞、油脂和细菌堵塞毛孔。<br/>3. 炎症:痤疮通常伴随着炎症,这可能是由于皮肤感染或免疫系统对痤疮的反应引起的。<br/>4. 遗传因素:某些人可能更容易患痤疮,这可能是由于遗传因素引起的。<br/>5. 荷尔蒙变化:痤疮在青春期和月经期间更为常见,这可能是由于荷尔蒙变化引起的。<br/>6. 生活方式因素:某些生活方式因素,如压力、不良饮食和缺乏运动,也可能导致痤疮。<br/>7. 化妆品和护肤品:某些化妆品和护肤品可能导致痤疮,特别是那些含有刺激性成分的产品。<br/>8. 药物:某些药物,如口服避孕药,也可能导致痤疮。<br/>要预防痤疮,可以采取一些措施,如保持皮肤清洁、使用温和的护肤品、避免过度使用化妆品和护肤品、保持健康的生活方式和饮食习惯、避免压力和荷尔蒙变化。如果痤疮严重,可能需要寻求皮肤科医生的帮助。 |
219
+ | 脸上长青春痘怎么办,能吃啥药,需要注意啥 | 青春痘,又称痤疮,是一种常见的皮肤病。青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。以下是一些建议:<br/><br/>1. 饮食:<br/>* 避免高糖、高脂、高油的食物,如快餐、甜点、油炸食品等。<br/>* 多吃新鲜蔬菜、水果、全谷类食物,以增加纤维摄入,促进肠道蠕动。<br/>* 避免摄入过多的咖啡因、酒精、辛辣食物等刺激性食物。<br/><br/>2. 生活习惯:<br/>* 保持良好的作息时间,保证充足的睡眠。<br/>* 避免长时间使用电脑、手机等电子产品,以免辐射对皮肤造成刺激。<br/>* 保持皮肤清洁,每天早晚使用温和的洁面产品清洁皮肤。<br/>* 避免挤压青春痘,以免感染和留下疤痕。<br/><br/>3. 药物:<br/>* 外用药物:可以使用含有水杨酸、苯甲酸、壬二酸等成分的护肤品,帮助去除皮肤表面的角质,减少痘痘的形成。<br/>* 口服药物:如果青春痘严重,可以考虑口服抗生素、维生素B族、锌等药物,但需在医生指导下使用。<br/><br/>4. 注意:<br/>* 不要随意使用化妆品、护肤品,以免刺激皮肤。<br/>* 不要过度清洁皮肤,以免破坏皮肤屏障。<br/>* 保持心情舒畅,避免焦虑、抑郁等负面情绪。<br/><br/>总之,青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。在治疗过程中,要保持耐心和信心,遵循医生的建议,逐步改善皮肤状况。 |
220
+ | 由于输精管堵塞会不会导致男性不育?���医生您好,我同事这几天不开心,我问她怎么了,她说她老公输精管堵塞,所以一直没孩子,不是不想要,请问输精管堵塞会不会导致男性不育? | 输精管堵塞可能会导致男性不育。输精管是连接睾丸和尿道的一条管道,负责将精子从睾丸输送到尿道。如果输精管堵塞,精子就无法通过输精管到达尿道,从而导致不育。<br/><br/>输精管堵塞的原因有很多,包括先天性发育异常、炎症、肿瘤、外伤等。输精管堵塞的治疗方法包括手术疏通、药物治疗等。如果输精管堵塞严重,可能需要进行手术疏通。<br/><br/>需要注意的是,输精管堵塞并不一定导致不育,因为还有其他途径可以将精子输送到女性体内。例如,精子可以通过附睾、前列腺等组织到达女性体内。因此,如果输精管堵塞,建议及时就医,寻求专业医生的帮助。 |
221
+ | 心肌梗塞能治疗吗? | 心肌梗死是一种严重的疾病,如果及时治疗,可以挽救生命并防止长期残疾。治疗心肌梗死的关键是尽早识别和治疗。<br/><br/>心肌梗死的治疗通常包括以下步骤:<br/><br/>1. 立即就医:如果出现胸痛、呼吸急促或晕厥等症状,立即就医。<br/><br/>2. 药物治疗:医生可能会开处方药物,如硝酸甘油或阿司匹林,以帮助缓解症状并防止进一步的损伤。<br/><br/>3. 溶栓治疗:如果患者在到达医院前出现心肌梗死,医生可能会进行溶栓治疗,通过注射溶解血栓的药物来帮助恢复血流。<br/><br/>4. 手术治疗:如果药物治疗或溶栓治疗无效,医生可能会建议进行手术,如冠状动脉搭桥手术或心脏支架植入术。<br/><br/>5. 康复和预防:心肌梗死的治疗通常包括康复计划,以帮助患者恢复身体功能并预防未来的心脏事件。此外,患者还应遵循医生的建议,以降低患心脏病的风险。<br/><br/>需要注意的是,心肌梗死的治疗取决于患者的具体情况和症状。因此,如果您或您认识的人出现心肌梗死症状,请立即就医。 |
222
+ | 怎么让自己精力充沛,列5点建议 | 1. 保持充足的睡眠:睡眠对于保持精力充沛至关重要。确保每晚获得7-9小时的高质量睡眠。<br/>2. 均衡饮食:均衡饮食对于保持能量水平至关重要。确保摄入足够的蛋白质、碳水化合物和健康脂肪。<br/>3. 定期锻炼:定期锻炼可以帮助提高能量水平。尝试每天进行至少30分钟的有氧运动,如散步、跑步或骑自行车。<br/>4. 管理压力:压力会消耗能量,因此管理压力对于保持精力充沛至关重要。尝试找到放松和减压的方法,如冥想、瑜伽或深呼吸练习。<br/>5. 保持水分:脱水会降低能量水平。确保每天喝足够的水,并根据需要补充水分。 |
223
+ | 写个春天的诗 | 春天,<br/>是大地苏醒的季节,<br/>是万物复苏的季节,<br/>是生命绽放的季节。<br/>春天,<br/>是阳光明媚的季节,<br/>是鸟语花香的季节,<br/>是绿意盎然的季节,<br/>是希望无限的季节。<br/>春天,<br/>是人们欢笑的季节,<br/>是孩子们嬉戏的季节,<br/>是恋人们相拥的季节,<br/>是人们追求梦想的季节。 |
224
+ | 南极洲为啥没有北极熊? | 南极洲没有北极熊,因为北极熊生活在北极地区,而南极洲位于南半球,气候和环境与北极地区有很大不同。南极洲��气候非常寒冷,常年被冰雪覆盖,而北极熊则生活在北极地区,那里有丰富的海洋生物和陆地哺乳动物。此外,南极洲的生态系统也与北极地区不同,没有北极熊所需的生存条件。 |
225
+
226
+ </details>
227
+
228
+ ## 📚 Dataset
229
+ ### 医疗数据集
230
+
231
+ - 240万条中文医疗数据集(包括预训练、指令微调和奖励数据集):[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
232
+ - 22万条中文医疗对话数据集(华佗项目):[FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
233
+
234
+ ### 通用数据集
235
+
236
+ #### Pretraining datasets
237
+ - 16GB中英文无监督、平行语料[Linly-AI/Chinese-pretraining-dataset](https://huggingface.co/datasets/Linly-AI/Chinese-pretraining-dataset)
238
+ - 524MB中文维基百科语料[wikipedia-cn-20230720-filtered](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
239
+ #### SFT datasets
240
+ - 10万条多语言ShareGPT GPT4多轮对话数据集:[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) [本项目支持格式]
241
+ - 9万条英文ShareGPT多轮对话数集:[anon8231489123/ShareGPT_Vicuna_unfiltered](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) [本项目支持格式]
242
+ - 50万条中文ChatGPT指令Belle数据集:[BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
243
+ - 100万条中文ChatGPT指令Belle数据集:[BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
244
+ - 5万条英文ChatGPT指令Alpaca数据集:[50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
245
+ - 2万条中文ChatGPT指令Alpaca数据集:[shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
246
+ - 69万条中文指令Guanaco数据集(Belle50万条+Guanaco19万条):[Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
247
+ - 5万条英文ChatGPT多轮对话数据集:[RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
248
+ - 80万条中文ChatGPT多轮对话数据集:[BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
249
+ - 116万条中文ChatGPT多轮对话数据集:[fnlp/moss-002-sft-data](https://huggingface.co/datasets/fnlp/moss-002-sft-data)
250
+ - 3.8万条中文ShareGPT多轮对话数据集:[FreedomIntelligence/ShareGPT-CN](https://huggingface.co/datasets/FreedomIntelligence/ShareGPT-CN)
251
+
252
+ #### Reward Model datasets
253
+ - 原版的oasst1数据集:[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
254
+ - 2万条多语言oasst1的reward数据集:[tasksource/oasst1_pairwise_rlhf_reward](https://huggingface.co/datasets/tasksource/oasst1_pairwise_rlhf_reward)[本项目支持格式]
255
+ - 11万条英文hh-rlhf的reward数据集:[Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
256
+ - 9万条英文reward数据集(来自Anthropic's Helpful Harmless dataset):[Dahoas/static-hh](https://huggingface.co/datasets/Dahoas/static-hh)
257
+ - 7万条英文reward数据集(来源同上):[Dahoas/rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
258
+ - 7万条繁体中文的reward数据集(翻译自rm-static)[liswei/rm-static-m2m100-zh](https://huggingface.co/datasets/liswei/rm-static-m2m100-zh)
259
+ - 7万条英文Reward数据集:[yitingxie/rlhf-reward-datasets](https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets)
260
+ - 3千条中文知乎问答偏好数据集:[liyucheng/zhihu_rlhf_3k](https://huggingface.co/datasets/liyucheng/zhihu_rlhf_3k)
261
+
262
+ ## ✅ Todo
263
+
264
+ 1. [x] add multi-round dialogue data fine-tuning method
265
+ 2. [x] add reward model fine-tuning
266
+ 3. [x] add rl fine-tuning
267
+ 4. [x] add medical reward dataset
268
+ 5. [x] add llama in8/int4 training
269
+ 6. [x] add all training and predict demo in colab
270
+ 7. [x] add dpo training
271
+
272
+ ## ☎️ Contact
273
+
274
+ - Issue(建议)
275
+ :[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
276
+ - 邮件我:xuming: [email protected]
277
+ - 微信我: 加我*微信号:xuming624, 备注:姓名-公司名-NLP* 进NLP交流群。
278
+
279
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
280
+
281
+ ## ⚠️ 局限性、使用限制与免责声明
282
+
283
+ 基于当前数据和基础模型训练得到的SFT模型,在效果上仍存在以下问题:
284
+
285
+ 1. 在涉及事实性的指令上可能会产生违背事实的错���回答。
286
+
287
+ 2. 对于具备危害性的指令无法很好的鉴别,由此会产生危害性言论。
288
+
289
+ 3. 在一些涉及推理、代码、多轮对话等场景下模型的能力仍有待提高。
290
+
291
+ 基于以上模型局限性,我们要求开发者仅将我们开源的模型权重及后续用此项目生成的衍生物用于研究目的,不得用于商业,以及其他会对社会带来危害的用途。
292
+
293
+ 本项目仅可应用于研究目的,项目开发者不承担任何因使用本项目(包含但不限于数据、模型、代码等)导致的危害或损失。详细请参考[免责声明](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER)。
294
+
295
+ 项目代码的授权协议为 [The Apache License 2.0](/LICENSE),代码可免费用做商业用途,模型权重和数据只能用于研究目的。请在产品说明中附加MedicalGPT的链接和授权协议。
296
+
297
+
298
+ ## 😇 Citation
299
+
300
+ 如果你在研究中使用了MedicalGPT,请按如下格式引用:
301
+
302
+ ```latex
303
+ @misc{MedicalGPT,
304
+ title={MedicalGPT: Training Medical GPT Model},
305
+ author={Ming Xu},
306
+ year={2023},
307
+ howpublished={\url{https://github.com/shibing624/MedicalGPT}},
308
+ }
309
+ ```
310
+
311
+ ## 😍 Contribute
312
+
313
+ 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
314
+
315
+ - 在`tests`添加相应的单元测试
316
+ - 使用`python -m pytest`来运行所有单元测试,确保所有单测都是通过的
317
+
318
+ 之后即可提交PR。
319
+
320
+ ## 💕 Acknowledgements
321
+
322
+ - [Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
323
+ - [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
324
+ - [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
325
+
326
+ Thanks for their great work!
README_EN.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
2
+
3
+ <div align="center">
4
+ <a href="https://github.com/shibing624/MedicalGPT">
5
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" width="120" alt="Logo">
6
+ </a>
7
+ </div>
8
+
9
+ -----------------
10
+
11
+ # MedicalGPT: Training Medical GPT Model
12
+ [![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
13
+ [![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
14
+ [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
15
+ [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
16
+ [![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
17
+ [![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
18
+ [![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
19
+
20
+ ## 📖 Introduction
21
+
22
+ **MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
23
+ Supervised Finetuning, Reward Modeling and Reinforcement Learning.
24
+
25
+
26
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/GPT_Training.jpg" width="860" />
27
+
28
+ Training MedicalGPT model:
29
+
30
+ - Stage 1:PT(Continue PreTraining), Pre-training the LLaMA model on massive domain document data to inject domain knowledge
31
+ - Stage 2: SFT (Supervised Fine-tuning) has supervised fine-tuning, constructs instruction fine-tuning data sets, and performs instruction fine-tuning on the basis of pre-trained models to align instruction intentions
32
+ - Stage 3: RM (Reward Model) reward model modeling, constructing a human preference ranking data set, training the reward model to align human preferences, mainly the "HHH" principle, specifically "helpful, honest, harmless"
33
+ - Stage 4: RL (Reinforcement Learning) is based on human feedback reinforcement learning (RLHF), using the reward model to train the SFT model, and the generation model uses rewards or penalties to update its strategy in order to generate higher quality, more in line with human preferences text
34
+
35
+ ## ▶️ Demo
36
+
37
+ - Hugging Face Demo: doing
38
+
39
+ We provide a simple Gradio-based interactive web interface. After the service is started, it can be accessed through a browser, enter a question, and the model will return an answer. The command is as follows:
40
+ ```shell
41
+ python scripts/gradio_demo.py --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
42
+ ```
43
+
44
+ Parameter Description:
45
+
46
+ - `--base_model {base_model}`: directory to store LLaMA model weights and configuration files in HF format, or use the HF Model Hub model call name
47
+ - `--lora_model {lora_model}`: The directory where the LoRA file is located, and the name of the HF Model Hub model can also be used. If the lora weights have been merged into the pre-trained model, delete the --lora_model parameter
48
+ - `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
49
+ - `--use_cpu`: use only CPU for inference
50
+ - `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
51
+
52
+
53
+ ## 🚀 Training Pipeline
54
+
55
+ ### Stage 1: Continue Pretraining
56
+
57
+ Based on the llama-7b model, use medical encyclopedia data to continue pre-training, and expect to inject medical knowledge into the pre-training model to obtain the llama-7b-pt model. This step is optional
58
+
59
+
60
+ ```shell
61
+ cd scripts
62
+ sh run_pt.sh
63
+ ```
64
+
65
+ [Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
66
+
67
+ ### Stage 2: Supervised FineTuning
68
+ Based on the llama-7b-pt model, the llama-7b-sft model is obtained by using medical question-and-answer data for supervised fine-tuning. This step is required
69
+
70
+ Supervised fine-tuning of the base llama-7b-pt model to create llama-7b-sft
71
+
72
+ ```shell
73
+ cd scripts
74
+ sh run_sft.sh
75
+ ```
76
+
77
+ [Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
78
+
79
+ ### Stage 3: Reward Modeling
80
+ RM(Reward Model): reward model modeling
81
+
82
+ In principle, we can directly use human annotations to fine-tune the model with RLHF.
83
+
84
+ However, this will require us to send some samples to humans to be scored after each round of optimization. This is expensive and slow due to the large number of training samples required for convergence and the limited speed at which humans can read and annotate them.
85
+ A better strategy than direct feedback is to train a reward model RM on the human annotated set before entering the RL loop. The purpose of the reward model is to simulate human scoring of text.
86
+
87
+ The best practice for building a reward model is to rank the prediction results, that is, for each prompt (input text) corresponding to two results (yk, yj), the model predicts which score the human annotation is higher.
88
+ The RM model is trained by manually marking the scoring results of the SFT model. The purpose is to replace manual scoring. It is essentially a regression model used to align human preferences, mainly based on the "HHH" principle, specifically "helpful, honest, harmless".
89
+
90
+
91
+ Based on the llama-7b-sft model, the reward preference model is trained using medical question and answer preference data, and the llama-7b-reward model is obtained after training. This step is required
92
+
93
+ Reward modeling using dialog pairs from the reward dataset using the llama-7b-sft to create llama-7b-reward:
94
+
95
+ ```shell
96
+ cd scripts
97
+ sh run_rm.sh
98
+ ```
99
+ [Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
100
+
101
+ ### Stage 4: Reinforcement Learning
102
+ The purpose of the RL (Reinforcement Learning) model is to maximize the output of the reward model. Based on the above steps, we have a fine-tuned language model (llama-7b-sft) and reward model (llama-7b-reward).
103
+ The RL loop is ready to execute.
104
+
105
+ This process is roughly divided into three steps:
106
+
107
+ 1. Enter prompt, the model generates a reply
108
+ 2. Use a reward model to score responses
109
+ 3. Based on the score, a round of reinforcement learning for policy optimization (PPO)
110
+
111
+ <img src=https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/trl_loop.png height=400 />
112
+
113
+ Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward reward model to create llama-7b-rl
114
+
115
+ ```shell
116
+ cd scripts
117
+ sh run_rl.sh
118
+ ```
119
+ [Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
120
+
121
+ ## 🔥 Inference
122
+ After the training is complete, now we load the trained model to verify the effect of the model generating text.
123
+
124
+ ```shell
125
+ python scripts/inference.py \
126
+ --base_model path_to_llama_hf_dir \
127
+ --lora_model path_to_lora \
128
+ --with_prompt \
129
+ --interactive
130
+ ```
131
+
132
+ Parameter Description:
133
+
134
+ - `--base_model {base_model}`: Directory to store LLaMA model weights and configuration files in HF format
135
+ - `--lora_model {lora_model}`: The directory where the LoRA file is decompressed, and the name of the HF Model Hub model can also be used. If you have incorporated LoRA weights into the pre-trained model, you can not provide this parameter
136
+ - `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
137
+ - `--with_prompt`: Whether to merge the input with the prompt template. Be sure to enable this option if loading an Alpaca model!
138
+ - `--interactive`: start interactively for multiple single rounds of question and answer
139
+ - `--data_file {file_name}`: Start in non-interactive mode, read the contents of file_name line by line for prediction
140
+ - `--predictions_file {file_name}`: In non-interactive mode, write the predicted results to file_name in json format
141
+ - `--use_cpu`: use only CPU for inference
142
+ - `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
143
+
144
+
145
+ #### Inference Examples
146
+
147
+ <details>
148
+ | Human | Assistant |
149
+ | ----- | ------ |
150
+ | How to make yourself energetic, list 5 suggestions | Here are some suggestions to keep you energetic: <br/>1. Get enough sleep. Make sure to get enough sleep every night. <br/>2. Exercise more. Doing moderate aerobic exercise, such as jogging and swimming, can help improve your energy and stamina. 3. Eat healthy food. Choose protein-rich foods such as chicken, fish, and eggs, as well as fresh fruits and vegetables. 4. Reduce caffeine intake. Try to avoid drinking caffeinated beverages or eating caffeinated foods during the day. <br/>5. Relax. Try meditation, yoga, or other forms of relaxation to reduce stress and anxiety. |
151
+ </details>
152
+ <br/>
153
+
154
+
155
+ ## 📚 Dataset
156
+
157
+ - 2.4 million Chinese medical datasets (including pre-training, instruction fine-tuning and reward datasets): [shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
158
+
159
+ **Attach links to some general datasets and medical datasets**
160
+
161
+ - Belle dataset of 500,000 Chinese ChatGPT commands: [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
162
+ - Belle dataset of 1 million Chinese ChatGPT commands: [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
163
+ - Alpaca dataset of 50,000 English ChatGPT commands: [50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
164
+ - Alpaca dataset of 20,000 Chinese GPT-4 instructions: [shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
165
+ - Guanaco dataset with 690,000 Chinese instructions (500,000 Belle + 190,000 Guanaco): [Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
166
+ - 220,000 Chinese medical dialogue datasets (HuatuoGPT project): [FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
167
+
168
+ ## ✅ Todo
169
+
170
+ 1. [ ] Added multi-round dialogue data fine-tuning method
171
+ 2. [x] add reward model finetuning
172
+ 3. [x] add rl finetuning
173
+ 4. [x] add medical reward dataset
174
+ 5. [x] add llama in8/int4 training
175
+ 6. [ ] add all training and predict demo in colab
176
+ ## ☎️ Contact
177
+
178
+ - Issue (suggestion)
179
+ : [![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
180
+ - Email me: xuming: [email protected]
181
+ - WeChat Me: Add me* WeChat ID: xuming624, Remarks: Name-Company Name-NLP* Enter the NLP exchange group.
182
+
183
+ <img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
184
+
185
+ ## ⚠️ Limitations, Restrictions of Use and Disclaimer
186
+
187
+ The SFT model trained based on the current data and the basic model still has the following problems in terms of effect:
188
+
189
+ 1. Wrong answers that contradict the facts may be generated on the factual instructions.
190
+ 2. Unable to identify harmful instructions well, resulting in harmful speech.
191
+ 3. The ability of the model still needs to be improved in some scenarios involving reasoning, code, and multiple rounds of dialogue.
192
+
193
+ Based on the limitations of the above models, we require developers to only use our open source model weights and subsequent derivatives generated by this project for research purposes, and not for commercial use, and other purposes that will cause harm to society.
194
+ This project can only be used for research purposes, and the project developer is not responsible for any harm or loss caused by the use of this project (including but not limited to data, models, codes, etc.). For details, please refer to [Disclaimer](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER).
195
+ The license agreement for the project code is [The Apache License 2.0](/LICENSE), the code is free for commercial use, and the model weights and data can only be used for research purposes. Please attach MedicalGPT's link and license agreement in the product description.
196
+
197
+ ## 😇 Citation
198
+
199
+ If you used MedicalGPT in your research, please cite as follows:
200
+
201
+ ```latex
202
+ @misc{MedicalGPT,
203
+ title={MedicalGPT: Training Medical GPT Model},
204
+ author={Ming Xu},
205
+ year={2023},
206
+ howpublished={\url{https://github.com/shibing624/MedicalGPT}},
207
+ }
208
+ ```
209
+
210
+ ## 😍 Contribute
211
+
212
+ The project code is still very rough. If you have improved the code, you are welcome to submit it back to this project. Before submitting, please pay attention to the following two points:
213
+
214
+ - Add corresponding unit tests in `tests`
215
+ - Use `python -m pytest` to run all unit tests to ensure that all unit tests are passed
216
+
217
+ Then you can submit a PR.
218
+
219
+ ## 💕 Acknowledgements
220
+
221
+ - [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
222
+ - [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
223
+
224
+ Thanks for their great work!
_config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ theme: jekyll-theme-cayman
build_domain_tokenizer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: Build chinese tokenizer from corpus txt
5
+
6
+ # train sentencepiece model from `corpus.txt` and makes `m.model` and `m.vocab`
7
+ # `m.vocab` is just a reference. not used in the segmentation.
8
+ # spm.SentencePieceTrainer.train('--input=data/pretrain/tianlongbabu.txt --model_prefix=m --vocab_size=20000')
9
+ """
10
+ import argparse
11
+
12
+ import sentencepiece as spm
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--in_file', default='data/pretrain/fever.txt', type=str)
18
+ parser.add_argument('--domain_sp_model_name', default='domain_sp', type=str)
19
+ parser.add_argument('--max_sentence_length', default=16384, type=int)
20
+ parser.add_argument('--pad_id', default=3, type=int)
21
+ parser.add_argument('--vocab_size', default=2236, type=int)
22
+ parser.add_argument('--model_type', default="BPE", type=str)
23
+
24
+ args = parser.parse_args()
25
+ print(args)
26
+
27
+ spm.SentencePieceTrainer.train(
28
+ input=args.in_file,
29
+ model_prefix=args.domain_sp_model_name,
30
+ shuffle_input_sentence=False,
31
+ train_extremely_large_corpus=True,
32
+ max_sentence_length=args.max_sentence_length,
33
+ pad_id=args.pad_id,
34
+ model_type=args.model_type,
35
+ vocab_size=args.vocab_size,
36
+ split_digits=True,
37
+ split_by_unicode_script=True,
38
+ byte_fallback=True,
39
+ allow_whitespace_only_pieces=True,
40
+ remove_extra_whitespaces=False,
41
+ normalization_rule_name="nfkc",
42
+ )
43
+
44
+ # makes segmenter instance and loads the model file (m.model)
45
+ sp = spm.SentencePieceProcessor()
46
+ model_file = args.domain_sp_model_name + '.model'
47
+ sp.load(model_file)
48
+
49
+ # encode: text => id
50
+ print(sp.encode_as_pieces('潜伏性感染又称潜在性感染。慕容复来到河边,this is a test'))
51
+ print(sp.encode_as_ids('this is a test'))
52
+
53
+ # decode: id => text
54
+ print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
55
+ # print(sp.decode_ids([209, 31, 9, 375, 586]))
56
+
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert alpaca dataset into sharegpt format.
3
+
4
+ Usage: python convert_alpaca.py --in_file alpaca_data.json --out_file alpaca_data_sharegpt.json
5
+ """
6
+
7
+ import argparse
8
+
9
+ from datasets import load_dataset
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--in_file", type=str)
14
+ parser.add_argument("--out_file", type=str)
15
+ parser.add_argument("--data_type", type=str, default='alpaca')
16
+ args = parser.parse_args()
17
+ print(args)
18
+ data_files = {"train": args.in_file}
19
+ raw_datasets = load_dataset('json', data_files=data_files)
20
+ ds = raw_datasets['train']
21
+
22
+
23
+ def process_alpaca(examples):
24
+ convs = []
25
+ for instruction, inp, output in zip(examples['instruction'], examples['input'], examples['output']):
26
+ if len(inp.strip()) > 1:
27
+ instruction = instruction + '\n\n' + inp
28
+ q = instruction
29
+ a = output
30
+ convs.append([
31
+ {"from": "human", "value": q},
32
+ {"from": "gpt", "value": a}
33
+ ])
34
+ return {"conversations": convs}
35
+
36
+
37
+ if args.data_type in ['alpaca']:
38
+ ds = ds.map(process_alpaca, batched=True, remove_columns=ds.column_names, desc="Running process")
39
+ else:
40
+ # Other sharegpt dataset, need rename to conversations and remove unused columns
41
+ if "items" in ds.column_names:
42
+ ds = ds.rename(columns={"items": "conversations"})
43
+ columns_to_remove = ds.column_names.copy()
44
+ columns_to_remove.remove('conversations')
45
+ ds = ds.remove_columns(columns_to_remove)
46
+
47
+ ds.to_json(f"{args.out_file}", lines=True, force_ascii=False)
deepspeed_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "optimizer": {
3
+ "type": "AdamW",
4
+ "params": {
5
+ "lr": "auto",
6
+ "weight_decay": "auto",
7
+ "torch_adam": true,
8
+ "adam_w_mode": true
9
+ }
10
+ },
11
+ "scheduler": {
12
+ "type": "WarmupDecayLR",
13
+ "params": {
14
+ "warmup_min_lr": "auto",
15
+ "warmup_max_lr": "auto",
16
+ "warmup_num_steps": "auto",
17
+ "total_num_steps": "auto"
18
+ }
19
+ },
20
+ "fp16": {
21
+ "enabled": true,
22
+ "loss_scale": 0,
23
+ "loss_scale_window": 1000,
24
+ "initial_scale_power": 16,
25
+ "hysteresis": 2,
26
+ "min_loss_scale": 1
27
+ },
28
+ "zero_optimization": {
29
+ "stage": 2,
30
+ "allgather_partitions": true,
31
+ "allgather_bucket_size": 2e8,
32
+ "reduce_scatter": true,
33
+ "reduce_bucket_size": "auto",
34
+ "overlap_comm": true,
35
+ "contiguous_gradients": true
36
+ },
37
+ "gradient_accumulation_steps": "auto",
38
+ "gradient_clipping": "auto",
39
+ "steps_per_print": 1000,
40
+ "train_batch_size": "auto",
41
+ "train_micro_batch_size_per_gpu": "auto",
42
+ "wall_clock_breakdown": false
43
+ }
dpo_training.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: Train a model from SFT using DPO
5
+ """
6
+
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from glob import glob
10
+ from typing import Dict, Optional
11
+
12
+ import torch
13
+ from datasets import load_dataset
14
+ from loguru import logger
15
+ from peft import LoraConfig, TaskType
16
+ from transformers import (
17
+ AutoConfig,
18
+ BloomForCausalLM,
19
+ AutoModelForCausalLM,
20
+ AutoModel,
21
+ LlamaTokenizer,
22
+ LlamaForCausalLM,
23
+ BloomTokenizerFast,
24
+ AutoTokenizer,
25
+ HfArgumentParser,
26
+ TrainingArguments,
27
+ BitsAndBytesConfig,
28
+ )
29
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
30
+ from trl import DPOTrainer
31
+
32
+ os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
33
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
34
+
35
+ MODEL_CLASSES = {
36
+ "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
37
+ "chatglm": (AutoConfig, AutoModel, AutoTokenizer),
38
+ "llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
39
+ "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
40
+ "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
41
+ }
42
+
43
+
44
+ @dataclass
45
+ class ScriptArguments:
46
+ """
47
+ The name of the Casual LM model we wish to fine with DPO
48
+ """
49
+ # Model arguments
50
+ model_type: str = field(
51
+ default=None,
52
+ metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
53
+ )
54
+ model_name_or_path: Optional[str] = field(
55
+ default=None, metadata={"help": "The model checkpoint for weights initialization."}
56
+ )
57
+ tokenizer_name_or_path: Optional[str] = field(
58
+ default=None, metadata={"help": "The tokenizer for weights initialization."}
59
+ )
60
+ load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
61
+ load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."})
62
+ cache_dir: Optional[str] = field(
63
+ default=None,
64
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
65
+ )
66
+ use_fast_tokenizer: bool = field(
67
+ default=False,
68
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
69
+ )
70
+ torch_dtype: Optional[str] = field(
71
+ default=None,
72
+ metadata={
73
+ "help": (
74
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
75
+ "dtype will be automatically derived from the model's weights."
76
+ ),
77
+ "choices": ["auto", "bfloat16", "float16", "float32"],
78
+ },
79
+ )
80
+ device_map: Optional[str] = field(
81
+ default="auto",
82
+ metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
83
+ )
84
+ trust_remote_code: bool = field(
85
+ default=True,
86
+ metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
87
+ )
88
+ # Dataset arguments
89
+ dataset_name: Optional[str] = field(
90
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
91
+ )
92
+ dataset_config_name: Optional[str] = field(
93
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
94
+ )
95
+ train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
96
+ validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
97
+ template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
98
+ per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
99
+ per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "Eval batch size per device"})
100
+ max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
101
+ max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
102
+ min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
103
+ max_train_samples: Optional[int] = field(
104
+ default=None,
105
+ metadata={
106
+ "help": (
107
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
108
+ "value if set."
109
+ )
110
+ },
111
+ )
112
+ max_eval_samples: Optional[int] = field(
113
+ default=None,
114
+ metadata={
115
+ "help": (
116
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
117
+ "value if set."
118
+ )
119
+ },
120
+ )
121
+ overwrite_cache: bool = field(
122
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
123
+ )
124
+ validation_split_percentage: Optional[int] = field(
125
+ default=1,
126
+ metadata={
127
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
128
+ },
129
+ )
130
+ preprocessing_num_workers: Optional[int] = field(
131
+ default=4, metadata={"help": "The number of processes to use for the preprocessing."},
132
+ )
133
+ # Training arguments
134
+ use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
135
+ qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
136
+ target_modules: Optional[str] = field(default=None)
137
+ lora_rank: Optional[int] = field(default=8)
138
+ lora_dropout: Optional[float] = field(default=0.05)
139
+ lora_alpha: Optional[float] = field(default=16.0)
140
+ peft_path: Optional[str] = field(default=None)
141
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
142
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
143
+ beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for DPO loss"})
144
+ learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "Learning rate"})
145
+ lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler type"})
146
+ warmup_steps: Optional[int] = field(default=100, metadata={"help": "The number of warmup steps"})
147
+ weight_decay: Optional[float] = field(default=0.05, metadata={"help": "The weight decay"})
148
+ optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer type"})
149
+ fp16: Optional[bool] = field(default=True, metadata={"help": "Whether to use fp16"})
150
+ bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to use bf16"})
151
+ gradient_checkpointing: Optional[bool] = field(
152
+ default=True, metadata={"help": "Whether to use gradient checkpointing"}
153
+ )
154
+ gradient_accumulation_steps: Optional[int] = field(
155
+ default=4, metadata={"help": "The number of gradient accumulation steps"}
156
+ )
157
+ save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
158
+ eval_steps: Optional[int] = field(default=50, metadata={"help": "X steps to evaluate the model"})
159
+ logging_steps: Optional[int] = field(default=1, metadata={"help": "X steps to log the model"})
160
+ output_dir: Optional[str] = field(default="outputs-dpo", metadata={"help": "The output directory"})
161
+ max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
162
+ eval_strategy: Optional[str] = field(default="steps", metadata={"help": "Evaluation strategy"})
163
+ remove_unused_columns: Optional[bool] = field(
164
+ default=False,
165
+ metadata={"help": "Remove unused columns from the dataset if `datasets.Dataset` is used"},
166
+ )
167
+ report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
168
+
169
+ def __post_init__(self):
170
+ if self.model_type is None:
171
+ raise ValueError("You must specify a valid model_type to run training.")
172
+ if self.model_name_or_path is None:
173
+ raise ValueError("You must specify a valid model_name_or_path to run training.")
174
+
175
+
176
+ def print_trainable_parameters(model):
177
+ """
178
+ Prints the number of trainable parameters in the model.
179
+ """
180
+ trainable_params = 0
181
+ all_param = 0
182
+ for _, param in model.named_parameters():
183
+ all_param += param.numel()
184
+ if param.requires_grad:
185
+ trainable_params += param.numel()
186
+ print(
187
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
188
+ )
189
+
190
+
191
+ def find_all_linear_names(peft_model, int4=False, int8=False):
192
+ """Find all linear layer names in the model. reference from qlora paper."""
193
+ cls = torch.nn.Linear
194
+ if int4 or int8:
195
+ import bitsandbytes as bnb
196
+ if int4:
197
+ cls = bnb.nn.Linear4bit
198
+ elif int8:
199
+ cls = bnb.nn.Linear8bitLt
200
+ lora_module_names = set()
201
+ for name, module in peft_model.named_modules():
202
+ if isinstance(module, cls):
203
+ # last layer is not add to lora_module_names
204
+ if 'lm_head' in name:
205
+ continue
206
+ if 'output_layer' in name:
207
+ continue
208
+ names = name.split('.')
209
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
210
+ return sorted(lora_module_names)
211
+
212
+
213
+ def return_prompt_and_responses(examples) -> Dict[str, str]:
214
+ """Load the paired dataset and convert it to the necessary format.
215
+
216
+ The dataset is converted to a dictionary with the following structure:
217
+ {
218
+ 'prompt': List[str],
219
+ 'chosen': List[str],
220
+ 'rejected': List[str],
221
+ }
222
+
223
+ Prompts are structured as follows:
224
+ "Question: " + <prompt> + "\n\nAnswer: "
225
+ """
226
+ return {
227
+ "prompt": ["Question: " + question + "\n\nAnswer: " for question in examples["question"]],
228
+ "chosen": examples["response_chosen"],
229
+ "rejected": examples["response_rejected"],
230
+ }
231
+
232
+
233
+ def main():
234
+ parser = HfArgumentParser(ScriptArguments)
235
+ args = parser.parse_args_into_dataclasses()[0]
236
+ logger.info(f"Parse args: {args}")
237
+
238
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
239
+ if args.model_type == 'bloom':
240
+ args.use_fast_tokenizer = True
241
+ # Load tokenizer
242
+ tokenizer_kwargs = {
243
+ "cache_dir": args.cache_dir,
244
+ "use_fast": args.use_fast_tokenizer,
245
+ "trust_remote_code": args.trust_remote_code,
246
+ }
247
+ tokenizer_name_or_path = args.tokenizer_name_or_path
248
+ if not tokenizer_name_or_path:
249
+ tokenizer_name_or_path = args.model_name_or_path
250
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
251
+ if tokenizer.pad_token_id is None:
252
+ tokenizer.pad_token_id = 0 # set as the <unk> token
253
+
254
+ # Get datasets
255
+ if args.dataset_name is not None:
256
+ # Downloading and loading a dataset from the hub.
257
+ raw_datasets = load_dataset(
258
+ args.dataset_name,
259
+ args.dataset_config_name,
260
+ cache_dir=args.cache_dir,
261
+ )
262
+ if "validation" not in raw_datasets.keys():
263
+ raw_datasets["validation"] = load_dataset(
264
+ args.dataset_name,
265
+ args.dataset_config_name,
266
+ split=f"train[:{args.validation_split_percentage}%]",
267
+ cache_dir=args.cache_dir,
268
+ )
269
+ raw_datasets["train"] = load_dataset(
270
+ args.dataset_name,
271
+ args.dataset_config_name,
272
+ split=f"train[{args.validation_split_percentage}%:]",
273
+ cache_dir=args.cache_dir,
274
+ )
275
+ else:
276
+ data_files = {}
277
+ if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
278
+ train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
279
+ f'{args.train_file_dir}/**/*.jsonl', recursive=True)
280
+ logger.info(f"train files: {', '.join(train_data_files)}")
281
+ data_files["train"] = train_data_files
282
+ if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
283
+ eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
284
+ f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
285
+ logger.info(f"eval files: {', '.join(eval_data_files)}")
286
+ data_files["validation"] = eval_data_files
287
+ raw_datasets = load_dataset(
288
+ 'json',
289
+ data_files=data_files,
290
+ cache_dir=args.cache_dir,
291
+ )
292
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
293
+ if "validation" not in raw_datasets.keys():
294
+ raw_datasets["validation"] = load_dataset(
295
+ 'json',
296
+ data_files=data_files,
297
+ split=f"train[:{args.validation_split_percentage}%]",
298
+ cache_dir=args.cache_dir,
299
+ )
300
+ raw_datasets["train"] = load_dataset(
301
+ 'json',
302
+ data_files=data_files,
303
+ split=f"train[{args.validation_split_percentage}%:]",
304
+ cache_dir=args.cache_dir,
305
+ )
306
+ logger.info(f"Raw datasets: {raw_datasets}")
307
+
308
+ # Preprocessing the datasets
309
+ max_source_length = args.max_source_length
310
+ max_target_length = args.max_target_length
311
+ full_max_length = max_source_length + max_target_length
312
+
313
+ # Preprocess the dataset
314
+ train_dataset = None
315
+ max_train_samples = 0
316
+ if args.do_train:
317
+ if "train" not in raw_datasets:
318
+ raise ValueError("--do_train requires a train dataset")
319
+ train_dataset = raw_datasets['train']
320
+ max_train_samples = len(train_dataset)
321
+ if args.max_train_samples is not None and args.max_train_samples > 0:
322
+ max_train_samples = min(len(train_dataset), args.max_train_samples)
323
+ train_dataset = train_dataset.select(range(max_train_samples))
324
+ logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
325
+ tokenized_dataset = train_dataset.shuffle().map(
326
+ return_prompt_and_responses,
327
+ batched=True,
328
+ num_proc=args.preprocessing_num_workers,
329
+ remove_columns=train_dataset.column_names,
330
+ load_from_cache_file=not args.overwrite_cache,
331
+ desc="Running tokenizer on dataset",
332
+ )
333
+ train_dataset = tokenized_dataset.filter(
334
+ lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
335
+ and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
336
+ )
337
+ logger.debug(f"Num train_samples: {len(train_dataset)}")
338
+ logger.debug("First train example:")
339
+ logger.debug(train_dataset[0]['prompt'] + train_dataset[0]['chosen'])
340
+
341
+ eval_dataset = None
342
+ max_eval_samples = 0
343
+ if args.do_eval:
344
+ if "validation" not in raw_datasets:
345
+ raise ValueError("--do_eval requires a validation dataset")
346
+ eval_dataset = raw_datasets["validation"]
347
+ max_eval_samples = len(eval_dataset)
348
+ if args.max_eval_samples is not None and args.max_eval_samples > 0:
349
+ max_eval_samples = min(len(eval_dataset), args.max_eval_samples)
350
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
351
+ logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
352
+ eval_dataset = eval_dataset.map(
353
+ return_prompt_and_responses,
354
+ batched=True,
355
+ num_proc=args.preprocessing_num_workers,
356
+ remove_columns=eval_dataset.column_names,
357
+ load_from_cache_file=not args.overwrite_cache,
358
+ desc="Running tokenizer on dataset",
359
+ )
360
+ eval_dataset = eval_dataset.filter(
361
+ lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
362
+ and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
363
+ )
364
+ logger.debug(f"Num eval_samples: {len(eval_dataset)}")
365
+ logger.debug("First eval example:")
366
+ logger.debug(eval_dataset[0]['prompt'] + eval_dataset[0]['chosen'])
367
+
368
+ logger.info("Loading model")
369
+ torch_dtype = (
370
+ args.torch_dtype
371
+ if args.torch_dtype in ["auto", None]
372
+ else getattr(torch, args.torch_dtype)
373
+ )
374
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
375
+ ddp = world_size != 1
376
+ if ddp:
377
+ args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
378
+ if args.qlora and is_deepspeed_zero3_enabled():
379
+ logger.warning("ZeRO3 are both currently incompatible with QLoRA.")
380
+ config = config_class.from_pretrained(
381
+ args.model_name_or_path,
382
+ trust_remote_code=args.trust_remote_code,
383
+ torch_dtype=torch_dtype,
384
+ cache_dir=args.cache_dir
385
+ )
386
+ model = model_class.from_pretrained(
387
+ args.model_name_or_path,
388
+ config=config,
389
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
390
+ device_map=args.device_map,
391
+ trust_remote_code=args.trust_remote_code,
392
+ quantization_config=BitsAndBytesConfig(
393
+ load_in_4bit=True,
394
+ bnb_4bit_use_double_quant=True,
395
+ bnb_4bit_quant_type="nf4",
396
+ bnb_4bit_compute_dtype=torch_dtype,
397
+ ) if args.qlora else None,
398
+ )
399
+ model_ref = model_class.from_pretrained(
400
+ args.model_name_or_path,
401
+ config=config,
402
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
403
+ device_map=args.device_map,
404
+ trust_remote_code=args.trust_remote_code,
405
+ quantization_config=BitsAndBytesConfig(
406
+ load_in_4bit=True,
407
+ bnb_4bit_use_double_quant=True,
408
+ bnb_4bit_quant_type="nf4",
409
+ bnb_4bit_compute_dtype=torch_dtype,
410
+ ) if args.qlora else None,
411
+ )
412
+
413
+ # Initialize our Trainer
414
+ if args.gradient_checkpointing:
415
+ model.gradient_checkpointing_enable()
416
+ model.config.use_cache = False
417
+ else:
418
+ model.config.use_cache = True
419
+
420
+ training_args = TrainingArguments(
421
+ per_device_train_batch_size=args.per_device_train_batch_size,
422
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
423
+ max_steps=args.max_steps,
424
+ logging_steps=args.logging_steps,
425
+ save_steps=args.save_steps,
426
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
427
+ gradient_checkpointing=args.gradient_checkpointing,
428
+ learning_rate=args.learning_rate,
429
+ evaluation_strategy=args.eval_strategy,
430
+ eval_steps=args.eval_steps,
431
+ output_dir=args.output_dir,
432
+ report_to=args.report_to,
433
+ lr_scheduler_type=args.lr_scheduler_type,
434
+ warmup_steps=args.warmup_steps,
435
+ optim=args.optim,
436
+ bf16=args.bf16,
437
+ fp16=args.fp16,
438
+ remove_unused_columns=args.remove_unused_columns,
439
+ run_name=f"dpo_{args.model_type}",
440
+ )
441
+
442
+ # Initialize DPO trainer
443
+ target_modules = args.target_modules.split(',') if args.target_modules else None
444
+ if target_modules and 'all' in target_modules:
445
+ target_modules = find_all_linear_names(model, int4=args.load_in_4bit, int8=args.load_in_8bit)
446
+ logger.info(f"Peft target_modules: {target_modules}")
447
+ peft_config = LoraConfig(
448
+ task_type=TaskType.CAUSAL_LM,
449
+ target_modules=target_modules,
450
+ inference_mode=False,
451
+ r=args.lora_rank,
452
+ lora_alpha=args.lora_alpha,
453
+ lora_dropout=args.lora_dropout,
454
+ )
455
+ trainer = DPOTrainer(
456
+ model,
457
+ model_ref,
458
+ args=training_args,
459
+ beta=args.beta,
460
+ train_dataset=train_dataset,
461
+ eval_dataset=eval_dataset,
462
+ tokenizer=tokenizer,
463
+ peft_config=peft_config if args.use_peft else None,
464
+ max_prompt_length=args.max_source_length,
465
+ max_length=full_max_length,
466
+ )
467
+ print_trainable_parameters(trainer.model)
468
+
469
+ # Training
470
+ if args.do_train:
471
+ logger.info("*** Train ***")
472
+ train_result = trainer.train()
473
+ metrics = train_result.metrics
474
+ metrics["train_samples"] = max_train_samples
475
+ logger.debug(f"Training metrics: {metrics}")
476
+ trainer.log_metrics("train", metrics)
477
+ trainer.save_metrics("train", metrics)
478
+ trainer.save_state()
479
+ logger.info(f"Saving model checkpoint to {args.output_dir}")
480
+ trainer.save_model(args.output_dir)
481
+ tokenizer.save_pretrained(args.output_dir)
482
+ trainer.model.save_pretrained(args.output_dir)
483
+
484
+ # Evaluation
485
+ if args.do_eval and trainer.is_world_process_zero():
486
+ logger.info("*** Evaluate ***")
487
+ metrics = trainer.evaluate()
488
+ metrics["eval_samples"] = max_eval_samples
489
+ logger.debug(f"Eval metrics: {metrics}")
490
+ trainer.log_metrics("eval", metrics)
491
+ trainer.save_metrics("eval", metrics)
492
+
493
+
494
+ if __name__ == "__main__":
495
+ main()
gradio_demo.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+
6
+ pip install gradio
7
+ pip install mdtex2html
8
+ """
9
+ import argparse
10
+ import os
11
+ from threading import Thread
12
+
13
+ import gradio as gr
14
+ import mdtex2html
15
+ import torch
16
+ from peft import PeftModel
17
+ from transformers import (
18
+ AutoModel,
19
+ AutoTokenizer,
20
+ AutoModelForCausalLM,
21
+ BloomForCausalLM,
22
+ BloomTokenizerFast,
23
+ LlamaTokenizer,
24
+ LlamaForCausalLM,
25
+ GenerationConfig,
26
+ TextIteratorStreamer,
27
+ )
28
+
29
+ from supervised_finetuning import get_conv_template
30
+
31
+ MODEL_CLASSES = {
32
+ "bloom": (BloomForCausalLM, BloomTokenizerFast),
33
+ "chatglm": (AutoModel, AutoTokenizer),
34
+ "llama": (LlamaForCausalLM, LlamaTokenizer),
35
+ "baichuan": (AutoModelForCausalLM, AutoTokenizer),
36
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
37
+ }
38
+
39
+
40
+ @torch.inference_mode()
41
+ def stream_generate_answer(
42
+ model,
43
+ tokenizer,
44
+ prompt,
45
+ device,
46
+ max_new_tokens=512,
47
+ temperature=0.7,
48
+ top_p=0.8,
49
+ repetition_penalty=1.0,
50
+ context_len=2048,
51
+ ):
52
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
53
+ input_ids = tokenizer(prompt).input_ids
54
+ max_src_len = context_len - max_new_tokens - 8
55
+ input_ids = input_ids[-max_src_len:]
56
+ generation_kwargs = dict(
57
+ input_ids=torch.as_tensor([input_ids]).to(device),
58
+ max_new_tokens=max_new_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ repetition_penalty=repetition_penalty,
62
+ streamer=streamer,
63
+ )
64
+
65
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
66
+ thread.start()
67
+
68
+ yield from streamer
69
+
70
+
71
+ def main():
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument('--model_type', default=None, type=str, required=True)
74
+ parser.add_argument('--base_model', default=None, type=str, required=True)
75
+ parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
76
+ parser.add_argument('--tokenizer_path', default=None, type=str)
77
+ parser.add_argument('--template_name', default="vicuna", type=str,
78
+ help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
79
+ parser.add_argument('--gpus', default="0", type=str)
80
+ parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
81
+ parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
82
+ args = parser.parse_args()
83
+ if args.only_cpu is True:
84
+ args.gpus = ""
85
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
86
+
87
+ def postprocess(self, y):
88
+ if y is None:
89
+ return []
90
+ for i, (message, response) in enumerate(y):
91
+ y[i] = (
92
+ None if message is None else mdtex2html.convert((message)),
93
+ None if response is None else mdtex2html.convert(response),
94
+ )
95
+ return y
96
+
97
+ gr.Chatbot.postprocess = postprocess
98
+
99
+ load_type = torch.float16
100
+ if torch.cuda.is_available():
101
+ device = torch.device(0)
102
+ else:
103
+ device = torch.device('cpu')
104
+
105
+ if args.tokenizer_path is None:
106
+ args.tokenizer_path = args.base_model
107
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
108
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
109
+ base_model = model_class.from_pretrained(
110
+ args.base_model,
111
+ load_in_8bit=False,
112
+ torch_dtype=load_type,
113
+ low_cpu_mem_usage=True,
114
+ device_map='auto',
115
+ trust_remote_code=True,
116
+ )
117
+ try:
118
+ base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
119
+ except OSError:
120
+ print("Failed to load generation config, use default.")
121
+ if args.resize_emb:
122
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
123
+ tokenzier_vocab_size = len(tokenizer)
124
+ print(f"Vocab of the base model: {model_vocab_size}")
125
+ print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
126
+ if model_vocab_size != tokenzier_vocab_size:
127
+ print("Resize model embeddings to fit tokenizer")
128
+ base_model.resize_token_embeddings(tokenzier_vocab_size)
129
+ if args.lora_model:
130
+ model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
131
+ print("loaded lora model")
132
+ else:
133
+ model = base_model
134
+ if device == torch.device('cpu'):
135
+ model.float()
136
+
137
+ model.eval()
138
+
139
+ def reset_user_input():
140
+ return gr.update(value='')
141
+
142
+ def reset_state():
143
+ return [], []
144
+
145
+ prompt_template = get_conv_template(args.template_name)
146
+ stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
147
+ history = []
148
+
149
+ def predict(
150
+ input,
151
+ chatbot,
152
+ history,
153
+ max_new_tokens,
154
+ temperature,
155
+ top_p
156
+ ):
157
+ now_input = input
158
+ chatbot.append((input, ""))
159
+ history = history or []
160
+ history.append([now_input, ''])
161
+
162
+ prompt = prompt_template.get_prompt(messages=history)
163
+ response = ""
164
+
165
+ for new_text in stream_generate_answer(
166
+ model,
167
+ tokenizer,
168
+ prompt,
169
+ device,
170
+ max_new_tokens=max_new_tokens,
171
+ temperature=temperature,
172
+ top_p=top_p,
173
+ ):
174
+ stop = False
175
+ pos = new_text.find(stop_str)
176
+ if pos != -1:
177
+ new_text = new_text[:pos]
178
+ stop = True
179
+ response += new_text
180
+ new_history = history + [(now_input, response)]
181
+ chatbot[-1] = (now_input, response)
182
+ yield chatbot, new_history
183
+ if stop:
184
+ break
185
+
186
+ with gr.Blocks() as demo:
187
+ gr.HTML("""<h1 align="center">MedicalGPT</h1>""")
188
+ gr.Markdown(
189
+ "> 为了促进医疗行业大模型的开放研究,本项目开源了MedicalGPT医疗大模型")
190
+ chatbot = gr.Chatbot()
191
+ with gr.Row():
192
+ with gr.Column(scale=4):
193
+ with gr.Column(scale=12):
194
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
195
+ container=False)
196
+ with gr.Column(min_width=32, scale=1):
197
+ submitBtn = gr.Button("Submit", variant="primary")
198
+ with gr.Column(scale=1):
199
+ emptyBtn = gr.Button("Clear History")
200
+ max_length = gr.Slider(
201
+ 0, 4096, value=512, step=1.0, label="Maximum length", interactive=True)
202
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01,
203
+ label="Top P", interactive=True)
204
+ temperature = gr.Slider(
205
+ 0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
206
+ history = gr.State([])
207
+ submitBtn.click(predict, [user_input, chatbot, history, max_length, temperature, top_p], [chatbot, history],
208
+ show_progress=True)
209
+ submitBtn.click(reset_user_input, [], [user_input])
210
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
211
+ demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=8082)
212
+
213
+
214
+ if __name__ == '__main__':
215
+ main()
inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
+ import argparse
7
+ import json
8
+ import os
9
+ from threading import Thread
10
+
11
+ import torch
12
+ from peft import PeftModel
13
+ from transformers import (
14
+ AutoModel,
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
17
+ BloomForCausalLM,
18
+ BloomTokenizerFast,
19
+ LlamaTokenizer,
20
+ LlamaForCausalLM,
21
+ TextIteratorStreamer,
22
+ GenerationConfig,
23
+ )
24
+
25
+ from supervised_finetuning import get_conv_template
26
+
27
+ MODEL_CLASSES = {
28
+ "bloom": (BloomForCausalLM, BloomTokenizerFast),
29
+ "chatglm": (AutoModel, AutoTokenizer),
30
+ "llama": (LlamaForCausalLM, LlamaTokenizer),
31
+ "baichuan": (AutoModelForCausalLM, AutoTokenizer),
32
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
33
+ }
34
+
35
+
36
+ @torch.inference_mode()
37
+ def stream_generate_answer(
38
+ model,
39
+ tokenizer,
40
+ prompt,
41
+ device,
42
+ do_print=True,
43
+ max_new_tokens=512,
44
+ temperature=0.7,
45
+ repetition_penalty=1.0,
46
+ context_len=2048,
47
+ stop_str="</s>",
48
+ ):
49
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
50
+ input_ids = tokenizer(prompt).input_ids
51
+ max_src_len = context_len - max_new_tokens - 8
52
+ input_ids = input_ids[-max_src_len:]
53
+ generation_kwargs = dict(
54
+ input_ids=torch.as_tensor([input_ids]).to(device),
55
+ max_new_tokens=max_new_tokens,
56
+ temperature=temperature,
57
+ repetition_penalty=repetition_penalty,
58
+ streamer=streamer,
59
+ )
60
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
61
+ thread.start()
62
+
63
+ generated_text = ""
64
+ for new_text in streamer:
65
+ stop = False
66
+ pos = new_text.find(stop_str)
67
+ if pos != -1:
68
+ new_text = new_text[:pos]
69
+ stop = True
70
+ generated_text += new_text
71
+ if do_print:
72
+ print(new_text, end="", flush=True)
73
+ if stop:
74
+ break
75
+ if do_print:
76
+ print()
77
+ return generated_text
78
+
79
+
80
+ def main():
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument('--model_type', default=None, type=str, required=True)
83
+ parser.add_argument('--base_model', default=None, type=str, required=True)
84
+ parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
85
+ parser.add_argument('--tokenizer_path', default=None, type=str)
86
+ parser.add_argument('--template_name', default="vicuna", type=str,
87
+ help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
88
+ parser.add_argument("--temperature", type=float, default=0.7)
89
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
90
+ parser.add_argument("--max_new_tokens", type=int, default=512)
91
+ parser.add_argument('--data_file', default=None, type=str,
92
+ help="A file that contains instructions (one instruction per line)")
93
+ parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)")
94
+ parser.add_argument('--predictions_file', default='./predictions.json', type=str)
95
+ parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
96
+ parser.add_argument('--gpus', default="0", type=str)
97
+ parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
98
+ args = parser.parse_args()
99
+ print(args)
100
+ if args.only_cpu is True:
101
+ args.gpus = ""
102
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
103
+ load_type = torch.float16
104
+ if torch.cuda.is_available():
105
+ device = torch.device(0)
106
+ else:
107
+ device = torch.device('cpu')
108
+ if args.tokenizer_path is None:
109
+ args.tokenizer_path = args.base_model
110
+
111
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
112
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
113
+ base_model = model_class.from_pretrained(
114
+ args.base_model,
115
+ load_in_8bit=False,
116
+ torch_dtype=load_type,
117
+ low_cpu_mem_usage=True,
118
+ device_map='auto',
119
+ trust_remote_code=True,
120
+ )
121
+ try:
122
+ base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
123
+ except OSError:
124
+ print("Failed to load generation config, use default.")
125
+ if args.resize_emb:
126
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
127
+ tokenzier_vocab_size = len(tokenizer)
128
+ print(f"Vocab of the base model: {model_vocab_size}")
129
+ print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
130
+ if model_vocab_size != tokenzier_vocab_size:
131
+ print("Resize model embeddings to fit tokenizer")
132
+ base_model.resize_token_embeddings(tokenzier_vocab_size)
133
+
134
+ if args.lora_model:
135
+ model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
136
+ print("Loaded lora model")
137
+ else:
138
+ model = base_model
139
+ if device == torch.device('cpu'):
140
+ model.float()
141
+ model.eval()
142
+ print(tokenizer)
143
+ # test data
144
+ if args.data_file is None:
145
+ examples = ["介绍下北京", "乙肝和丙肝的区别?"]
146
+ else:
147
+ with open(args.data_file, 'r') as f:
148
+ examples = [l.strip() for l in f.readlines()]
149
+ print("first 10 examples:")
150
+ for example in examples[:10]:
151
+ print(example)
152
+
153
+ # Chat
154
+ prompt_template = get_conv_template(args.template_name)
155
+ stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
156
+
157
+ if args.interactive:
158
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
159
+ history = []
160
+ while True:
161
+ try:
162
+ query = input(f"{prompt_template.roles[0]}: ")
163
+ except UnicodeDecodeError:
164
+ print("Detected decoding error at the inputs, please try again.")
165
+ continue
166
+ except Exception:
167
+ raise
168
+ if query == "":
169
+ print("Please input text, try again.")
170
+ continue
171
+ if query.strip() == "exit":
172
+ print("exit...")
173
+ break
174
+ if query.strip() == "clear":
175
+ history = []
176
+ print("history cleared.")
177
+ continue
178
+
179
+ print(f"{prompt_template.roles[1]}: ", end="", flush=True)
180
+
181
+ history.append([query, ''])
182
+ prompt = prompt_template.get_prompt(messages=history)
183
+ response = stream_generate_answer(
184
+ model,
185
+ tokenizer,
186
+ prompt,
187
+ device,
188
+ do_print=True,
189
+ max_new_tokens=args.max_new_tokens,
190
+ temperature=args.temperature,
191
+ repetition_penalty=args.repetition_penalty,
192
+ stop_str=stop_str,
193
+ )
194
+ if history:
195
+ history[-1][-1] = response.strip()
196
+ else:
197
+ print("Start inference.")
198
+ results = []
199
+ for index, example in enumerate(examples):
200
+ # Single turn inference
201
+ history = [[example, '']]
202
+ prompt = prompt_template.get_prompt(messages=history)
203
+ response = stream_generate_answer(
204
+ model,
205
+ tokenizer,
206
+ prompt,
207
+ device,
208
+ do_print=False,
209
+ max_new_tokens=args.max_new_tokens,
210
+ temperature=args.temperature,
211
+ repetition_penalty=args.repetition_penalty,
212
+ stop_str=stop_str,
213
+ )
214
+ response = response.strip()
215
+ print(f"======={index}=======")
216
+ print(f"Input: {example}\n")
217
+ print(f"Output: {response}\n")
218
+ results.append({"Input": prompt, "Output": response})
219
+
220
+ with open(args.predictions_file, 'w', encoding='utf-8') as f:
221
+ json.dump(results, f, ensure_ascii=False, indent=2)
222
+
223
+
224
+ if __name__ == '__main__':
225
+ main()
merge_peft_adapter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+
6
+ Usage:
7
+ python merge_peft_adapter.py \
8
+ --base_model_name_or_path path/to/llama/model \
9
+ --tokenizer_path path/to/llama/tokenizer \
10
+ --peft_model_path path/to/lora/model \
11
+ --output_dir path/to/output/dir
12
+
13
+ after merged, chatglm and baichuan model need copy python script to output dir.
14
+ """
15
+
16
+ import argparse
17
+
18
+ import torch
19
+ from peft import PeftModel, PeftConfig
20
+ from transformers import (
21
+ AutoModel,
22
+ AutoTokenizer,
23
+ BloomForCausalLM,
24
+ BloomTokenizerFast,
25
+ AutoModelForCausalLM,
26
+ LlamaTokenizer,
27
+ LlamaForCausalLM,
28
+ AutoModelForSequenceClassification,
29
+ )
30
+
31
+ MODEL_CLASSES = {
32
+ "bloom": (BloomForCausalLM, BloomTokenizerFast),
33
+ "chatglm": (AutoModel, AutoTokenizer),
34
+ "llama": (LlamaForCausalLM, LlamaTokenizer),
35
+ "baichuan": (AutoModelForCausalLM, AutoTokenizer),
36
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
37
+ }
38
+
39
+
40
+ def main():
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument('--model_type', default=None, type=str, required=True)
43
+ parser.add_argument('--base_model_name_or_path', default=None, required=True, type=str,
44
+ help="Base model name or path")
45
+ parser.add_argument('--tokenizer_path', default=None, type=str,
46
+ help="Please specify tokenization path.")
47
+ parser.add_argument('--peft_model_path', default=None, required=True, type=str,
48
+ help="Please specify LoRA model to be merged.")
49
+ parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
50
+ parser.add_argument('--output_dir', default='./merged', type=str)
51
+ args = parser.parse_args()
52
+ print(args)
53
+
54
+ base_model_path = args.base_model_name_or_path
55
+ peft_model_path = args.peft_model_path
56
+ output_dir = args.output_dir
57
+ print(f"Base model: {base_model_path}")
58
+ print(f"LoRA model: {peft_model_path}")
59
+ peft_config = PeftConfig.from_pretrained(peft_model_path)
60
+
61
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
62
+ if peft_config.task_type == "SEQ_CLS":
63
+ print("Loading LoRA for sequence classification model")
64
+ if args.model_type == "chatglm":
65
+ raise ValueError("chatglm does not support sequence classification")
66
+ base_model = AutoModelForSequenceClassification.from_pretrained(
67
+ base_model_path,
68
+ load_in_8bit=False,
69
+ torch_dtype=torch.float16,
70
+ trust_remote_code=True,
71
+ device_map="auto",
72
+ )
73
+ else:
74
+ print("Loading LoRA for causal language model")
75
+ base_model = model_class.from_pretrained(
76
+ base_model_path,
77
+ load_in_8bit=False,
78
+ torch_dtype=torch.float16,
79
+ trust_remote_code=True,
80
+ device_map="auto",
81
+ )
82
+ if args.tokenizer_path:
83
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
84
+ else:
85
+ tokenizer = tokenizer_class.from_pretrained(peft_model_path, trust_remote_code=True)
86
+ if args.resize_emb:
87
+ base_model_token_size = base_model.get_input_embeddings().weight.size(0)
88
+ if base_model_token_size != len(tokenizer):
89
+ base_model.resize_token_embeddings(len(tokenizer))
90
+ print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
91
+
92
+ lora_model = PeftModel.from_pretrained(
93
+ base_model,
94
+ peft_model_path,
95
+ device_map="auto",
96
+ torch_dtype=torch.float16,
97
+ )
98
+ lora_model.eval()
99
+ print(f"Merging with merge_and_unload...")
100
+ base_model = lora_model.merge_and_unload()
101
+
102
+ print("Saving to Hugging Face format...")
103
+ tokenizer.save_pretrained(output_dir)
104
+ base_model.save_pretrained(output_dir)
105
+ print(f"Done! model saved to {output_dir}")
106
+
107
+
108
+ if __name__ == '__main__':
109
+ main()
merge_tokenizers.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
+ import os
7
+
8
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
+ from transformers import LlamaTokenizer
10
+ from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
11
+ import sentencepiece as spm
12
+ import argparse
13
+
14
+
15
+ def is_chinese(uchar):
16
+ """判断一个unicode是否是汉字"""
17
+ return '\u4e00' <= uchar <= '\u9fa5'
18
+
19
+
20
+ def is_chinese_string(string):
21
+ """判断是否全为汉字"""
22
+ return all(is_chinese(c) for c in string)
23
+
24
+
25
+ def load_baichuan_vocab(vocab_file):
26
+ words = set()
27
+ with open(vocab_file, "r", encoding="utf-8") as f:
28
+ for line in f:
29
+ if line.strip():
30
+ words.add(line.strip().split()[0])
31
+ return words
32
+
33
+
34
+ def load_jieba_vocab(jieba_vocab_file):
35
+ # Read jieba vocab and sort by freq
36
+ with open(jieba_vocab_file, "r", encoding="utf-8") as f:
37
+ lines = f.readlines()
38
+ word_freqs = [line.strip().split() for line in lines]
39
+ word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
40
+ return word_freqs
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
46
+ parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
47
+ parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
48
+ parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
49
+ parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
50
+ parser.add_argument('--jieba_word_size', default=20000, type=int)
51
+
52
+ args = parser.parse_args()
53
+ print(args)
54
+
55
+ # load
56
+ llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
57
+ chinese_sp_model = spm.SentencePieceProcessor()
58
+ chinese_sp_model.Load(args.domain_sp_model_file)
59
+
60
+ llama_spm = sp_pb2_model.ModelProto()
61
+ llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
62
+ chinese_spm = sp_pb2_model.ModelProto()
63
+ chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
64
+
65
+ # print number of tokens
66
+ print(len(llama_tokenizer), len(chinese_sp_model))
67
+ print(llama_tokenizer.all_special_tokens)
68
+ print(llama_tokenizer.all_special_ids)
69
+ print(llama_tokenizer.special_tokens_map)
70
+
71
+ # Add Chinese tokens to LLaMA tokenizer
72
+ llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
73
+
74
+ print(len(llama_spm_tokens_set))
75
+ print(f"Before:{len(llama_spm_tokens_set)}")
76
+ added_set = set()
77
+ for p in chinese_spm.pieces:
78
+ piece = p.piece
79
+ if piece not in llama_spm_tokens_set:
80
+ # print('picec', piece)
81
+ new_p = sp_pb2_model.ModelProto().SentencePiece()
82
+ new_p.piece = piece
83
+ new_p.score = 0
84
+ llama_spm.pieces.append(new_p)
85
+ added_set.add(piece)
86
+ print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")
87
+
88
+ vocab = load_baichuan_vocab(args.baichuan_vocab_file)
89
+ print('baichuan vocab len:', len(vocab))
90
+ baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
91
+ print('baichuan chinese vocab size:', len(baichuan_vocab_set))
92
+ print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
93
+ for p in baichuan_vocab_set:
94
+ piece = p
95
+ if piece not in llama_spm_tokens_set and piece not in added_set:
96
+ # print('baichuan picec', piece)
97
+ new_p = sp_pb2_model.ModelProto().SentencePiece()
98
+ new_p.piece = piece
99
+ new_p.score = 0
100
+ llama_spm.pieces.append(new_p)
101
+ added_set.add(piece)
102
+ print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")
103
+
104
+ if args.add_jieba:
105
+ word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
106
+ top_words = word_freqs[:args.jieba_word_size]
107
+ print('jieba top10 freq words:', top_words[:10])
108
+ jieba_vocab_set = set([i[0] for i in top_words if i])
109
+ print('jieba_vocab_set size:', len(jieba_vocab_set))
110
+ print('jieba_vocab head:', list(jieba_vocab_set)[:3])
111
+ for p in jieba_vocab_set:
112
+ piece = p
113
+ if piece not in llama_spm_tokens_set and piece not in added_set:
114
+ # print('jieba picec', piece)
115
+ new_p = sp_pb2_model.ModelProto().SentencePiece()
116
+ new_p.piece = piece
117
+ new_p.score = 0
118
+ llama_spm.pieces.append(new_p)
119
+ print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
120
+
121
+ # Save
122
+ output_sp_dir = 'merged_tokenizer_sp'
123
+ output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer
124
+ os.makedirs(output_sp_dir, exist_ok=True)
125
+ with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
126
+ f.write(llama_spm.SerializeToString())
127
+ tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')
128
+
129
+ tokenizer.save_pretrained(output_hf_dir)
130
+ print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
131
+
132
+ # Test
133
+ llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
134
+ chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
135
+ print(chinese_llama_tokenizer.all_special_tokens)
136
+ print(chinese_llama_tokenizer.all_special_ids)
137
+ print(chinese_llama_tokenizer.special_tokens_map)
138
+ print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
139
+ text = '''this is a test, hello world. thisisatesthelloworld,
140
+ 慕容复来到河边,姑苏慕容氏在外面丢了人。
141
+ 1号店一周岁了,我们一古脑儿买了10斤零食。
142
+ 巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
143
+ 白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
144
+ print("Test text:\n", text)
145
+ print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
146
+ print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
147
+
148
+
149
+ if __name__ == '__main__':
150
+ main()
pretraining.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 XuMing([email protected]) and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
17
+
18
+ part of this code is adapted from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
19
+ """
20
+ import math
21
+ import os
22
+ from dataclasses import dataclass, field
23
+ from glob import glob
24
+ from itertools import chain
25
+ from typing import Optional, List, Dict, Any, Mapping
26
+
27
+ import numpy as np
28
+ import torch
29
+ from datasets import load_dataset
30
+ from loguru import logger
31
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
32
+ from sklearn.metrics import accuracy_score
33
+ from transformers import (
34
+ AutoConfig,
35
+ BloomForCausalLM,
36
+ AutoModelForCausalLM,
37
+ AutoModel,
38
+ LlamaTokenizer,
39
+ LlamaForCausalLM,
40
+ BloomTokenizerFast,
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ Trainer,
44
+ TrainingArguments,
45
+ is_torch_tpu_available,
46
+ set_seed,
47
+ )
48
+ from transformers.trainer import TRAINING_ARGS_NAME
49
+ from transformers.utils.versions import require_version
50
+
51
+ MODEL_CLASSES = {
52
+ "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
53
+ "chatglm": (AutoConfig, AutoModel, AutoTokenizer),
54
+ "llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
55
+ "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
56
+ "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
57
+ }
58
+
59
+
60
+ @dataclass
61
+ class ModelArguments:
62
+ """
63
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
64
+ """
65
+
66
+ model_type: str = field(
67
+ default=None,
68
+ metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
69
+ )
70
+ model_name_or_path: Optional[str] = field(
71
+ default=None,
72
+ metadata={
73
+ "help": (
74
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
75
+ )
76
+ },
77
+ )
78
+ tokenizer_name_or_path: Optional[str] = field(
79
+ default=None,
80
+ metadata={
81
+ "help": (
82
+ "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
83
+ )
84
+ },
85
+ )
86
+ load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
87
+ cache_dir: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=False,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ torch_dtype: Optional[str] = field(
96
+ default=None,
97
+ metadata={
98
+ "help": (
99
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
100
+ "dtype will be automatically derived from the model's weights."
101
+ ),
102
+ "choices": ["auto", "bfloat16", "float16", "float32"],
103
+ },
104
+ )
105
+ device_map: Optional[str] = field(
106
+ default="auto",
107
+ metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
108
+ )
109
+ trust_remote_code: bool = field(
110
+ default=True,
111
+ metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
112
+ )
113
+
114
+ def __post_init__(self):
115
+ if self.model_type is None:
116
+ raise ValueError(
117
+ "You must specify a valid model_type to run training. Available model types are " + ", ".join(
118
+ MODEL_CLASSES.keys()))
119
+ if self.model_name_or_path is None:
120
+ raise ValueError("You must specify a valid model_name_or_path to run training.")
121
+
122
+
123
+ @dataclass
124
+ class DataTrainingArguments:
125
+ """
126
+ Arguments pertaining to what data we are going to input our model for training and eval.
127
+ """
128
+
129
+ dataset_name: Optional[str] = field(
130
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
131
+ )
132
+ dataset_config_name: Optional[str] = field(
133
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
134
+ )
135
+ train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train text data file folder."})
136
+ validation_file_dir: Optional[str] = field(
137
+ default=None,
138
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on text file folder."},
139
+ )
140
+ max_train_samples: Optional[int] = field(
141
+ default=None,
142
+ metadata={
143
+ "help": (
144
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
145
+ "value if set."
146
+ )
147
+ },
148
+ )
149
+ max_eval_samples: Optional[int] = field(
150
+ default=None,
151
+ metadata={
152
+ "help": (
153
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
154
+ "value if set."
155
+ )
156
+ },
157
+ )
158
+ streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
159
+ block_size: Optional[int] = field(
160
+ default=1024,
161
+ metadata={
162
+ "help": (
163
+ "Optional input sequence length after tokenization. "
164
+ "The training dataset will be truncated in block of this size for training. "
165
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
166
+ )
167
+ },
168
+ )
169
+ overwrite_cache: bool = field(
170
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
171
+ )
172
+ validation_split_percentage: Optional[int] = field(
173
+ default=1,
174
+ metadata={
175
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
176
+ },
177
+ )
178
+ preprocessing_num_workers: Optional[int] = field(
179
+ default=None,
180
+ metadata={"help": "The number of processes to use for the preprocessing."},
181
+ )
182
+ keep_linebreaks: bool = field(
183
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
184
+ )
185
+
186
+ def __post_init__(self):
187
+ if self.streaming:
188
+ require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
189
+
190
+
191
+ @dataclass
192
+ class PeftArguments(TrainingArguments):
193
+ use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
194
+ target_modules: Optional[str] = field(default="all")
195
+ lora_rank: Optional[int] = field(default=8)
196
+ lora_dropout: Optional[float] = field(default=0.05)
197
+ lora_alpha: Optional[float] = field(default=32.0)
198
+ modules_to_save: Optional[str] = field(default=None)
199
+ peft_path: Optional[str] = field(default=None)
200
+
201
+
202
+ def accuracy(predictions, references, normalize=True, sample_weight=None):
203
+ return {
204
+ "accuracy": float(accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight))
205
+ }
206
+
207
+
208
+ def compute_metrics(eval_preds):
209
+ preds, labels = eval_preds
210
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
211
+ # by preprocess_logits_for_metrics, we need to shift the labels
212
+ labels = labels[:, 1:].reshape(-1)
213
+ preds = preds[:, :-1].reshape(-1)
214
+ return accuracy(predictions=preds, references=labels)
215
+
216
+
217
+ def preprocess_logits_for_metrics(logits, labels):
218
+ if isinstance(logits, tuple):
219
+ # Depending on the model and config, logits may contain extra tensors,
220
+ # like past_key_values, but logits always come first
221
+ logits = logits[0]
222
+ return logits.argmax(dim=-1)
223
+
224
+
225
+ def fault_tolerance_data_collator(features: List) -> Dict[str, Any]:
226
+ if not isinstance(features[0], Mapping):
227
+ features = [vars(f) for f in features]
228
+ first = features[0]
229
+ batch = {}
230
+
231
+ # Special handling for labels.
232
+ # Ensure that tensor is created with the correct type
233
+ if "label" in first and first["label"] is not None:
234
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
235
+ dtype = torch.long if isinstance(label, int) else torch.float
236
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
237
+ elif "label_ids" in first and first["label_ids"] is not None:
238
+ if isinstance(first["label_ids"], torch.Tensor):
239
+ batch["labels"] = torch.stack([f["label_ids"] for f in features])
240
+ else:
241
+ dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
242
+ batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
243
+
244
+ # Handling of all other possible keys.
245
+ # Again, we will use the first element to figure out which key/values are not None for this model.
246
+ try:
247
+ for k, v in first.items():
248
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
249
+ if isinstance(v, torch.Tensor):
250
+ batch[k] = torch.stack([f[k] for f in features])
251
+ elif isinstance(v, np.ndarray):
252
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
253
+ else:
254
+ batch[k] = torch.tensor([f[k] for f in features])
255
+ except ValueError: # quick fix by simply take the first example
256
+ for k, v in first.items():
257
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
258
+ if isinstance(v, torch.Tensor):
259
+ batch[k] = torch.stack([features[0][k]] * len(features))
260
+ elif isinstance(v, np.ndarray):
261
+ batch[k] = torch.tensor(np.stack([features[0][k]] * len(features)))
262
+ else:
263
+ batch[k] = torch.tensor([features[0][k]] * len(features))
264
+
265
+ return batch
266
+
267
+
268
+ class GroupTextsBuilder:
269
+ def __init__(self, max_seq_length):
270
+ self.max_seq_length = max_seq_length
271
+
272
+ def __call__(self, examples):
273
+ # Concatenate all texts.
274
+ firsts = {k: examples[k][0][0] for k in examples.keys()}
275
+ lasts = {k: examples[k][0][-1] for k in examples.keys()}
276
+ contents = {k: sum([vi[1:-1] for vi in v], []) for k, v in examples.items()}
277
+ total_length = len(contents[list(examples.keys())[0]])
278
+
279
+ content_length = self.max_seq_length - 2
280
+ if total_length >= content_length:
281
+ total_length = (total_length // content_length) * content_length
282
+ # Split by chunks of max_len.
283
+ result = {
284
+ k: [[firsts[k]] + t[i: i + content_length] + [lasts[k]] for i in range(0, total_length, content_length)] for
285
+ k, t in contents.items()}
286
+ return result
287
+
288
+
289
+ class SavePeftModelTrainer(Trainer):
290
+ """
291
+ Trainer for lora models
292
+ """
293
+
294
+ def save_model(self, output_dir=None, _internal_call=False):
295
+ """Save the LoRA model."""
296
+ os.makedirs(output_dir, exist_ok=True)
297
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
298
+ self.model.save_pretrained(output_dir)
299
+
300
+
301
+ def save_model(output_dir, model, tokenizer, args):
302
+ """Save the model and the tokenizer."""
303
+ os.makedirs(output_dir, exist_ok=True)
304
+
305
+ # Take care of distributed/parallel training
306
+ model_to_save = model.module if hasattr(model, "module") else model
307
+ model_to_save.save_pretrained(output_dir)
308
+ tokenizer.save_pretrained(output_dir)
309
+ torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
310
+
311
+
312
+ def print_trainable_parameters(model):
313
+ """
314
+ Prints the number of trainable parameters in the model.
315
+ """
316
+ trainable_params = 0
317
+ all_param = 0
318
+ for _, param in model.named_parameters():
319
+ all_param += param.numel()
320
+ if param.requires_grad:
321
+ trainable_params += param.numel()
322
+ print(
323
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
324
+ )
325
+
326
+
327
+ def find_all_linear_names(peft_model, int4=False, int8=False):
328
+ """Find all linear layer names in the model. reference from qlora paper."""
329
+ cls = torch.nn.Linear
330
+ if int4 or int8:
331
+ import bitsandbytes as bnb
332
+ if int4:
333
+ cls = bnb.nn.Linear4bit
334
+ elif int8:
335
+ cls = bnb.nn.Linear8bitLt
336
+ lora_module_names = set()
337
+ for name, module in peft_model.named_modules():
338
+ if isinstance(module, cls):
339
+ # last layer is not add to lora_module_names
340
+ if 'lm_head' in name:
341
+ continue
342
+ names = name.split('.')
343
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
344
+ return sorted(lora_module_names)
345
+
346
+
347
+ def main():
348
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ logger.info(f"Model args: {model_args}")
352
+ logger.info(f"Data args: {data_args}")
353
+ logger.info(f"Training args: {training_args}")
354
+ logger.info(
355
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
356
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
357
+ )
358
+
359
+ # Set seed before initializing model.
360
+ set_seed(training_args.seed)
361
+
362
+ # Load tokenizer
363
+ if not model_args.model_type:
364
+ raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
365
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
366
+
367
+ tokenizer_kwargs = {
368
+ "cache_dir": model_args.cache_dir,
369
+ "use_fast": model_args.use_fast_tokenizer,
370
+ "trust_remote_code": model_args.trust_remote_code,
371
+ }
372
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
373
+ if not tokenizer_name_or_path:
374
+ tokenizer_name_or_path = model_args.model_name_or_path
375
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
376
+
377
+ # Preprocessing the datasets.
378
+ def tokenize_function(examples):
379
+ return tokenizer(examples["text"])
380
+
381
+ if data_args.block_size is None:
382
+ block_size = tokenizer.model_max_length
383
+ if block_size > 2048:
384
+ logger.warning(
385
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
386
+ " of 2048. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
387
+ " override this default with `--block_size xxx`."
388
+ )
389
+ else:
390
+ if data_args.block_size > tokenizer.model_max_length:
391
+ logger.warning(
392
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
393
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
394
+ )
395
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
396
+
397
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
398
+ def group_texts(examples):
399
+ # Concatenate all texts.
400
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
401
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
402
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
403
+ # customize this part to your needs.
404
+ if total_length >= block_size:
405
+ total_length = (total_length // block_size) * block_size
406
+ # Split by chunks of max_len.
407
+ result = {
408
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
409
+ for k, t in concatenated_examples.items()
410
+ }
411
+ result["labels"] = result["input_ids"].copy()
412
+ return result
413
+
414
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
415
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
416
+ # (the dataset will be downloaded automatically from the datasets Hub).
417
+ #
418
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
419
+ # 'text' is found. You can easily tweak this behavior (see below).
420
+ #
421
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
422
+ # download the dataset.
423
+ if data_args.dataset_name is not None:
424
+ # Downloading and loading a dataset from the hub.
425
+ raw_datasets = load_dataset(
426
+ data_args.dataset_name,
427
+ data_args.dataset_config_name,
428
+ cache_dir=model_args.cache_dir,
429
+ streaming=data_args.streaming,
430
+ )
431
+ if "validation" not in raw_datasets.keys():
432
+ raw_datasets["validation"] = load_dataset(
433
+ data_args.dataset_name,
434
+ data_args.dataset_config_name,
435
+ split=f"train[:{data_args.validation_split_percentage}%]",
436
+ cache_dir=model_args.cache_dir,
437
+ streaming=data_args.streaming,
438
+ )
439
+ raw_datasets["train"] = load_dataset(
440
+ data_args.dataset_name,
441
+ data_args.dataset_config_name,
442
+ split=f"train[{data_args.validation_split_percentage}%:]",
443
+ cache_dir=model_args.cache_dir,
444
+ streaming=data_args.streaming,
445
+ )
446
+ else:
447
+ data_files = {}
448
+ dataset_args = {}
449
+ if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
450
+ train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
451
+ f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
452
+ f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
453
+ logger.info(f"train files: {train_data_files}")
454
+ # Train data files must be same type, e.g. all txt or all jsonl
455
+ types = [f.split('.')[-1] for f in train_data_files]
456
+ if len(set(types)) > 1:
457
+ raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
458
+ data_files["train"] = train_data_files
459
+ if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
460
+ eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.txt', recursive=True) + glob(
461
+ f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
462
+ f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
463
+ logger.info(f"eval files: {eval_data_files}")
464
+ data_files["validation"] = eval_data_files
465
+ # Train data files must be same type, e.g. all txt or all jsonl
466
+ types = [f.split('.')[-1] for f in eval_data_files]
467
+ if len(set(types)) > 1:
468
+ raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
469
+ extension = "text" if data_files["train"][0].endswith('txt') else 'json'
470
+ if extension == "text":
471
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
472
+ raw_datasets = load_dataset(
473
+ extension,
474
+ data_files=data_files,
475
+ cache_dir=model_args.cache_dir,
476
+ **dataset_args,
477
+ )
478
+
479
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
480
+ if "validation" not in raw_datasets.keys():
481
+ raw_datasets["validation"] = load_dataset(
482
+ extension,
483
+ data_files=data_files,
484
+ split=f"train[:{data_args.validation_split_percentage}%]",
485
+ cache_dir=model_args.cache_dir,
486
+ **dataset_args,
487
+ )
488
+ raw_datasets["train"] = load_dataset(
489
+ extension,
490
+ data_files=data_files,
491
+ split=f"train[{data_args.validation_split_percentage}%:]",
492
+ cache_dir=model_args.cache_dir,
493
+ **dataset_args,
494
+ )
495
+ logger.info(f"Raw datasets: {raw_datasets}")
496
+
497
+ # Preprocessing the datasets.
498
+ if training_args.do_train:
499
+ column_names = list(raw_datasets["train"].features)
500
+ else:
501
+ column_names = list(raw_datasets["validation"].features)
502
+
503
+ with training_args.main_process_first(desc="Dataset tokenization and grouping"):
504
+ if not data_args.streaming:
505
+ tokenized_datasets = raw_datasets.map(
506
+ tokenize_function,
507
+ batched=True,
508
+ num_proc=data_args.preprocessing_num_workers,
509
+ remove_columns=column_names,
510
+ load_from_cache_file=not data_args.overwrite_cache,
511
+ desc="Running tokenizer on dataset",
512
+ )
513
+ lm_datasets = tokenized_datasets.map(
514
+ group_texts,
515
+ batched=True,
516
+ num_proc=data_args.preprocessing_num_workers,
517
+ load_from_cache_file=not data_args.overwrite_cache,
518
+ desc=f"Grouping texts in chunks of {block_size}",
519
+ )
520
+ else:
521
+ tokenized_datasets = raw_datasets.map(
522
+ tokenize_function,
523
+ batched=True,
524
+ remove_columns=column_names,
525
+ )
526
+ lm_datasets = tokenized_datasets.map(
527
+ group_texts,
528
+ batched=True,
529
+ )
530
+
531
+ train_dataset = None
532
+ max_train_samples = 0
533
+ if training_args.do_train:
534
+ if "train" not in tokenized_datasets:
535
+ raise ValueError("--do_train requires a train dataset")
536
+ train_dataset = lm_datasets['train']
537
+ max_train_samples = len(train_dataset)
538
+ if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
539
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
540
+ train_dataset = train_dataset.select(range(max_train_samples))
541
+ logger.debug(f"Num train_samples: {len(train_dataset)}")
542
+ logger.debug("Tokenized training example:")
543
+ logger.debug(tokenizer.decode(train_dataset[0]['input_ids']))
544
+
545
+ eval_dataset = None
546
+ max_eval_samples = 0
547
+ if training_args.do_eval:
548
+ if "validation" not in tokenized_datasets:
549
+ raise ValueError("--do_eval requires a validation dataset")
550
+ eval_dataset = lm_datasets["validation"]
551
+ max_eval_samples = len(eval_dataset)
552
+ if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
553
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
554
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
555
+ logger.debug(f"Num eval_samples: {len(eval_dataset)}")
556
+ logger.debug("Tokenized eval example:")
557
+ logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
558
+
559
+ # Load model
560
+ if model_args.model_type and model_args.model_name_or_path:
561
+ torch_dtype = (
562
+ model_args.torch_dtype
563
+ if model_args.torch_dtype in ["auto", None]
564
+ else getattr(torch, model_args.torch_dtype)
565
+ )
566
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
567
+ ddp = world_size != 1
568
+ if ddp:
569
+ model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
570
+
571
+ config = config_class.from_pretrained(
572
+ model_args.model_name_or_path,
573
+ torch_dtype=torch_dtype,
574
+ trust_remote_code=model_args.trust_remote_code,
575
+ cache_dir=model_args.cache_dir
576
+ )
577
+ model = model_class.from_pretrained(
578
+ model_args.model_name_or_path,
579
+ config=config,
580
+ load_in_8bit=model_args.load_in_8bit,
581
+ device_map=model_args.device_map,
582
+ trust_remote_code=model_args.trust_remote_code,
583
+ )
584
+ else:
585
+ raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")
586
+
587
+ if training_args.use_peft:
588
+ if training_args.peft_path is not None:
589
+ logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
590
+ model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
591
+ else:
592
+ logger.info("Init new peft model")
593
+ target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
594
+ if target_modules and 'all' in target_modules:
595
+ target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
596
+ modules_to_save = training_args.modules_to_save
597
+ if modules_to_save is not None:
598
+ modules_to_save = modules_to_save.split(',')
599
+ logger.info(f"Peft target_modules: {target_modules}")
600
+ logger.info(f"Peft lora_rank: {training_args.lora_rank}")
601
+ peft_config = LoraConfig(
602
+ task_type=TaskType.CAUSAL_LM,
603
+ target_modules=target_modules,
604
+ inference_mode=False,
605
+ r=training_args.lora_rank,
606
+ lora_alpha=training_args.lora_alpha,
607
+ lora_dropout=training_args.lora_dropout,
608
+ modules_to_save=modules_to_save)
609
+ model = get_peft_model(model, peft_config)
610
+ if model_args.load_in_8bit:
611
+ model = prepare_model_for_int8_training(model)
612
+ model.print_trainable_parameters()
613
+ else:
614
+ logger.info("Full parameters training")
615
+ model = model.float()
616
+ print_trainable_parameters(model)
617
+
618
+ # Initialize our Trainer
619
+ if training_args.gradient_checkpointing:
620
+ model.gradient_checkpointing_enable()
621
+ model.config.use_cache = False
622
+ else:
623
+ model.config.use_cache = True
624
+ model.enable_input_require_grads()
625
+ if not ddp and torch.cuda.device_count() > 1:
626
+ # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
627
+ model.is_parallelizable = True
628
+ model.model_parallel = True
629
+
630
+ trainer = SavePeftModelTrainer(
631
+ model=model,
632
+ args=training_args,
633
+ train_dataset=train_dataset if training_args.do_train else None,
634
+ eval_dataset=eval_dataset if training_args.do_eval else None,
635
+ tokenizer=tokenizer,
636
+ data_collator=fault_tolerance_data_collator,
637
+ compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
638
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
639
+ if training_args.do_eval and not is_torch_tpu_available()
640
+ else None,
641
+ )
642
+
643
+ # Training
644
+ if training_args.do_train:
645
+ logger.info("*** Train ***")
646
+ logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
647
+ checkpoint = None
648
+ if training_args.resume_from_checkpoint is not None:
649
+ checkpoint = training_args.resume_from_checkpoint
650
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
651
+
652
+ metrics = train_result.metrics
653
+ metrics["train_samples"] = max_train_samples
654
+ logger.debug(f"Training metrics: {metrics}")
655
+ trainer.log_metrics("train", metrics)
656
+ trainer.save_metrics("train", metrics)
657
+ trainer.save_state()
658
+ logger.info(f"Saving model checkpoint to {training_args.output_dir}")
659
+ save_model(training_args.output_dir, model, tokenizer, training_args)
660
+
661
+ # Evaluation
662
+ if training_args.do_eval and trainer.is_world_process_zero():
663
+ logger.info("*** Evaluate ***")
664
+ metrics = trainer.evaluate()
665
+
666
+ metrics["eval_samples"] = max_eval_samples
667
+ try:
668
+ perplexity = math.exp(metrics["eval_loss"])
669
+ except OverflowError:
670
+ perplexity = float("inf")
671
+ metrics["perplexity"] = perplexity
672
+ logger.debug(f"Eval metrics: {metrics}")
673
+ trainer.log_metrics("eval", metrics)
674
+ trainer.save_metrics("eval", metrics)
675
+
676
+
677
+ if __name__ == "__main__":
678
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ loguru
2
+ transformers>=4.30.1
3
+ sentencepiece
4
+ datasets
5
+ tqdm
6
+ tensorboard
7
+ tqdm>=4.47.0
8
+ peft>=0.5.0
9
+ accelerate>=0.20.3
10
+ trl>=0.6.0
reward_modeling.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
+
7
+ import math
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from glob import glob
11
+ from typing import Any, List, Union, Optional, Dict
12
+
13
+ import torch
14
+ from datasets import load_dataset
15
+ from loguru import logger
16
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
17
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
18
+ from torch.utils.data import Dataset
19
+ from transformers import (
20
+ AutoConfig,
21
+ PreTrainedTokenizerBase,
22
+ BloomForSequenceClassification,
23
+ LlamaForSequenceClassification,
24
+ LlamaTokenizer,
25
+ BloomTokenizerFast,
26
+ AlbertForSequenceClassification,
27
+ BertForSequenceClassification,
28
+ BertTokenizer,
29
+ AutoTokenizer,
30
+ RobertaForSequenceClassification,
31
+ AutoModelForSequenceClassification,
32
+ RobertaTokenizer,
33
+ HfArgumentParser,
34
+ Trainer,
35
+ TrainingArguments,
36
+ set_seed,
37
+ )
38
+ from transformers.trainer import TRAINING_ARGS_NAME
39
+
40
+ MODEL_CLASSES = {
41
+ "bert": (AutoConfig, BertForSequenceClassification, BertTokenizer),
42
+ "roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer),
43
+ "albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer),
44
+ "bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast),
45
+ "llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer),
46
+ "auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer),
47
+ }
48
+
49
+
50
+ @dataclass
51
+ class ModelArguments:
52
+ """
53
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
54
+ """
55
+
56
+ model_type: str = field(
57
+ default=None,
58
+ metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
59
+ )
60
+ model_name_or_path: Optional[str] = field(
61
+ default=None,
62
+ metadata={
63
+ "help": (
64
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
65
+ )
66
+ },
67
+ )
68
+ tokenizer_name_or_path: Optional[str] = field(
69
+ default=None,
70
+ metadata={
71
+ "help": (
72
+ "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
73
+ )
74
+ },
75
+ )
76
+ load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
77
+ cache_dir: Optional[str] = field(
78
+ default=None,
79
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
80
+ )
81
+ use_fast_tokenizer: bool = field(
82
+ default=False,
83
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
84
+ )
85
+ torch_dtype: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": (
89
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
90
+ "dtype will be automatically derived from the model's weights."
91
+ ),
92
+ "choices": ["auto", "bfloat16", "float16", "float32"],
93
+ },
94
+ )
95
+ device_map: Optional[str] = field(
96
+ default="auto",
97
+ metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
98
+ )
99
+ trust_remote_code: bool = field(
100
+ default=True,
101
+ metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
102
+ )
103
+
104
+ def __post_init__(self):
105
+ if self.model_type is None:
106
+ raise ValueError(
107
+ "You must specify a valid model_type to run training. Available model types are " + ", ".join(
108
+ MODEL_CLASSES.keys()))
109
+ if self.model_name_or_path is None:
110
+ raise ValueError("You must specify a valid model_name_or_path to run training.")
111
+
112
+
113
+ @dataclass
114
+ class DataTrainingArguments:
115
+ """
116
+ Arguments pertaining to what data we are going to input our model for training and eval.
117
+ """
118
+
119
+ dataset_name: Optional[str] = field(
120
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
121
+ )
122
+ dataset_config_name: Optional[str] = field(
123
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
+ )
125
+ train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
126
+ validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
127
+ max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
128
+ max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
129
+ max_train_samples: Optional[int] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": (
133
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
134
+ "value if set."
135
+ )
136
+ },
137
+ )
138
+ max_eval_samples: Optional[int] = field(
139
+ default=None,
140
+ metadata={
141
+ "help": (
142
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
143
+ "value if set."
144
+ )
145
+ },
146
+ )
147
+ overwrite_cache: bool = field(
148
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
149
+ )
150
+ validation_split_percentage: Optional[int] = field(
151
+ default=1,
152
+ metadata={
153
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
154
+ },
155
+ )
156
+ preprocessing_num_workers: Optional[int] = field(
157
+ default=4,
158
+ metadata={"help": "The number of processes to use for the preprocessing."},
159
+ )
160
+
161
+
162
+ @dataclass
163
+ class PeftArguments(TrainingArguments):
164
+ use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
165
+ target_modules: Optional[str] = field(default="all")
166
+ lora_rank: Optional[int] = field(default=8)
167
+ lora_dropout: Optional[float] = field(default=0.05)
168
+ lora_alpha: Optional[float] = field(default=32.0)
169
+ modules_to_save: Optional[str] = field(default=None)
170
+ peft_path: Optional[str] = field(default=None)
171
+
172
+
173
+ def compute_metrics(eval_preds):
174
+ preds, labels = eval_preds
175
+ # Here, predictions is rewards_chosen and rewards_rejected.
176
+ if isinstance(preds, torch.Tensor):
177
+ preds = preds.detach().cpu().numpy()
178
+ if isinstance(labels, torch.Tensor):
179
+ labels = labels.detach().cpu().numpy()
180
+ # MSE
181
+ mse = mean_squared_error(labels, preds)
182
+ # MAE
183
+ mae = mean_absolute_error(labels, preds)
184
+
185
+ return {"mse": mse, "mae": mae}
186
+
187
+
188
+ @dataclass
189
+ class RewardDataCollatorWithPadding:
190
+ """We need to define a special data collator that batches the data in our chosen vs rejected format"""
191
+ tokenizer: PreTrainedTokenizerBase
192
+ padding: Union[bool, str] = True
193
+ max_length: Optional[int] = None
194
+ pad_to_multiple_of: Optional[int] = None
195
+ return_tensors: str = "pt"
196
+
197
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
198
+ features_chosen = []
199
+ features_rejected = []
200
+ for feature in features:
201
+ features_chosen.append(
202
+ {
203
+ "input_ids": feature["input_ids_chosen"],
204
+ "attention_mask": feature["attention_mask_chosen"],
205
+ }
206
+ )
207
+ features_rejected.append(
208
+ {
209
+ "input_ids": feature["input_ids_rejected"],
210
+ "attention_mask": feature["attention_mask_rejected"],
211
+ }
212
+ )
213
+ batch_chosen = self.tokenizer.pad(
214
+ features_chosen,
215
+ padding=self.padding,
216
+ max_length=self.max_length,
217
+ pad_to_multiple_of=self.pad_to_multiple_of,
218
+ return_tensors=self.return_tensors,
219
+ )
220
+ batch_rejected = self.tokenizer.pad(
221
+ features_rejected,
222
+ padding=self.padding,
223
+ max_length=self.max_length,
224
+ pad_to_multiple_of=self.pad_to_multiple_of,
225
+ return_tensors=self.return_tensors,
226
+ )
227
+ batch = {
228
+ "input_ids_chosen": batch_chosen["input_ids"],
229
+ "attention_mask_chosen": batch_chosen["attention_mask"],
230
+ "input_ids_rejected": batch_rejected["input_ids"],
231
+ "attention_mask_rejected": batch_rejected["attention_mask"],
232
+ "return_loss": True,
233
+ }
234
+ return batch
235
+
236
+
237
+ class RewardTrainer(Trainer):
238
+ """
239
+ Trainer for reward models
240
+ Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
241
+ """
242
+
243
+ def compute_loss(self, model, inputs, return_outputs=False):
244
+ rewards_chosen = model(input_ids=inputs["input_ids_chosen"],
245
+ attention_mask=inputs["attention_mask_chosen"])[0]
246
+ rewards_rejected = model(input_ids=inputs["input_ids_rejected"],
247
+ attention_mask=inputs["attention_mask_rejected"])[0]
248
+ loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
249
+ if return_outputs:
250
+ return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected}
251
+ return loss
252
+
253
+ def evaluate(
254
+ self,
255
+ eval_dataset: Optional[Dataset] = None,
256
+ ignore_keys: Optional[List[str]] = None,
257
+ metric_key_prefix: str = "eval",
258
+ ) -> Dict[str, float]:
259
+ if eval_dataset is None:
260
+ eval_dataset = self.eval_dataset
261
+ return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
262
+
263
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
264
+ # Prepare inputs for chosen and rejected separately
265
+ device = model.device
266
+
267
+ inputs_chosen = {
268
+ "input_ids": inputs["input_ids_chosen"].to(device),
269
+ "attention_mask": inputs["attention_mask_chosen"].to(device),
270
+ }
271
+ outputs_chosen = model(**inputs_chosen)
272
+ rewards_chosen = outputs_chosen.logits.detach()
273
+
274
+ inputs_rejected = {
275
+ "input_ids": inputs["input_ids_rejected"].to(device),
276
+ "attention_mask": inputs["attention_mask_rejected"].to(device),
277
+ }
278
+ outputs_rejected = model(**inputs_rejected)
279
+ rewards_rejected = outputs_rejected.logits.detach()
280
+
281
+ # Keep the compute_loss method
282
+ loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
283
+ if prediction_loss_only:
284
+ return (loss, None, None)
285
+
286
+ return (loss, rewards_chosen, rewards_rejected)
287
+
288
+ def save_model(self, output_dir=None, _internal_call=False):
289
+ """Save the LoRA model."""
290
+ os.makedirs(output_dir, exist_ok=True)
291
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
292
+ self.model.save_pretrained(output_dir)
293
+
294
+
295
+ def save_model(output_dir, model, tokenizer, args):
296
+ """Save the model and the tokenizer."""
297
+ os.makedirs(output_dir, exist_ok=True)
298
+
299
+ # Take care of distributed/parallel training
300
+ model_to_save = model.module if hasattr(model, "module") else model
301
+ model_to_save.save_pretrained(output_dir)
302
+ tokenizer.save_pretrained(output_dir)
303
+ torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
304
+
305
+
306
+ class CastOutputToFloat(torch.nn.Sequential):
307
+ """Cast the output of the model to float"""
308
+
309
+ def forward(self, x):
310
+ return super().forward(x).to(torch.float32)
311
+
312
+
313
+ def print_trainable_parameters(model):
314
+ """
315
+ Prints the number of trainable parameters in the model.
316
+ """
317
+ trainable_params = 0
318
+ all_param = 0
319
+ for _, param in model.named_parameters():
320
+ all_param += param.numel()
321
+ if param.requires_grad:
322
+ trainable_params += param.numel()
323
+ print(
324
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
325
+ )
326
+
327
+
328
+ def find_all_linear_names(peft_model, int4=False, int8=False):
329
+ cls = torch.nn.Linear
330
+ if int4 or int8:
331
+ import bitsandbytes as bnb
332
+ if int4:
333
+ cls = bnb.nn.Linear4bit
334
+ elif int8:
335
+ cls = bnb.nn.Linear8bitLt
336
+ lora_module_names = set()
337
+ for name, module in peft_model.named_modules():
338
+ if isinstance(module, cls):
339
+ # last layer is not add to lora_module_names
340
+ if 'lm_head' in name:
341
+ continue
342
+ if 'score' in name:
343
+ continue
344
+ names = name.split('.')
345
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
346
+ return sorted(lora_module_names)
347
+
348
+
349
+ def main():
350
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
351
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
352
+
353
+ logger.info(f"Model args: {model_args}")
354
+ logger.info(f"Data args: {data_args}")
355
+ logger.info(f"Training args: {training_args}")
356
+ logger.info(
357
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
358
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
359
+ )
360
+
361
+ # Set seed before initializing model.
362
+ set_seed(training_args.seed)
363
+
364
+ # Load model
365
+ if not model_args.model_type:
366
+ raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
367
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
368
+ if model_args.model_name_or_path:
369
+ torch_dtype = (
370
+ model_args.torch_dtype
371
+ if model_args.torch_dtype in ["auto", None]
372
+ else getattr(torch, model_args.torch_dtype)
373
+ )
374
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
375
+ if world_size > 1:
376
+ model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
377
+ config = config_class.from_pretrained(
378
+ model_args.model_name_or_path,
379
+ num_labels=1,
380
+ torch_dtype=torch_dtype,
381
+ trust_remote_code=model_args.trust_remote_code,
382
+ cache_dir=model_args.cache_dir
383
+ )
384
+ if model_args.model_type in ['bloom', 'llama']:
385
+ model = model_class.from_pretrained(
386
+ model_args.model_name_or_path,
387
+ config=config,
388
+ load_in_8bit=model_args.load_in_8bit,
389
+ device_map=model_args.device_map,
390
+ trust_remote_code=model_args.trust_remote_code,
391
+ )
392
+ model.score = CastOutputToFloat(model.score)
393
+ else:
394
+ model = model_class.from_pretrained(
395
+ model_args.model_name_or_path,
396
+ config=config,
397
+ cache_dir=model_args.cache_dir,
398
+ ignore_mismatched_sizes=True
399
+ )
400
+ model.to(training_args.device)
401
+ else:
402
+ raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model")
403
+
404
+ # Load tokenizer
405
+ if model_args.model_type == "bloom":
406
+ model_args.use_fast_tokenizer = True
407
+ tokenizer_kwargs = {
408
+ "cache_dir": model_args.cache_dir,
409
+ "use_fast": model_args.use_fast_tokenizer,
410
+ "trust_remote_code": model_args.trust_remote_code,
411
+ }
412
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
413
+ if not tokenizer_name_or_path:
414
+ tokenizer_name_or_path = model_args.model_name_or_path
415
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
416
+ if tokenizer.pad_token_id is None:
417
+ tokenizer.pad_token_id = 0
418
+
419
+ if training_args.use_peft:
420
+ if training_args.peft_path is not None:
421
+ logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
422
+ model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
423
+ else:
424
+ logger.info("Init new peft model")
425
+ target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
426
+ if target_modules and 'all' in target_modules:
427
+ target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
428
+ modules_to_save = training_args.modules_to_save
429
+ if modules_to_save is not None:
430
+ modules_to_save = modules_to_save.split(',')
431
+ logger.info(f"Peft target_modules: {target_modules}")
432
+ logger.info(f"Peft lora_rank: {training_args.lora_rank}")
433
+ peft_config = LoraConfig(
434
+ task_type=TaskType.SEQ_CLS,
435
+ target_modules=target_modules,
436
+ inference_mode=False,
437
+ r=training_args.lora_rank,
438
+ lora_alpha=training_args.lora_alpha,
439
+ lora_dropout=training_args.lora_dropout,
440
+ modules_to_save=modules_to_save)
441
+ model = get_peft_model(model, peft_config)
442
+ if model_args.load_in_8bit:
443
+ model = prepare_model_for_int8_training(model)
444
+ model.print_trainable_parameters()
445
+ else:
446
+ logger.info("Full parameters training")
447
+ print_trainable_parameters(model)
448
+
449
+ # Get reward dataset for tuning the reward model.
450
+ if data_args.dataset_name is not None:
451
+ # Downloading and loading a dataset from the hub.
452
+ raw_datasets = load_dataset(
453
+ data_args.dataset_name,
454
+ data_args.dataset_config_name,
455
+ cache_dir=model_args.cache_dir,
456
+ )
457
+ if "validation" not in raw_datasets.keys():
458
+ raw_datasets["validation"] = load_dataset(
459
+ data_args.dataset_name,
460
+ data_args.dataset_config_name,
461
+ split=f"train[:{data_args.validation_split_percentage}%]",
462
+ cache_dir=model_args.cache_dir,
463
+ )
464
+ raw_datasets["train"] = load_dataset(
465
+ data_args.dataset_name,
466
+ data_args.dataset_config_name,
467
+ split=f"train[{data_args.validation_split_percentage}%:]",
468
+ cache_dir=model_args.cache_dir,
469
+ )
470
+ else:
471
+ data_files = {}
472
+ if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
473
+ train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
474
+ f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
475
+ logger.info(f"train files: {', '.join(train_data_files)}")
476
+ data_files["train"] = train_data_files
477
+ if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
478
+ eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
479
+ f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
480
+ logger.info(f"eval files: {', '.join(eval_data_files)}")
481
+ data_files["validation"] = eval_data_files
482
+ raw_datasets = load_dataset(
483
+ 'json',
484
+ data_files=data_files,
485
+ cache_dir=model_args.cache_dir,
486
+ )
487
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
488
+ if "validation" not in raw_datasets.keys():
489
+ raw_datasets["validation"] = load_dataset(
490
+ 'json',
491
+ data_files=data_files,
492
+ split=f"train[:{data_args.validation_split_percentage}%]",
493
+ cache_dir=model_args.cache_dir,
494
+ )
495
+ raw_datasets["train"] = load_dataset(
496
+ 'json',
497
+ data_files=data_files,
498
+ split=f"train[{data_args.validation_split_percentage}%:]",
499
+ cache_dir=model_args.cache_dir,
500
+ )
501
+ logger.info(f"Raw datasets: {raw_datasets}")
502
+
503
+ # Preprocessing the datasets
504
+ full_max_length = data_args.max_source_length + data_args.max_target_length
505
+
506
+ def preprocess_reward_function(examples):
507
+ """
508
+ Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer
509
+ and text_rejected is the other.
510
+ """
511
+ new_examples = {
512
+ "input_ids_chosen": [],
513
+ "attention_mask_chosen": [],
514
+ "input_ids_rejected": [],
515
+ "attention_mask_rejected": [],
516
+ }
517
+ for question, chosen, rejected in zip(examples["question"], examples["response_chosen"],
518
+ examples["response_rejected"]):
519
+ tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen)
520
+ tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected)
521
+
522
+ new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
523
+ new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
524
+ new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
525
+ new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
526
+
527
+ return new_examples
528
+
529
+ train_dataset = None
530
+ max_train_samples = 0
531
+ if training_args.do_train:
532
+ if "train" not in raw_datasets:
533
+ raise ValueError("--do_train requires a train dataset")
534
+ train_dataset = raw_datasets['train']
535
+ max_train_samples = len(train_dataset)
536
+ if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
537
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
538
+ train_dataset = train_dataset.select(range(max_train_samples))
539
+ logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
540
+ with training_args.main_process_first(desc="Train dataset tokenization"):
541
+ tokenized_dataset = train_dataset.shuffle().map(
542
+ preprocess_reward_function,
543
+ batched=True,
544
+ num_proc=data_args.preprocessing_num_workers,
545
+ remove_columns=train_dataset.column_names,
546
+ load_from_cache_file=not data_args.overwrite_cache,
547
+ desc="Running tokenizer on dataset",
548
+ )
549
+ train_dataset = tokenized_dataset.filter(
550
+ lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
551
+ x['input_ids_chosen']) <= full_max_length
552
+ )
553
+ logger.debug(f"Num train_samples: {len(train_dataset)}")
554
+ logger.debug("Tokenized training example:")
555
+ logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen']))
556
+
557
+ eval_dataset = None
558
+ max_eval_samples = 0
559
+ if training_args.do_eval:
560
+ with training_args.main_process_first(desc="Eval dataset tokenization"):
561
+ if "validation" not in raw_datasets:
562
+ raise ValueError("--do_eval requires a validation dataset")
563
+ eval_dataset = raw_datasets["validation"]
564
+ max_eval_samples = len(eval_dataset)
565
+ if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
566
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
567
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
568
+ logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
569
+ tokenized_dataset = eval_dataset.map(
570
+ preprocess_reward_function,
571
+ batched=True,
572
+ num_proc=data_args.preprocessing_num_workers,
573
+ remove_columns=eval_dataset.column_names,
574
+ load_from_cache_file=not data_args.overwrite_cache,
575
+ desc="Running tokenizer on dataset",
576
+ )
577
+ eval_dataset = tokenized_dataset.filter(
578
+ lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
579
+ x['input_ids_chosen']) <= full_max_length
580
+ )
581
+ logger.debug(f"Num eval_samples: {len(eval_dataset)}")
582
+ logger.debug("Tokenized eval example:")
583
+ logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen']))
584
+
585
+ # Initialize our Trainer
586
+ if training_args.gradient_checkpointing:
587
+ model.gradient_checkpointing_enable()
588
+ model.config.use_cache = False
589
+ else:
590
+ model.config.use_cache = True
591
+ model.enable_input_require_grads()
592
+ if torch.cuda.device_count() > 1:
593
+ # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
594
+ model.is_parallelizable = True
595
+ model.model_parallel = True
596
+ trainer = RewardTrainer(
597
+ model=model,
598
+ args=training_args,
599
+ train_dataset=train_dataset if training_args.do_train else None,
600
+ eval_dataset=eval_dataset if training_args.do_eval else None,
601
+ tokenizer=tokenizer,
602
+ compute_metrics=compute_metrics,
603
+ data_collator=RewardDataCollatorWithPadding(
604
+ tokenizer=tokenizer, max_length=full_max_length, padding="max_length"
605
+ ),
606
+ )
607
+
608
+ # Training
609
+ if training_args.do_train:
610
+ logger.info("*** Train ***")
611
+ logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
612
+ checkpoint = None
613
+ if training_args.resume_from_checkpoint is not None:
614
+ checkpoint = training_args.resume_from_checkpoint
615
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
616
+
617
+ metrics = train_result.metrics
618
+ metrics["train_samples"] = max_train_samples
619
+ logger.debug(f"Training metrics: {metrics}")
620
+ trainer.log_metrics("train", metrics)
621
+ trainer.save_metrics("train", metrics)
622
+ trainer.save_state()
623
+ logger.info(f"Saving model checkpoint to {training_args.output_dir}")
624
+ save_model(training_args.output_dir, model, tokenizer, training_args)
625
+
626
+ # Evaluation
627
+ if training_args.do_eval and trainer.is_world_process_zero():
628
+ logger.info("*** Evaluate ***")
629
+ metrics = trainer.evaluate()
630
+
631
+ metrics["eval_samples"] = max_eval_samples
632
+ try:
633
+ perplexity = math.exp(metrics["eval_loss"])
634
+ except OverflowError:
635
+ perplexity = float("inf")
636
+ metrics["perplexity"] = perplexity
637
+ logger.debug(f"Eval metrics: {metrics}")
638
+ trainer.log_metrics("eval", metrics)
639
+ trainer.save_metrics("eval", metrics)
640
+
641
+
642
+ if __name__ == "__main__":
643
+ main()
rl_training.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: Train a model from SFT using PPO
5
+ """
6
+
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from glob import glob
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from datasets import load_dataset
14
+ from loguru import logger
15
+ from peft import LoraConfig, TaskType
16
+ from tqdm import tqdm
17
+ from transformers import (
18
+ AutoConfig,
19
+ AutoModelForSequenceClassification,
20
+ BloomForCausalLM,
21
+ AutoModelForCausalLM,
22
+ AutoModel,
23
+ LlamaTokenizer,
24
+ LlamaForCausalLM,
25
+ BloomTokenizerFast,
26
+ AutoTokenizer,
27
+ HfArgumentParser,
28
+ )
29
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
30
+
31
+ from supervised_finetuning import get_conv_template
32
+
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
34
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
35
+
36
+ MODEL_CLASSES = {
37
+ "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
38
+ "chatglm": (AutoConfig, AutoModel, AutoTokenizer),
39
+ "llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
40
+ "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
41
+ "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class ScriptArguments:
47
+ """
48
+ The name of the Casual LM model we wish to fine with PPO
49
+ """
50
+ # Model arguments
51
+ model_type: str = field(
52
+ default=None,
53
+ metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
54
+ )
55
+ model_name_or_path: Optional[str] = field(
56
+ default=None, metadata={"help": "The model checkpoint for weights initialization."}
57
+ )
58
+ reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The reward model name"})
59
+ tokenizer_name_or_path: Optional[str] = field(
60
+ default=None, metadata={"help": "The tokenizer for weights initialization."}
61
+ )
62
+ load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
63
+ cache_dir: Optional[str] = field(
64
+ default=None,
65
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
66
+ )
67
+ use_fast_tokenizer: bool = field(
68
+ default=False,
69
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
70
+ )
71
+ torch_dtype: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": (
75
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
76
+ "dtype will be automatically derived from the model's weights."
77
+ ),
78
+ "choices": ["auto", "bfloat16", "float16", "float32"],
79
+ },
80
+ )
81
+ device_map: Optional[str] = field(
82
+ default="auto",
83
+ metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
84
+ )
85
+ trust_remote_code: bool = field(
86
+ default=True,
87
+ metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
88
+ )
89
+ # Dataset arguments
90
+ dataset_name: Optional[str] = field(
91
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
92
+ )
93
+ dataset_config_name: Optional[str] = field(
94
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
95
+ )
96
+ train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
97
+ validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
98
+ template_name: Optional[str] = field(default="vicuna", metadata={"help": "The template name."})
99
+ batch_size: Optional[int] = field(default=8, metadata={"help": "Batch size"})
100
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "PPO minibatch size"})
101
+ max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
102
+ max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
103
+ min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
104
+ max_train_samples: Optional[int] = field(
105
+ default=None,
106
+ metadata={
107
+ "help": (
108
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
109
+ "value if set."
110
+ )
111
+ },
112
+ )
113
+ max_eval_samples: Optional[int] = field(
114
+ default=None,
115
+ metadata={
116
+ "help": (
117
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
118
+ "value if set."
119
+ )
120
+ },
121
+ )
122
+ overwrite_cache: bool = field(
123
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
124
+ )
125
+ validation_split_percentage: Optional[int] = field(
126
+ default=1,
127
+ metadata={
128
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
129
+ },
130
+ )
131
+ preprocessing_num_workers: Optional[int] = field(
132
+ default=None, metadata={"help": "The number of processes to use for the preprocessing."},
133
+ )
134
+ # Training arguments
135
+ use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
136
+ target_modules: Optional[str] = field(default=None)
137
+ lora_rank: Optional[int] = field(default=8)
138
+ lora_dropout: Optional[float] = field(default=0.05)
139
+ lora_alpha: Optional[float] = field(default=32.0)
140
+ modules_to_save: Optional[str] = field(default=None)
141
+ peft_path: Optional[str] = field(default=None)
142
+
143
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
144
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
145
+ early_stopping: Optional[bool] = field(default=False, metadata={"help": "Whether to early stop"})
146
+ target_kl: Optional[float] = field(default=0.1, metadata={"help": "The kl target for early stopping"})
147
+ reward_baseline: Optional[float] = field(
148
+ default=0.0, metadata={"help": "Baseline value that is subtracted from the reward"},
149
+ )
150
+ init_kl_coef: Optional[float] = field(
151
+ default=0.2, metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
152
+ )
153
+ adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
154
+ learning_rate: Optional[float] = field(default=1.5e-5, metadata={"help": "Learning rate"})
155
+ gradient_accumulation_steps: Optional[int] = field(
156
+ default=1, metadata={"help": "the number of gradient accumulation steps"}
157
+ )
158
+ save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
159
+ output_dir: Optional[str] = field(default="outputs-rl", metadata={"help": "The output directory"})
160
+ seed: Optional[int] = field(default=0, metadata={"help": "Seed"})
161
+ max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
162
+ report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
163
+
164
+ def __post_init__(self):
165
+ if self.model_type is None:
166
+ raise ValueError("You must specify a valid model_type to run training.")
167
+ if self.model_name_or_path is None:
168
+ raise ValueError("You must specify a valid model_name_or_path to run training.")
169
+ if self.reward_model_name_or_path is None:
170
+ raise ValueError("You must specify a valid reward_model_name_or_path to run training.")
171
+
172
+
173
+ def print_trainable_parameters(model):
174
+ """
175
+ Prints the number of trainable parameters in the model.
176
+ """
177
+ trainable_params = 0
178
+ all_param = 0
179
+ for _, param in model.named_parameters():
180
+ all_param += param.numel()
181
+ if param.requires_grad:
182
+ trainable_params += param.numel()
183
+ print(
184
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
185
+ )
186
+
187
+
188
+ def get_reward_model_output(reward_model, reward_tokenizer, question, answer, device):
189
+ """
190
+ Get the reward score for a given question and answer pair.
191
+ """
192
+ inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device)
193
+ score = reward_model(**inputs).logits[0].cpu().detach()
194
+
195
+ return score
196
+
197
+
198
+ def calculate_rewards(reward_score_outputs, reward_baseline=0):
199
+ """
200
+ Calculate the reward for a given score output.
201
+ :param reward_score_outputs:
202
+ :param reward_baseline:
203
+ :return:
204
+ """
205
+ rewards = []
206
+ for score in reward_score_outputs:
207
+ if isinstance(score, torch.Tensor) and score.numel() == 1:
208
+ reward_value = score.item() - reward_baseline
209
+ rewards.append(torch.tensor(reward_value))
210
+ else:
211
+ # Use the average of the tensor elements as `score` is multiple elements
212
+ reward_value = torch.mean(score).item() - reward_baseline
213
+ rewards.append(torch.tensor(reward_value))
214
+ return rewards
215
+
216
+
217
+ def main():
218
+ parser = HfArgumentParser(ScriptArguments)
219
+ args = parser.parse_args_into_dataclasses()[0]
220
+ logger.info(f"Parse args: {args}")
221
+
222
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
223
+ if args.model_type == 'bloom':
224
+ args.use_fast_tokenizer = True
225
+ # Load tokenizer
226
+ tokenizer_kwargs = {
227
+ "cache_dir": args.cache_dir,
228
+ "use_fast": args.use_fast_tokenizer,
229
+ "trust_remote_code": args.trust_remote_code,
230
+ }
231
+ tokenizer_name_or_path = args.tokenizer_name_or_path
232
+ if not tokenizer_name_or_path:
233
+ tokenizer_name_or_path = args.model_name_or_path
234
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
235
+ if tokenizer.pad_token_id is None:
236
+ tokenizer.pad_token_id = 0 # set as the <unk> token
237
+
238
+ logger.info("Load model")
239
+ peft_config = LoraConfig(
240
+ task_type=TaskType.CAUSAL_LM,
241
+ target_modules=args.target_modules,
242
+ inference_mode=False,
243
+ r=args.lora_rank,
244
+ lora_alpha=args.lora_alpha,
245
+ lora_dropout=args.lora_dropout,
246
+ )
247
+ torch_dtype = (
248
+ args.torch_dtype
249
+ if args.torch_dtype in ["auto", None]
250
+ else getattr(torch, args.torch_dtype)
251
+ )
252
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
253
+ if world_size > 1:
254
+ args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
255
+ config = config_class.from_pretrained(
256
+ args.model_name_or_path,
257
+ torch_dtype=torch_dtype,
258
+ trust_remote_code=args.trust_remote_code,
259
+ cache_dir=args.cache_dir
260
+ )
261
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
262
+ args.model_name_or_path,
263
+ config=config,
264
+ load_in_8bit=args.load_in_8bit,
265
+ device_map=args.device_map,
266
+ trust_remote_code=args.trust_remote_code,
267
+ peft_config=peft_config if args.use_peft else None,
268
+ )
269
+ print_trainable_parameters(model)
270
+ # Load reward model
271
+ device = "cuda" if torch.cuda.is_available() else "cpu"
272
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
273
+ args.reward_model_name_or_path,
274
+ config=config,
275
+ load_in_8bit=args.load_in_8bit,
276
+ trust_remote_code=args.trust_remote_code,
277
+ )
278
+ reward_model.to(device)
279
+ reward_tokenizer = AutoTokenizer.from_pretrained(
280
+ args.reward_model_name_or_path, **tokenizer_kwargs
281
+ )
282
+
283
+ # Get datasets
284
+ if args.dataset_name is not None:
285
+ # Downloading and loading a dataset from the hub.
286
+ raw_datasets = load_dataset(
287
+ args.dataset_name,
288
+ args.dataset_config_name,
289
+ cache_dir=args.cache_dir,
290
+ )
291
+ if "validation" not in raw_datasets.keys():
292
+ raw_datasets["validation"] = load_dataset(
293
+ args.dataset_name,
294
+ args.dataset_config_name,
295
+ split=f"train[:{args.validation_split_percentage}%]",
296
+ cache_dir=args.cache_dir,
297
+ )
298
+ raw_datasets["train"] = load_dataset(
299
+ args.dataset_name,
300
+ args.dataset_config_name,
301
+ split=f"train[{args.validation_split_percentage}%:]",
302
+ cache_dir=args.cache_dir,
303
+ )
304
+ else:
305
+ data_files = {}
306
+ if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
307
+ train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
308
+ f'{args.train_file_dir}/**/*.jsonl', recursive=True)
309
+ logger.info(f"train files: {', '.join(train_data_files)}")
310
+ data_files["train"] = train_data_files
311
+ if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
312
+ eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
313
+ f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
314
+ logger.info(f"eval files: {', '.join(eval_data_files)}")
315
+ data_files["validation"] = eval_data_files
316
+ raw_datasets = load_dataset(
317
+ 'json',
318
+ data_files=data_files,
319
+ cache_dir=args.cache_dir,
320
+ )
321
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
322
+ if "validation" not in raw_datasets.keys():
323
+ raw_datasets["validation"] = load_dataset(
324
+ 'json',
325
+ data_files=data_files,
326
+ split=f"train[:{args.validation_split_percentage}%]",
327
+ cache_dir=args.cache_dir,
328
+ )
329
+ raw_datasets["train"] = load_dataset(
330
+ 'json',
331
+ data_files=data_files,
332
+ split=f"train[{args.validation_split_percentage}%:]",
333
+ cache_dir=args.cache_dir,
334
+ )
335
+ logger.info(f"Raw datasets: {raw_datasets}")
336
+
337
+ # Preprocessing the datasets
338
+ max_source_length = args.max_source_length
339
+ max_target_length = args.max_target_length
340
+ prompt_template = get_conv_template(args.template_name)
341
+
342
+ def preprocess_function(examples):
343
+ new_examples = {
344
+ "query": [],
345
+ "input_ids": [],
346
+ }
347
+ roles = ["human", "gpt"]
348
+
349
+ def get_prompt(examples):
350
+ for i, source in enumerate(examples['conversations']):
351
+ if len(source) < 2:
352
+ continue
353
+ data_role = source[0].get("from", "")
354
+ if data_role not in roles or data_role != roles[0]:
355
+ # Skip the first one if it is not from human
356
+ source = source[1:]
357
+ if len(source) < 2:
358
+ continue
359
+ messages = []
360
+ for j, sentence in enumerate(source):
361
+ data_role = sentence.get("from", "")
362
+ if data_role not in roles:
363
+ logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
364
+ break
365
+ if data_role == roles[j % 2]:
366
+ messages.append(sentence["value"])
367
+ if len(messages) < 2 or len(messages) % 2 != 0:
368
+ continue
369
+ # Convert the list to pairs of elements
370
+ history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
371
+ yield prompt_template.get_prompt(history_messages)
372
+
373
+ for prompt in get_prompt(examples):
374
+ for i in range(len(prompt) // 2):
375
+ source_txt = prompt[2 * i]
376
+ tokenized_question = tokenizer(
377
+ source_txt, truncation=True, max_length=max_source_length, padding="max_length",
378
+ return_tensors="pt"
379
+ )
380
+ new_examples["query"].append(source_txt)
381
+ new_examples["input_ids"].append(tokenized_question["input_ids"])
382
+
383
+ return new_examples
384
+
385
+ # Preprocess the dataset
386
+ train_dataset = None
387
+ if args.do_train:
388
+ if "train" not in raw_datasets:
389
+ raise ValueError("--do_train requires a train dataset")
390
+ train_dataset = raw_datasets['train']
391
+ if args.max_train_samples is not None and args.max_train_samples > 0:
392
+ max_train_samples = min(len(train_dataset), args.max_train_samples)
393
+ train_dataset = train_dataset.select(range(max_train_samples))
394
+ logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
395
+ tokenized_dataset = train_dataset.shuffle().map(
396
+ preprocess_function,
397
+ batched=True,
398
+ num_proc=args.preprocessing_num_workers,
399
+ remove_columns=train_dataset.column_names,
400
+ load_from_cache_file=not args.overwrite_cache,
401
+ desc="Running tokenizer on dataset",
402
+ )
403
+ train_dataset = tokenized_dataset.filter(
404
+ lambda x: len(x['input_ids']) > 0
405
+ )
406
+ logger.debug(f"Num train_samples: {len(train_dataset)}")
407
+
408
+ def collator(data):
409
+ return dict((key, [d[key] for d in data]) for key in data[0])
410
+
411
+ output_dir = args.output_dir
412
+ config = PPOConfig(
413
+ steps=args.max_steps,
414
+ model_name=args.model_name_or_path,
415
+ learning_rate=args.learning_rate,
416
+ log_with=args.report_to,
417
+ batch_size=args.batch_size,
418
+ mini_batch_size=args.mini_batch_size,
419
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
420
+ optimize_cuda_cache=True,
421
+ early_stopping=args.early_stopping,
422
+ target_kl=args.target_kl,
423
+ seed=args.seed,
424
+ init_kl_coef=args.init_kl_coef,
425
+ adap_kl_ctrl=args.adap_kl_ctrl,
426
+ project_kwargs={"logging_dir": output_dir},
427
+ )
428
+ # Set seed before initializing value head for deterministic eval
429
+ set_seed(config.seed)
430
+
431
+ # We then build the PPOTrainer, passing the model, the reference model, the tokenizer
432
+ trainer = PPOTrainer(
433
+ config,
434
+ model,
435
+ ref_model=None,
436
+ tokenizer=tokenizer,
437
+ dataset=train_dataset,
438
+ data_collator=collator,
439
+ )
440
+
441
+ # These arguments are passed to the `generate` function of the PPOTrainer
442
+ generation_kwargs = {
443
+ "max_new_tokens": max_target_length,
444
+ "temperature": 1.0,
445
+ "repetition_penalty": 1.0,
446
+ "top_p": 1.0,
447
+ "do_sample": True,
448
+ }
449
+
450
+ def save_model(save_dir):
451
+ trainer.accelerator.unwrap_model(trainer.model).save_pretrained(save_dir)
452
+ trainer.tokenizer.save_pretrained(save_dir)
453
+
454
+ # Training
455
+ if args.do_train:
456
+ logger.info("*** Train ***")
457
+ total_steps = config.total_ppo_epochs
458
+ for step, batch in tqdm(enumerate(trainer.dataloader)):
459
+ if step >= total_steps:
460
+ break
461
+ question_tensors = batch["input_ids"]
462
+ question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors]
463
+ responses = []
464
+ response_tensors = []
465
+ for q_tensor in question_tensors:
466
+ response_tensor = trainer.generate(
467
+ q_tensor,
468
+ return_prompt=False,
469
+ **generation_kwargs,
470
+ )
471
+ r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0]
472
+ responses.append(r)
473
+ response_tensors.append(response_tensor.squeeze(0))
474
+ batch["response"] = responses
475
+
476
+ # Compute reward score
477
+ score_outputs = [
478
+ get_reward_model_output(reward_model, reward_tokenizer, q, r, device) for q, r in
479
+ zip(batch["query"], batch["response"])
480
+ ]
481
+ rewards = calculate_rewards(score_outputs, args.reward_baseline)
482
+
483
+ # Run PPO step
484
+ try:
485
+ stats = trainer.step(question_tensors, response_tensors, rewards)
486
+ trainer.log_stats(stats, batch, rewards)
487
+ logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}")
488
+ except ValueError as e:
489
+ logger.warning(f"Failed to log stats for step {step}, because of {e}")
490
+
491
+ if step and step % args.save_steps == 0:
492
+ save_dir = os.path.join(output_dir, f"checkpoint-{step}")
493
+ save_model(save_dir)
494
+ # Save final model
495
+ save_model(output_dir)
496
+
497
+
498
+ if __name__ == "__main__":
499
+ main()
run_dpo.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1 python dpo_training.py \
2
+ --model_type bloom \
3
+ --model_name_or_path bigscience/bloomz-560m \
4
+ --train_file_dir ./data/reward \
5
+ --validation_file_dir ./data/reward \
6
+ --per_device_train_batch_size 4 \
7
+ --per_device_eval_batch_size 1 \
8
+ --do_train \
9
+ --do_eval \
10
+ --use_peft True \
11
+ --max_train_samples 1000 \
12
+ --max_eval_samples 10 \
13
+ --max_steps 100 \
14
+ --eval_steps 20 \
15
+ --save_steps 50 \
16
+ --max_source_length 128 \
17
+ --max_target_length 128 \
18
+ --output_dir outputs-dpo-bloom-v1 \
19
+ --target_modules all \
20
+ --lora_rank 8 \
21
+ --lora_alpha 16 \
22
+ --lora_dropout 0.05 \
23
+ --torch_dtype float16 \
24
+ --fp16 True \
25
+ --device_map auto \
26
+ --report_to tensorboard \
27
+ --remove_unused_columns False \
28
+ --gradient_checkpointing True \
29
+ --cache_dir ./cache
run_pt.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 pretraining.py \
2
+ --model_type bloom \
3
+ --model_name_or_path bigscience/bloomz-560m \
4
+ --train_file_dir ./data/pretrain \
5
+ --validation_file_dir ./data/pretrain \
6
+ --per_device_train_batch_size 4 \
7
+ --per_device_eval_batch_size 4 \
8
+ --do_train \
9
+ --do_eval \
10
+ --use_peft True \
11
+ --seed 42 \
12
+ --fp16 \
13
+ --max_train_samples 10000 \
14
+ --max_eval_samples 10 \
15
+ --num_train_epochs 0.5 \
16
+ --learning_rate 2e-4 \
17
+ --warmup_ratio 0.05 \
18
+ --weight_decay 0.01 \
19
+ --logging_strategy steps \
20
+ --logging_steps 10 \
21
+ --eval_steps 50 \
22
+ --evaluation_strategy steps \
23
+ --save_steps 500 \
24
+ --save_strategy steps \
25
+ --save_total_limit 3 \
26
+ --gradient_accumulation_steps 1 \
27
+ --preprocessing_num_workers 1 \
28
+ --block_size 1024 \
29
+ --output_dir outputs-pt-bloom-v1 \
30
+ --overwrite_output_dir \
31
+ --ddp_timeout 30000 \
32
+ --logging_first_step True \
33
+ --target_modules all \
34
+ --lora_rank 8 \
35
+ --lora_alpha 16 \
36
+ --lora_dropout 0.05 \
37
+ --torch_dtype float16 \
38
+ --device_map auto \
39
+ --report_to tensorboard \
40
+ --ddp_find_unused_parameters False \
41
+ --gradient_checkpointing True \
42
+ --cache_dir ./cache
run_rl.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 rl_training.py \
2
+ --model_type bloom \
3
+ --model_name_or_path bigscience/bloomz-560m \
4
+ --reward_model_name_or_path OpenAssistant/reward-model-deberta-v3-large-v2 \
5
+ --torch_dtype float16 \
6
+ --device_map auto \
7
+ --train_file_dir ./data/finetune \
8
+ --validation_file_dir ./data/finetune \
9
+ --batch_size 8 \
10
+ --max_source_length 256 \
11
+ --max_target_length 256 \
12
+ --max_train_samples 1000 \
13
+ --use_peft True \
14
+ --lora_rank 8 \
15
+ --lora_alpha 16 \
16
+ --lora_dropout 0.05 \
17
+ --do_train \
18
+ --max_steps 100 \
19
+ --learning_rate 1e-5 \
20
+ --save_steps 50 \
21
+ --output_dir outputs-rl-bloom-v1 \
22
+ --early_stopping True \
23
+ --target_kl 0.1 \
24
+ --reward_baseline 0.0
run_rm.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1 python reward_modeling.py \
2
+ --model_type bloom \
3
+ --model_name_or_path bigscience/bloomz-560m \
4
+ --train_file_dir ./data/reward \
5
+ --validation_file_dir ./data/reward \
6
+ --per_device_train_batch_size 4 \
7
+ --per_device_eval_batch_size 4 \
8
+ --do_train \
9
+ --use_peft True \
10
+ --seed 42 \
11
+ --max_train_samples 1000 \
12
+ --max_eval_samples 10 \
13
+ --num_train_epochs 1 \
14
+ --learning_rate 2e-5 \
15
+ --warmup_ratio 0.05 \
16
+ --weight_decay 0.001 \
17
+ --logging_strategy steps \
18
+ --logging_steps 10 \
19
+ --eval_steps 50 \
20
+ --evaluation_strategy steps \
21
+ --save_steps 500 \
22
+ --save_strategy steps \
23
+ --save_total_limit 3 \
24
+ --max_source_length 256 \
25
+ --max_target_length 256 \
26
+ --output_dir outputs-rm-bloom-v1 \
27
+ --overwrite_output_dir \
28
+ --ddp_timeout 30000 \
29
+ --logging_first_step True \
30
+ --target_modules all \
31
+ --lora_rank 8 \
32
+ --lora_alpha 16 \
33
+ --lora_dropout 0.05 \
34
+ --torch_dtype float32 \
35
+ --device_map auto \
36
+ --report_to tensorboard \
37
+ --ddp_find_unused_parameters False \
38
+ --remove_unused_columns False \
39
+ --gradient_checkpointing True
run_sft.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 supervised_finetuning.py \
2
+ --model_type bloom \
3
+ --model_name_or_path bigscience/bloomz-560m \
4
+ --train_file_dir ./data/finetune \
5
+ --validation_file_dir ./data/finetune \
6
+ --per_device_train_batch_size 4 \
7
+ --per_device_eval_batch_size 4 \
8
+ --do_train \
9
+ --do_eval \
10
+ --use_peft True \
11
+ --fp16 \
12
+ --max_train_samples 1000 \
13
+ --max_eval_samples 10 \
14
+ --num_train_epochs 1 \
15
+ --learning_rate 2e-5 \
16
+ --warmup_ratio 0.05 \
17
+ --weight_decay 0.05 \
18
+ --logging_strategy steps \
19
+ --logging_steps 10 \
20
+ --eval_steps 50 \
21
+ --evaluation_strategy steps \
22
+ --save_steps 500 \
23
+ --save_strategy steps \
24
+ --save_total_limit 3 \
25
+ --gradient_accumulation_steps 1 \
26
+ --preprocessing_num_workers 4 \
27
+ --output_dir outputs-sft-bloom-v1 \
28
+ --overwrite_output_dir \
29
+ --ddp_timeout 30000 \
30
+ --logging_first_step True \
31
+ --target_modules all \
32
+ --lora_rank 8 \
33
+ --lora_alpha 16 \
34
+ --lora_dropout 0.05 \
35
+ --torch_dtype float16 \
36
+ --device_map auto \
37
+ --report_to tensorboard \
38
+ --ddp_find_unused_parameters False \
39
+ --gradient_checkpointing True \
40
+ --cache_dir ./cache
run_training_dpo_pipeline.ipynb ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# Training Pipeline\n",
7
+ "[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb)"
8
+ ],
9
+ "metadata": {
10
+ "collapsed": false
11
+ }
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "tags": []
17
+ },
18
+ "source": [
19
+ "# Stage 1: Continue Pretraining\n",
20
+ "\n",
21
+ "第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
22
+ "\n",
23
+ "| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": [
30
+ "#### 说明:\n",
31
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
32
+ "\n",
33
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
34
+ "2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "source": [],
40
+ "metadata": {
41
+ "collapsed": false
42
+ }
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "## 配置运行环境\n",
49
+ "\n",
50
+ "本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
51
+ "\n",
52
+ "colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
53
+ "\n",
54
+ "步骤:\n",
55
+ "1. 下载最新代码到本地\n",
56
+ "2. 安装依赖包\n",
57
+ "\n",
58
+ "依赖包如下,保证最新版本:\n",
59
+ "\n",
60
+ "```\n",
61
+ "loguru\n",
62
+ "transformers\n",
63
+ "sentencepiece\n",
64
+ "datasets\n",
65
+ "tensorboard\n",
66
+ "tqdm\n",
67
+ "peft\n",
68
+ "trl\n",
69
+ "```"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
79
+ "%cd MedicalGPT\n",
80
+ "%ls\n",
81
+ "!pip install -r requirements.txt"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "## Stage1 咱们开始吧\n",
89
+ "\n",
90
+ "训练步骤如下:\n",
91
+ "\n",
92
+ "1. 确认训练集\n",
93
+ "2. 执行训练脚本\n",
94
+ "\n",
95
+ "训练脚本的执行逻辑如下:\n",
96
+ "1. 导入依赖包\n",
97
+ "2. 设置参数\n",
98
+ "3. 定义各函数并加载训练集\n",
99
+ "4. 加载模型和tokenizer\n",
100
+ "5. 开始训练并评估\n",
101
+ "6. 查看训练结果\n",
102
+ "\n",
103
+ "**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "%ls ./data/pretrain/"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "outputs": [],
119
+ "source": [
120
+ "!python pretraining.py \\\n",
121
+ " --model_type bloom \\\n",
122
+ " --model_name_or_path bigscience/bloomz-560m \\\n",
123
+ " --train_file_dir ./data/pretrain \\\n",
124
+ " --validation_file_dir ./data/pretrain \\\n",
125
+ " --per_device_train_batch_size 3 \\\n",
126
+ " --per_device_eval_batch_size 3 \\\n",
127
+ " --do_train \\\n",
128
+ " --do_eval \\\n",
129
+ " --use_peft True \\\n",
130
+ " --seed 42 \\\n",
131
+ " --fp16 \\\n",
132
+ " --max_train_samples 10000 \\\n",
133
+ " --max_eval_samples 10 \\\n",
134
+ " --num_train_epochs 1 \\\n",
135
+ " --learning_rate 2e-4 \\\n",
136
+ " --warmup_ratio 0.05 \\\n",
137
+ " --weight_decay 0.01 \\\n",
138
+ " --logging_strategy steps \\\n",
139
+ " --logging_steps 10 \\\n",
140
+ " --eval_steps 50 \\\n",
141
+ " --evaluation_strategy steps \\\n",
142
+ " --save_steps 500 \\\n",
143
+ " --save_strategy steps \\\n",
144
+ " --save_total_limit 3 \\\n",
145
+ " --gradient_accumulation_steps 1 \\\n",
146
+ " --preprocessing_num_workers 1 \\\n",
147
+ " --block_size 1024 \\\n",
148
+ " --output_dir outputs-pt-v1 \\\n",
149
+ " --overwrite_output_dir \\\n",
150
+ " --ddp_timeout 30000 \\\n",
151
+ " --logging_first_step True \\\n",
152
+ " --target_modules all \\\n",
153
+ " --lora_rank 8 \\\n",
154
+ " --lora_alpha 16 \\\n",
155
+ " --lora_dropout 0.05 \\\n",
156
+ " --torch_dtype float16 \\\n",
157
+ " --device_map auto \\\n",
158
+ " --report_to tensorboard \\\n",
159
+ " --ddp_find_unused_parameters False \\\n",
160
+ " --gradient_checkpointing True"
161
+ ],
162
+ "metadata": {
163
+ "collapsed": false
164
+ }
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "%ls -lh outputs-pt-v1"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {},
178
+ "source": [
179
+ "模型训练结果:\n",
180
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
181
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "markdown",
186
+ "source": [
187
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
188
+ ],
189
+ "metadata": {
190
+ "collapsed": false
191
+ }
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "outputs": [],
197
+ "source": [
198
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
199
+ " --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
200
+ ],
201
+ "metadata": {
202
+ "collapsed": false
203
+ }
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "outputs": [],
209
+ "source": [
210
+ "%ls -lh merged-pt/"
211
+ ],
212
+ "metadata": {
213
+ "collapsed": false
214
+ }
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "outputs": [],
220
+ "source": [
221
+ "%cat merged-pt/config.json"
222
+ ],
223
+ "metadata": {
224
+ "collapsed": false
225
+ }
226
+ },
227
+ {
228
+ "cell_type": "markdown",
229
+ "metadata": {},
230
+ "source": [
231
+ "Stage1 增量预训练完成。"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {
238
+ "ExecuteTime": {
239
+ "start_time": "2023-06-15T13:56:17.032821Z",
240
+ "end_time": "2023-06-15T13:56:17.081153Z"
241
+ }
242
+ },
243
+ "outputs": [],
244
+ "source": []
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "source": [
249
+ "# Stage 2: Supervised FineTuning\n",
250
+ "\n",
251
+ "第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
252
+ "\n",
253
+ "| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
254
+ ],
255
+ "metadata": {
256
+ "collapsed": false
257
+ }
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "source": [
262
+ "#### 说明:\n",
263
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
264
+ "\n",
265
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
266
+ "2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
267
+ ],
268
+ "metadata": {
269
+ "collapsed": false
270
+ }
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "source": [
275
+ "## Stage2 咱们开始吧\n",
276
+ "\n",
277
+ "训练步骤如下:\n",
278
+ "\n",
279
+ "1. 确认训练集\n",
280
+ "2. 执行训练脚本\n",
281
+ "\n",
282
+ "训练脚本的执行逻辑如下:\n",
283
+ "1. 导入依赖包\n",
284
+ "2. 设置参数\n",
285
+ "3. 定义各函数并加载训练集\n",
286
+ "4. 加载模型和tokenizer\n",
287
+ "5. 开始训练并评估\n",
288
+ "6. 查看训练结果"
289
+ ],
290
+ "metadata": {
291
+ "collapsed": false
292
+ }
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "outputs": [],
298
+ "source": [
299
+ "%ls ./data/finetune"
300
+ ],
301
+ "metadata": {
302
+ "collapsed": false,
303
+ "ExecuteTime": {
304
+ "start_time": "2023-06-15T13:58:38.778132Z",
305
+ "end_time": "2023-06-15T13:58:38.966506Z"
306
+ }
307
+ }
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "outputs": [],
313
+ "source": [
314
+ "!python supervised_finetuning.py \\\n",
315
+ " --model_type bloom \\\n",
316
+ " --model_name_or_path merged-pt \\\n",
317
+ " --train_file_dir ./data/finetune \\\n",
318
+ " --validation_file_dir ./data/finetune \\\n",
319
+ " --per_device_train_batch_size 4 \\\n",
320
+ " --per_device_eval_batch_size 4 \\\n",
321
+ " --do_train \\\n",
322
+ " --do_eval \\\n",
323
+ " --use_peft True \\\n",
324
+ " --fp16 \\\n",
325
+ " --max_train_samples 1000 \\\n",
326
+ " --max_eval_samples 10 \\\n",
327
+ " --num_train_epochs 1 \\\n",
328
+ " --learning_rate 2e-5 \\\n",
329
+ " --warmup_ratio 0.05 \\\n",
330
+ " --weight_decay 0.05 \\\n",
331
+ " --logging_strategy steps \\\n",
332
+ " --logging_steps 10 \\\n",
333
+ " --eval_steps 50 \\\n",
334
+ " --evaluation_strategy steps \\\n",
335
+ " --save_steps 500 \\\n",
336
+ " --save_strategy steps \\\n",
337
+ " --save_total_limit 3 \\\n",
338
+ " --gradient_accumulation_steps 1 \\\n",
339
+ " --preprocessing_num_workers 1 \\\n",
340
+ " --output_dir outputs-sft-v1 \\\n",
341
+ " --overwrite_output_dir \\\n",
342
+ " --ddp_timeout 30000 \\\n",
343
+ " --logging_first_step True \\\n",
344
+ " --target_modules all \\\n",
345
+ " --lora_rank 8 \\\n",
346
+ " --lora_alpha 16 \\\n",
347
+ " --lora_dropout 0.05 \\\n",
348
+ " --torch_dtype float16 \\\n",
349
+ " --device_map auto \\\n",
350
+ " --report_to tensorboard \\\n",
351
+ " --ddp_find_unused_parameters False \\\n",
352
+ " --gradient_checkpointing True"
353
+ ],
354
+ "metadata": {
355
+ "collapsed": false
356
+ }
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "outputs": [],
362
+ "source": [
363
+ "%ls -lh outputs-sft-v1"
364
+ ],
365
+ "metadata": {
366
+ "collapsed": false
367
+ }
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "source": [
372
+ "模型训练结果:\n",
373
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
374
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
375
+ ],
376
+ "metadata": {
377
+ "collapsed": false
378
+ }
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "source": [
383
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
384
+ ],
385
+ "metadata": {
386
+ "collapsed": false
387
+ }
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": null,
392
+ "outputs": [],
393
+ "source": [
394
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
395
+ " --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
396
+ ],
397
+ "metadata": {
398
+ "collapsed": false
399
+ }
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "outputs": [],
405
+ "source": [
406
+ "%ls -lh merged-sft/"
407
+ ],
408
+ "metadata": {
409
+ "collapsed": false
410
+ }
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": null,
415
+ "outputs": [],
416
+ "source": [
417
+ "%cat merged-sft/config.json"
418
+ ],
419
+ "metadata": {
420
+ "collapsed": false
421
+ }
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "source": [
426
+ "Stage2 SFT训练完成。"
427
+ ],
428
+ "metadata": {
429
+ "collapsed": false
430
+ }
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "outputs": [],
436
+ "source": [],
437
+ "metadata": {
438
+ "collapsed": false,
439
+ "ExecuteTime": {
440
+ "start_time": "2023-06-15T14:07:40.731186Z",
441
+ "end_time": "2023-06-15T14:07:40.752635Z"
442
+ }
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "source": [
448
+ "# Stage 3: DPO(Direct Preference Optimization)\n",
449
+ "\n",
450
+ "第三阶段:DPO(Direct Preference Optimization)直接偏好优化,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好\n",
451
+ "\n",
452
+ "| Stage 3: Direct Preference Optimization | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |"
453
+ ],
454
+ "metadata": {
455
+ "collapsed": false
456
+ }
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "source": [
461
+ "#### 说明:\n",
462
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
463
+ "\n",
464
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
465
+ "2. 数据集:DPO阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
466
+ ],
467
+ "metadata": {
468
+ "collapsed": false
469
+ }
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "source": [
474
+ "## Stage3 咱们开始吧\n",
475
+ "\n",
476
+ "训练步骤如下:\n",
477
+ "\n",
478
+ "1. 确认训练集\n",
479
+ "2. 执行训练脚本\n",
480
+ "\n",
481
+ "训练脚本的执行逻辑如下:\n",
482
+ "1. 导入依赖包\n",
483
+ "2. 设置参数\n",
484
+ "3. 定义各函数并加载训练集\n",
485
+ "4. 加载模型和tokenizer\n",
486
+ "5. 开始训练并评估\n",
487
+ "6. 查看训练结果"
488
+ ],
489
+ "metadata": {
490
+ "collapsed": false
491
+ }
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "outputs": [],
497
+ "source": [
498
+ "%ls ./data/reward/"
499
+ ],
500
+ "metadata": {
501
+ "collapsed": false
502
+ }
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "outputs": [],
508
+ "source": [
509
+ "!python dpo_training.py \\\n",
510
+ " --model_type bloom \\\n",
511
+ " --model_name_or_path merged-sft \\\n",
512
+ " --train_file_dir ./data/reward \\\n",
513
+ " --validation_file_dir ./data/reward \\\n",
514
+ " --per_device_train_batch_size 3 \\\n",
515
+ " --per_device_eval_batch_size 1 \\\n",
516
+ " --do_train \\\n",
517
+ " --do_eval \\\n",
518
+ " --use_peft True \\\n",
519
+ " --max_train_samples 1000 \\\n",
520
+ " --max_eval_samples 10 \\\n",
521
+ " --max_steps 100 \\\n",
522
+ " --eval_steps 10 \\\n",
523
+ " --save_steps 50 \\\n",
524
+ " --max_source_length 128 \\\n",
525
+ " --max_target_length 128 \\\n",
526
+ " --output_dir outputs-dpo-v1 \\\n",
527
+ " --target_modules all \\\n",
528
+ " --lora_rank 8 \\\n",
529
+ " --lora_alpha 16 \\\n",
530
+ " --lora_dropout 0.05 \\\n",
531
+ " --torch_dtype float16 \\\n",
532
+ " --fp16 True \\\n",
533
+ " --device_map auto \\\n",
534
+ " --report_to tensorboard \\\n",
535
+ " --remove_unused_columns False \\\n",
536
+ " --gradient_checkpointing True \\\n",
537
+ " --cache_dir ./cache"
538
+ ],
539
+ "metadata": {
540
+ "collapsed": false
541
+ }
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": null,
546
+ "outputs": [],
547
+ "source": [
548
+ "%ls -lh outputs-dpo-v1"
549
+ ],
550
+ "metadata": {
551
+ "collapsed": false
552
+ }
553
+ },
554
+ {
555
+ "cell_type": "markdown",
556
+ "source": [
557
+ "模型训练结果:\n",
558
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
559
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
560
+ ],
561
+ "metadata": {
562
+ "collapsed": false
563
+ }
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "source": [
568
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
569
+ ],
570
+ "metadata": {
571
+ "collapsed": false
572
+ }
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": null,
577
+ "outputs": [],
578
+ "source": [
579
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
580
+ " --base_model_name_or_path merged-sft --peft_model_path outputs-dpo-v1 --output_dir merged-dpo/"
581
+ ],
582
+ "metadata": {
583
+ "collapsed": false
584
+ }
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "execution_count": null,
589
+ "outputs": [],
590
+ "source": [
591
+ "%ls -lh merged-dpo/"
592
+ ],
593
+ "metadata": {
594
+ "collapsed": false
595
+ }
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": null,
600
+ "outputs": [],
601
+ "source": [
602
+ "%cat merged-dpo/config.json"
603
+ ],
604
+ "metadata": {
605
+ "collapsed": false
606
+ }
607
+ },
608
+ {
609
+ "cell_type": "markdown",
610
+ "source": [
611
+ "Stage3 偏好建模第一次训练完成。"
612
+ ],
613
+ "metadata": {
614
+ "collapsed": false
615
+ }
616
+ },
617
+ {
618
+ "cell_type": "markdown",
619
+ "source": [
620
+ "**至此一个完整的训练流程演示完成。**"
621
+ ],
622
+ "metadata": {
623
+ "collapsed": false
624
+ }
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": null,
629
+ "outputs": [],
630
+ "source": [],
631
+ "metadata": {
632
+ "collapsed": false,
633
+ "ExecuteTime": {
634
+ "start_time": "2023-06-26T12:34:29.620609Z",
635
+ "end_time": "2023-06-26T12:34:29.658428Z"
636
+ }
637
+ }
638
+ },
639
+ {
640
+ "cell_type": "markdown",
641
+ "source": [
642
+ "# Test"
643
+ ],
644
+ "metadata": {
645
+ "collapsed": false
646
+ }
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "execution_count": null,
651
+ "outputs": [],
652
+ "source": [
653
+ "!python inference.py --model_type bloom --base_model merged-dpo --interactive"
654
+ ],
655
+ "metadata": {
656
+ "collapsed": false,
657
+ "ExecuteTime": {
658
+ "start_time": "2023-06-26T12:34:47.802087Z",
659
+ "end_time": "2023-06-26T12:35:00.864463Z"
660
+ }
661
+ }
662
+ },
663
+ {
664
+ "cell_type": "markdown",
665
+ "source": [
666
+ "Input:介绍下南京\n",
667
+ "Response: 南京市位于江苏省西南部,是全国首批历史文化名城、国家中心城市和自由贸易试验区。\n",
668
+ "\n",
669
+ "完。\n"
670
+ ],
671
+ "metadata": {
672
+ "collapsed": false
673
+ }
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "execution_count": null,
678
+ "outputs": [],
679
+ "source": [],
680
+ "metadata": {
681
+ "collapsed": false
682
+ }
683
+ }
684
+ ],
685
+ "metadata": {
686
+ "kernelspec": {
687
+ "name": "python3",
688
+ "language": "python",
689
+ "display_name": "Python 3"
690
+ },
691
+ "language_info": {
692
+ "codemirror_mode": {
693
+ "name": "ipython",
694
+ "version": 3
695
+ },
696
+ "file_extension": ".py",
697
+ "mimetype": "text/x-python",
698
+ "name": "python",
699
+ "nbconvert_exporter": "python",
700
+ "pygments_lexer": "ipython3",
701
+ "version": "3.8.13"
702
+ },
703
+ "vscode": {
704
+ "interpreter": {
705
+ "hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
706
+ }
707
+ }
708
+ },
709
+ "nbformat": 4,
710
+ "nbformat_minor": 4
711
+ }
run_training_pipeline.ipynb ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# Training Pipeline\n",
7
+ "[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb)"
8
+ ],
9
+ "metadata": {
10
+ "collapsed": false
11
+ }
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "tags": []
17
+ },
18
+ "source": [
19
+ "# Stage 1: Continue Pretraining\n",
20
+ "\n",
21
+ "第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
22
+ "\n",
23
+ "| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": [
30
+ "#### 说明:\n",
31
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
32
+ "\n",
33
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
34
+ "2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "## 配置运行环境\n",
42
+ "\n",
43
+ "本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
44
+ "\n",
45
+ "colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
46
+ "\n",
47
+ "步骤:\n",
48
+ "1. 下载最新代码到本地\n",
49
+ "2. 安装依赖包\n",
50
+ "\n",
51
+ "依赖包如下,保证最新版本:\n",
52
+ "\n",
53
+ "```\n",
54
+ "loguru\n",
55
+ "transformers\n",
56
+ "sentencepiece\n",
57
+ "datasets\n",
58
+ "tensorboard\n",
59
+ "tqdm\n",
60
+ "peft\n",
61
+ "trl\n",
62
+ "```"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
72
+ "%cd MedicalGPT\n",
73
+ "%ls\n",
74
+ "!pip install -r requirements.txt"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "metadata": {},
80
+ "source": [
81
+ "## Stage1 咱们开始吧\n",
82
+ "\n",
83
+ "训练步骤如下:\n",
84
+ "\n",
85
+ "1. 确认训练集\n",
86
+ "2. 执行训练脚本\n",
87
+ "\n",
88
+ "训练脚本的执行逻辑如下:\n",
89
+ "1. 导入依赖包\n",
90
+ "2. 设置参数\n",
91
+ "3. 定义各函数并加载训练集\n",
92
+ "4. 加载模型和tokenizer\n",
93
+ "5. 开始训练并评估\n",
94
+ "6. 查看训练结果\n",
95
+ "\n",
96
+ "**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "%ls ./data/pretrain/"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "outputs": [],
112
+ "source": [
113
+ "!python pretraining.py \\\n",
114
+ " --model_type bloom \\\n",
115
+ " --model_name_or_path bigscience/bloomz-560m \\\n",
116
+ " --train_file_dir ./data/pretrain \\\n",
117
+ " --validation_file_dir ./data/pretrain \\\n",
118
+ " --per_device_train_batch_size 3 \\\n",
119
+ " --per_device_eval_batch_size 3 \\\n",
120
+ " --do_train \\\n",
121
+ " --do_eval \\\n",
122
+ " --use_peft True \\\n",
123
+ " --seed 42 \\\n",
124
+ " --fp16 \\\n",
125
+ " --max_train_samples 10000 \\\n",
126
+ " --max_eval_samples 10 \\\n",
127
+ " --num_train_epochs 1 \\\n",
128
+ " --learning_rate 2e-4 \\\n",
129
+ " --warmup_ratio 0.05 \\\n",
130
+ " --weight_decay 0.01 \\\n",
131
+ " --logging_strategy steps \\\n",
132
+ " --logging_steps 10 \\\n",
133
+ " --eval_steps 50 \\\n",
134
+ " --evaluation_strategy steps \\\n",
135
+ " --save_steps 500 \\\n",
136
+ " --save_strategy steps \\\n",
137
+ " --save_total_limit 3 \\\n",
138
+ " --gradient_accumulation_steps 1 \\\n",
139
+ " --preprocessing_num_workers 1 \\\n",
140
+ " --block_size 1024 \\\n",
141
+ " --output_dir outputs-pt-v1 \\\n",
142
+ " --overwrite_output_dir \\\n",
143
+ " --ddp_timeout 30000 \\\n",
144
+ " --logging_first_step True \\\n",
145
+ " --target_modules all \\\n",
146
+ " --lora_rank 8 \\\n",
147
+ " --lora_alpha 16 \\\n",
148
+ " --lora_dropout 0.05 \\\n",
149
+ " --torch_dtype float16 \\\n",
150
+ " --device_map auto \\\n",
151
+ " --report_to tensorboard \\\n",
152
+ " --ddp_find_unused_parameters False \\\n",
153
+ " --gradient_checkpointing True"
154
+ ],
155
+ "metadata": {
156
+ "collapsed": false
157
+ }
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "%ls -lh outputs-pt-v1"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "模型训练结果:\n",
173
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
174
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "source": [
180
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
181
+ ],
182
+ "metadata": {
183
+ "collapsed": false
184
+ }
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "outputs": [],
190
+ "source": [
191
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
192
+ " --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
193
+ ],
194
+ "metadata": {
195
+ "collapsed": false
196
+ }
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "outputs": [],
202
+ "source": [
203
+ "%ls -lh merged-pt/"
204
+ ],
205
+ "metadata": {
206
+ "collapsed": false
207
+ }
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "outputs": [],
213
+ "source": [
214
+ "%cat merged-pt/config.json"
215
+ ],
216
+ "metadata": {
217
+ "collapsed": false
218
+ }
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "metadata": {},
223
+ "source": [
224
+ "Stage1 增量预训练完成。"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {
231
+ "ExecuteTime": {
232
+ "start_time": "2023-06-15T13:56:17.032821Z",
233
+ "end_time": "2023-06-15T13:56:17.081153Z"
234
+ }
235
+ },
236
+ "outputs": [],
237
+ "source": []
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "source": [
242
+ "# Stage 2: Supervised FineTuning\n",
243
+ "\n",
244
+ "第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
245
+ "\n",
246
+ "| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
247
+ ],
248
+ "metadata": {
249
+ "collapsed": false
250
+ }
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "source": [
255
+ "#### 说明:\n",
256
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
257
+ "\n",
258
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
259
+ "2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
260
+ ],
261
+ "metadata": {
262
+ "collapsed": false
263
+ }
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "source": [
268
+ "## Stage2 咱们开始吧\n",
269
+ "\n",
270
+ "训练步骤如下:\n",
271
+ "\n",
272
+ "1. 确认训练集\n",
273
+ "2. 执行训练脚本\n",
274
+ "\n",
275
+ "训练脚本的执行逻辑如下:\n",
276
+ "1. 导入依赖包\n",
277
+ "2. 设置参数\n",
278
+ "3. 定义各函数并加载训练集\n",
279
+ "4. 加载模型和tokenizer\n",
280
+ "5. 开始训练并评估\n",
281
+ "6. 查看训练结果"
282
+ ],
283
+ "metadata": {
284
+ "collapsed": false
285
+ }
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "outputs": [],
291
+ "source": [
292
+ "%ls ./data/finetune"
293
+ ],
294
+ "metadata": {
295
+ "collapsed": false,
296
+ "ExecuteTime": {
297
+ "start_time": "2023-06-15T13:58:38.778132Z",
298
+ "end_time": "2023-06-15T13:58:38.966506Z"
299
+ }
300
+ }
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "outputs": [],
306
+ "source": [
307
+ "!python supervised_finetuning.py \\\n",
308
+ " --model_type bloom \\\n",
309
+ " --model_name_or_path merged-pt \\\n",
310
+ " --train_file_dir ./data/finetune \\\n",
311
+ " --validation_file_dir ./data/finetune \\\n",
312
+ " --per_device_train_batch_size 4 \\\n",
313
+ " --per_device_eval_batch_size 4 \\\n",
314
+ " --do_train \\\n",
315
+ " --do_eval \\\n",
316
+ " --use_peft True \\\n",
317
+ " --fp16 \\\n",
318
+ " --max_train_samples 1000 \\\n",
319
+ " --max_eval_samples 10 \\\n",
320
+ " --num_train_epochs 1 \\\n",
321
+ " --learning_rate 2e-5 \\\n",
322
+ " --warmup_ratio 0.05 \\\n",
323
+ " --weight_decay 0.05 \\\n",
324
+ " --logging_strategy steps \\\n",
325
+ " --logging_steps 10 \\\n",
326
+ " --eval_steps 50 \\\n",
327
+ " --evaluation_strategy steps \\\n",
328
+ " --save_steps 500 \\\n",
329
+ " --save_strategy steps \\\n",
330
+ " --save_total_limit 3 \\\n",
331
+ " --gradient_accumulation_steps 1 \\\n",
332
+ " --preprocessing_num_workers 1 \\\n",
333
+ " --output_dir outputs-sft-v1 \\\n",
334
+ " --overwrite_output_dir \\\n",
335
+ " --ddp_timeout 30000 \\\n",
336
+ " --logging_first_step True \\\n",
337
+ " --target_modules all \\\n",
338
+ " --lora_rank 8 \\\n",
339
+ " --lora_alpha 16 \\\n",
340
+ " --lora_dropout 0.05 \\\n",
341
+ " --torch_dtype float16 \\\n",
342
+ " --device_map auto \\\n",
343
+ " --report_to tensorboard \\\n",
344
+ " --ddp_find_unused_parameters False \\\n",
345
+ " --gradient_checkpointing True"
346
+ ],
347
+ "metadata": {
348
+ "collapsed": false
349
+ }
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "outputs": [],
355
+ "source": [
356
+ "%ls -lh outputs-sft-v1"
357
+ ],
358
+ "metadata": {
359
+ "collapsed": false
360
+ }
361
+ },
362
+ {
363
+ "cell_type": "markdown",
364
+ "source": [
365
+ "模型训练结果:\n",
366
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
367
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
368
+ ],
369
+ "metadata": {
370
+ "collapsed": false
371
+ }
372
+ },
373
+ {
374
+ "cell_type": "markdown",
375
+ "source": [
376
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
377
+ ],
378
+ "metadata": {
379
+ "collapsed": false
380
+ }
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "outputs": [],
386
+ "source": [
387
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
388
+ " --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
389
+ ],
390
+ "metadata": {
391
+ "collapsed": false
392
+ }
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "outputs": [],
398
+ "source": [
399
+ "%ls -lh merged-sft/"
400
+ ],
401
+ "metadata": {
402
+ "collapsed": false
403
+ }
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "outputs": [],
409
+ "source": [
410
+ "%cat merged-sft/config.json"
411
+ ],
412
+ "metadata": {
413
+ "collapsed": false
414
+ }
415
+ },
416
+ {
417
+ "cell_type": "markdown",
418
+ "source": [
419
+ "Stage2 SFT训练完成。"
420
+ ],
421
+ "metadata": {
422
+ "collapsed": false
423
+ }
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "outputs": [],
429
+ "source": [],
430
+ "metadata": {
431
+ "collapsed": false,
432
+ "ExecuteTime": {
433
+ "start_time": "2023-06-15T14:07:40.731186Z",
434
+ "end_time": "2023-06-15T14:07:40.752635Z"
435
+ }
436
+ }
437
+ },
438
+ {
439
+ "cell_type": "markdown",
440
+ "source": [
441
+ "# Stage 3: Reward Modeling\n",
442
+ "\n",
443
+ "第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是\"HHH\"原则,具体是\"helpful, honest, harmless\"\n",
444
+ "\n",
445
+ "| Stage 3: Reward Modeling | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |"
446
+ ],
447
+ "metadata": {
448
+ "collapsed": false
449
+ }
450
+ },
451
+ {
452
+ "cell_type": "markdown",
453
+ "source": [
454
+ "#### 说明:\n",
455
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
456
+ "\n",
457
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
458
+ "2. 数据集:RM阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
459
+ ],
460
+ "metadata": {
461
+ "collapsed": false
462
+ }
463
+ },
464
+ {
465
+ "cell_type": "markdown",
466
+ "source": [
467
+ "## Stage3 咱们开始吧\n",
468
+ "\n",
469
+ "训练步骤如下:\n",
470
+ "\n",
471
+ "1. 确认训练集\n",
472
+ "2. 执行训练脚本\n",
473
+ "\n",
474
+ "训练脚本的执行逻辑如下:\n",
475
+ "1. 导入依赖包\n",
476
+ "2. 设置参数\n",
477
+ "3. 定义各函数并加载训练集\n",
478
+ "4. 加载模型和tokenizer\n",
479
+ "5. 开始训练并评估\n",
480
+ "6. 查看训练结果"
481
+ ],
482
+ "metadata": {
483
+ "collapsed": false
484
+ }
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "outputs": [],
490
+ "source": [
491
+ "%ls ./data/reward/"
492
+ ],
493
+ "metadata": {
494
+ "collapsed": false
495
+ }
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "outputs": [],
501
+ "source": [
502
+ "!python reward_modeling.py \\\n",
503
+ " --model_type bloom \\\n",
504
+ " --model_name_or_path merged-sft \\\n",
505
+ " --train_file_dir ./data/reward \\\n",
506
+ " --validation_file_dir ./data/reward \\\n",
507
+ " --per_device_train_batch_size 3 \\\n",
508
+ " --per_device_eval_batch_size 1 \\\n",
509
+ " --do_train \\\n",
510
+ " --use_peft True \\\n",
511
+ " --seed 42 \\\n",
512
+ " --max_train_samples 1000 \\\n",
513
+ " --max_eval_samples 10 \\\n",
514
+ " --num_train_epochs 1 \\\n",
515
+ " --learning_rate 2e-5 \\\n",
516
+ " --warmup_ratio 0.05 \\\n",
517
+ " --weight_decay 0.001 \\\n",
518
+ " --logging_strategy steps \\\n",
519
+ " --logging_steps 10 \\\n",
520
+ " --eval_steps 50 \\\n",
521
+ " --evaluation_strategy steps \\\n",
522
+ " --save_steps 500 \\\n",
523
+ " --save_strategy steps \\\n",
524
+ " --save_total_limit 3 \\\n",
525
+ " --max_source_length 256 \\\n",
526
+ " --max_target_length 256 \\\n",
527
+ " --output_dir outputs-rm-v1 \\\n",
528
+ " --overwrite_output_dir \\\n",
529
+ " --ddp_timeout 30000 \\\n",
530
+ " --logging_first_step True \\\n",
531
+ " --target_modules all \\\n",
532
+ " --lora_rank 8 \\\n",
533
+ " --lora_alpha 16 \\\n",
534
+ " --lora_dropout 0.05 \\\n",
535
+ " --torch_dtype float32 \\\n",
536
+ " --device_map auto \\\n",
537
+ " --report_to tensorboard \\\n",
538
+ " --ddp_find_unused_parameters False \\\n",
539
+ " --remove_unused_columns False \\\n",
540
+ " --gradient_checkpointing True"
541
+ ],
542
+ "metadata": {
543
+ "collapsed": false
544
+ }
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "outputs": [],
550
+ "source": [
551
+ "%ls -lh outputs-rm-v1"
552
+ ],
553
+ "metadata": {
554
+ "collapsed": false
555
+ }
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "source": [
560
+ "模型训练结果:\n",
561
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
562
+ "- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
563
+ ],
564
+ "metadata": {
565
+ "collapsed": false
566
+ }
567
+ },
568
+ {
569
+ "cell_type": "markdown",
570
+ "source": [
571
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
572
+ ],
573
+ "metadata": {
574
+ "collapsed": false
575
+ }
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": null,
580
+ "outputs": [],
581
+ "source": [
582
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
583
+ " --base_model_name_or_path merged-sft --peft_model_path outputs-rm-v1 --output_dir merged-rm/"
584
+ ],
585
+ "metadata": {
586
+ "collapsed": false
587
+ }
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "outputs": [],
593
+ "source": [
594
+ "%ls -lh merged-rm/"
595
+ ],
596
+ "metadata": {
597
+ "collapsed": false
598
+ }
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": null,
603
+ "outputs": [],
604
+ "source": [
605
+ "%cat merged-rm/config.json"
606
+ ],
607
+ "metadata": {
608
+ "collapsed": false
609
+ }
610
+ },
611
+ {
612
+ "cell_type": "markdown",
613
+ "source": [
614
+ "Stage3 奖励建模第一次训练完成。"
615
+ ],
616
+ "metadata": {
617
+ "collapsed": false
618
+ }
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": null,
623
+ "outputs": [],
624
+ "source": [],
625
+ "metadata": {
626
+ "collapsed": false,
627
+ "ExecuteTime": {
628
+ "start_time": "2023-06-15T14:12:09.464881Z",
629
+ "end_time": "2023-06-15T14:12:09.472414Z"
630
+ }
631
+ }
632
+ },
633
+ {
634
+ "cell_type": "markdown",
635
+ "source": [
636
+ "# Stage 4: Reinforcement Learning Training\n",
637
+ "\n",
638
+ "第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本\n",
639
+ "\n",
640
+ "| Stage 4: Reinforcement Learning | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |\n"
641
+ ],
642
+ "metadata": {
643
+ "collapsed": false
644
+ }
645
+ },
646
+ {
647
+ "cell_type": "markdown",
648
+ "source": [
649
+ "#### 说明:\n",
650
+ "以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型、奖励模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
651
+ "\n",
652
+ "1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
653
+ "2. 奖励模型:使用的是`OpenAssistant/reward-model-deberta-v3-large-v2` 或者 Stage3得到的BERT类或者GPT类奖励模型\n",
654
+ "3. 数据集:RL阶段的数据可以复用SFT的数据集,使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
655
+ ],
656
+ "metadata": {
657
+ "collapsed": false
658
+ }
659
+ },
660
+ {
661
+ "cell_type": "markdown",
662
+ "source": [
663
+ "## Stage4 咱们开始吧\n",
664
+ "\n",
665
+ "训练步骤如下:\n",
666
+ "\n",
667
+ "1. 确认训练集\n",
668
+ "2. 执行训练脚本\n",
669
+ "\n",
670
+ "训练脚本的执行逻辑如下:\n",
671
+ "1. 导入依赖包\n",
672
+ "2. 设置参数\n",
673
+ "3. 定义各函数并加载训练集\n",
674
+ "4. 加载生成模型和tokenizer,加载奖励模型和其tokenizer\n",
675
+ "5. 开始训练并评估\n",
676
+ "6. 查看训练结果\n",
677
+ "\n",
678
+ "以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的。"
679
+ ],
680
+ "metadata": {
681
+ "collapsed": false
682
+ }
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": null,
687
+ "outputs": [],
688
+ "source": [
689
+ "%ls ./data/finetune/"
690
+ ],
691
+ "metadata": {
692
+ "collapsed": false
693
+ }
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "outputs": [],
699
+ "source": [
700
+ "!python rl_training.py \\\n",
701
+ " --model_type bloom \\\n",
702
+ " --model_name_or_path merged-sft \\\n",
703
+ " --reward_model_name_or_path merged-rm \\\n",
704
+ " --torch_dtype float16 \\\n",
705
+ " --device_map auto \\\n",
706
+ " --train_file_dir ./data/finetune \\\n",
707
+ " --validation_file_dir ./data/finetune \\\n",
708
+ " --batch_size 4 \\\n",
709
+ " --max_source_length 256 \\\n",
710
+ " --max_target_length 256 \\\n",
711
+ " --max_train_samples 1000 \\\n",
712
+ " --use_peft True \\\n",
713
+ " --lora_rank 8 \\\n",
714
+ " --lora_alpha 16 \\\n",
715
+ " --lora_dropout 0.05 \\\n",
716
+ " --do_train \\\n",
717
+ " --max_steps 64 \\\n",
718
+ " --learning_rate 1e-5 \\\n",
719
+ " --save_steps 50 \\\n",
720
+ " --output_dir outputs-rl-v1 \\\n",
721
+ " --early_stopping True \\\n",
722
+ " --target_kl 0.1 \\\n",
723
+ " --reward_baseline 0.0"
724
+ ],
725
+ "metadata": {
726
+ "collapsed": false
727
+ }
728
+ },
729
+ {
730
+ "cell_type": "code",
731
+ "execution_count": null,
732
+ "outputs": [],
733
+ "source": [
734
+ "%ls -lh outputs-rl-v1"
735
+ ],
736
+ "metadata": {
737
+ "collapsed": false
738
+ }
739
+ },
740
+ {
741
+ "cell_type": "markdown",
742
+ "source": [
743
+ "模型训练结果:\n",
744
+ "- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
745
+ "- 日志保存在`output_dir/trl`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/trl --host 0.0.0.0 --port 8009`"
746
+ ],
747
+ "metadata": {
748
+ "collapsed": false
749
+ }
750
+ },
751
+ {
752
+ "cell_type": "markdown",
753
+ "source": [
754
+ "lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
755
+ ],
756
+ "metadata": {
757
+ "collapsed": false
758
+ }
759
+ },
760
+ {
761
+ "cell_type": "code",
762
+ "execution_count": null,
763
+ "outputs": [],
764
+ "source": [
765
+ "!python merge_peft_adapter.py --model_type bloom \\\n",
766
+ " --base_model_name_or_path merged-sft --peft_model_path outputs-rl-v1 --output_dir merged-rl/"
767
+ ],
768
+ "metadata": {
769
+ "collapsed": false
770
+ }
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": null,
775
+ "outputs": [],
776
+ "source": [
777
+ "%ls -lh merged-rl/"
778
+ ],
779
+ "metadata": {
780
+ "collapsed": false
781
+ }
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "execution_count": null,
786
+ "outputs": [],
787
+ "source": [
788
+ "%cat merged-rl/config.json"
789
+ ],
790
+ "metadata": {
791
+ "collapsed": false
792
+ }
793
+ },
794
+ {
795
+ "cell_type": "markdown",
796
+ "source": [
797
+ "Stage4 RL第一次训练完成。\n",
798
+ "\n",
799
+ "**至此一个完整的4阶段训练流程演示完成。**"
800
+ ],
801
+ "metadata": {
802
+ "collapsed": false
803
+ }
804
+ },
805
+ {
806
+ "cell_type": "markdown",
807
+ "source": [
808
+ "实际操作中Stage3和Stage4可以反复多次,直到RL得到的最后模型满足评估要求。\n",
809
+ "\n",
810
+ "RLHF过程可以把SFT模型当成一个初始化模型,RM模型当做指导老师,使用RL(PPO)调教SFT模型生成指导老师最满意的结果,如果小学老师满意了,我们就再训练一个中学老师,继续指导,中学老师满意了,就训练一个大学老师,这样不断迭代,使得生成模型的质量达到甚至超过人工撰写的天花板。\n",
811
+ "\n",
812
+ "RLHF训练不易,此项目提供给大家一种实现的方法和参考,希望抛砖引玉,共同促进中文开源LLM发展。"
813
+ ],
814
+ "metadata": {
815
+ "collapsed": false
816
+ }
817
+ },
818
+ {
819
+ "cell_type": "markdown",
820
+ "source": [],
821
+ "metadata": {
822
+ "collapsed": false
823
+ }
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": null,
828
+ "outputs": [],
829
+ "source": [],
830
+ "metadata": {
831
+ "collapsed": false,
832
+ "ExecuteTime": {
833
+ "start_time": "2023-06-26T12:34:29.620609Z",
834
+ "end_time": "2023-06-26T12:34:29.658428Z"
835
+ }
836
+ }
837
+ },
838
+ {
839
+ "cell_type": "markdown",
840
+ "source": [
841
+ "# Test"
842
+ ],
843
+ "metadata": {
844
+ "collapsed": false
845
+ }
846
+ },
847
+ {
848
+ "cell_type": "markdown",
849
+ "source": [],
850
+ "metadata": {
851
+ "collapsed": false
852
+ }
853
+ },
854
+ {
855
+ "cell_type": "code",
856
+ "execution_count": null,
857
+ "outputs": [],
858
+ "source": [
859
+ "!python inference.py --model_type bloom --base_model merged-rl --interactive"
860
+ ],
861
+ "metadata": {
862
+ "collapsed": false,
863
+ "ExecuteTime": {
864
+ "start_time": "2023-06-26T12:34:47.802087Z",
865
+ "end_time": "2023-06-26T12:35:00.864463Z"
866
+ }
867
+ }
868
+ },
869
+ {
870
+ "cell_type": "markdown",
871
+ "source": [
872
+ "Input:介绍下南京\n",
873
+ "Response: 南京市位于江苏省西南部,是全国��批历史文化名城、国家中心城市和自由贸易试验区。\n",
874
+ "\n",
875
+ "完。\n"
876
+ ],
877
+ "metadata": {
878
+ "collapsed": false
879
+ }
880
+ },
881
+ {
882
+ "cell_type": "code",
883
+ "execution_count": null,
884
+ "outputs": [],
885
+ "source": [],
886
+ "metadata": {
887
+ "collapsed": false
888
+ }
889
+ }
890
+ ],
891
+ "metadata": {
892
+ "kernelspec": {
893
+ "name": "python3",
894
+ "language": "python",
895
+ "display_name": "Python 3"
896
+ },
897
+ "language_info": {
898
+ "codemirror_mode": {
899
+ "name": "ipython",
900
+ "version": 3
901
+ },
902
+ "file_extension": ".py",
903
+ "mimetype": "text/x-python",
904
+ "name": "python",
905
+ "nbconvert_exporter": "python",
906
+ "pygments_lexer": "ipython3",
907
+ "version": "3.8.13"
908
+ },
909
+ "vscode": {
910
+ "interpreter": {
911
+ "hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
912
+ }
913
+ }
914
+ },
915
+ "nbformat": 4,
916
+ "nbformat_minor": 4
917
+ }
supervised_finetuning.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2023 XuMing([email protected]) and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
17
+
18
+ part of this code is adapted from https://github.com/shibing624/textgen
19
+ """
20
+ import math
21
+ import os
22
+ from dataclasses import dataclass, field
23
+ from glob import glob
24
+ from typing import List, Optional, Dict, Sequence
25
+
26
+ import torch
27
+ from datasets import load_dataset
28
+ from loguru import logger
29
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
30
+ from transformers import (
31
+ AutoConfig,
32
+ BloomForCausalLM,
33
+ AutoModel,
34
+ AutoModelForCausalLM,
35
+ LlamaTokenizer,
36
+ LlamaForCausalLM,
37
+ BloomTokenizerFast,
38
+ AutoTokenizer,
39
+ HfArgumentParser,
40
+ Trainer,
41
+ TrainingArguments,
42
+ set_seed,
43
+ BitsAndBytesConfig,
44
+ DataCollatorForSeq2Seq,
45
+ )
46
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
47
+ from transformers.trainer import TRAINING_ARGS_NAME
48
+ from transformers.trainer_pt_utils import LabelSmoother
49
+
50
+ MODEL_CLASSES = {
51
+ "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
52
+ "chatglm": (AutoConfig, AutoModel, AutoTokenizer),
53
+ "llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
54
+ "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
55
+ "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
56
+ }
57
+
58
+
59
+ @dataclass
60
+ class ModelArguments:
61
+ """
62
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
63
+ """
64
+
65
+ model_type: str = field(
66
+ default=None,
67
+ metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
68
+ )
69
+ model_name_or_path: Optional[str] = field(
70
+ default=None,
71
+ metadata={
72
+ "help": (
73
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
74
+ )
75
+ },
76
+ )
77
+ tokenizer_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": (
81
+ "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
82
+ )
83
+ },
84
+ )
85
+ load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
86
+ cache_dir: Optional[str] = field(
87
+ default=None,
88
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
89
+ )
90
+ use_fast_tokenizer: bool = field(
91
+ default=False,
92
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
93
+ )
94
+ torch_dtype: Optional[str] = field(
95
+ default="float16",
96
+ metadata={
97
+ "help": (
98
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
99
+ "dtype will be automatically derived from the model's weights."
100
+ ),
101
+ "choices": ["auto", "bfloat16", "float16", "float32"],
102
+ },
103
+ )
104
+ device_map: Optional[str] = field(
105
+ default="auto",
106
+ metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
107
+ )
108
+ trust_remote_code: bool = field(
109
+ default=True,
110
+ metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
111
+ )
112
+
113
+ def __post_init__(self):
114
+ if self.model_type is None:
115
+ raise ValueError(
116
+ "You must specify a valid model_type to run training. Available model types are " + ", ".join(
117
+ MODEL_CLASSES.keys()))
118
+ if self.model_name_or_path is None:
119
+ raise ValueError("You must specify a valid model_name_or_path to run training.")
120
+
121
+
122
+ @dataclass
123
+ class DataTrainingArguments:
124
+ """
125
+ Arguments pertaining to what data we are going to input our model for training and eval.
126
+ """
127
+
128
+ dataset_name: Optional[str] = field(
129
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
130
+ )
131
+ dataset_config_name: Optional[str] = field(
132
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
133
+ )
134
+ train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."})
135
+ validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."})
136
+ template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
137
+ max_train_samples: Optional[int] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": (
141
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
142
+ "value if set."
143
+ )
144
+ },
145
+ )
146
+ max_eval_samples: Optional[int] = field(
147
+ default=None,
148
+ metadata={
149
+ "help": (
150
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
151
+ "value if set."
152
+ )
153
+ },
154
+ )
155
+ max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
156
+ max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
157
+ ignore_pad_token_for_loss: bool = field(
158
+ default=True,
159
+ metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
160
+ )
161
+ overwrite_cache: bool = field(
162
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
163
+ )
164
+ validation_split_percentage: Optional[int] = field(
165
+ default=1,
166
+ metadata={
167
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
168
+ },
169
+ )
170
+ preprocessing_num_workers: Optional[int] = field(
171
+ default=None,
172
+ metadata={"help": "The number of processes to use for the preprocessing."},
173
+ )
174
+
175
+ def __post_init__(self):
176
+ if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000:
177
+ logger.warning("You may set max_train_samples = -1 to run all samples in production.")
178
+ if self.max_source_length < 30:
179
+ raise ValueError("You must specify a valid max_source_length >= 30 to run training.")
180
+
181
+
182
+ @dataclass
183
+ class PeftArguments(TrainingArguments):
184
+ use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
185
+ target_modules: Optional[str] = field(default="all")
186
+ lora_rank: Optional[int] = field(default=8)
187
+ lora_dropout: Optional[float] = field(default=0.05)
188
+ lora_alpha: Optional[float] = field(default=32.0)
189
+ modules_to_save: Optional[str] = field(default=None)
190
+ peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"})
191
+ qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
192
+
193
+
194
+ class CastOutputToFloat(torch.nn.Sequential):
195
+ """Cast the output of the model to float"""
196
+
197
+ def forward(self, x):
198
+ return super().forward(x).to(torch.float32)
199
+
200
+
201
+ @dataclass
202
+ class Conversation:
203
+ """A class that manages prompt templates and keeps all conversation history."""
204
+
205
+ # The name of this template
206
+ name: str
207
+ # The system prompt
208
+ system_prompt: str
209
+ # All messages. format: list of [question, answer]
210
+ messages: Optional[List[Sequence[str]]]
211
+ # The roles of the speakers
212
+ roles: Optional[Sequence[str]]
213
+ # Conversation prompt
214
+ prompt: str
215
+ # Separator
216
+ sep: str
217
+ # Stop token, default is tokenizer.eos_token
218
+ stop_str: Optional[str] = "</s>"
219
+
220
+ def get_prompt(
221
+ self,
222
+ messages: Optional[List[Sequence[str]]] = None,
223
+ system_prompt: Optional[str] = ""
224
+ ) -> str:
225
+ """
226
+ Returns a string containing prompt without response.
227
+ """
228
+ return "".join(self._format_example(messages, system_prompt))
229
+
230
+ def get_dialog(
231
+ self,
232
+ messages: Optional[List[Sequence[str]]] = None,
233
+ system_prompt: Optional[str] = ""
234
+ ) -> List[str]:
235
+ """
236
+ Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
237
+ """
238
+ return self._format_example(messages, system_prompt)
239
+
240
+ def _format_example(
241
+ self,
242
+ messages: Optional[List[Sequence[str]]] = None,
243
+ system_prompt: Optional[str] = ""
244
+ ) -> List[str]:
245
+ system_prompt = system_prompt or self.system_prompt
246
+ system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt
247
+ messages = messages or self.messages
248
+ convs = []
249
+ for turn_idx, [user_query, bot_resp] in enumerate(messages):
250
+ if turn_idx == 0:
251
+ convs.append(system_prompt + self.prompt.format(query=user_query))
252
+ convs.append(bot_resp)
253
+ else:
254
+ convs.append(self.sep + self.prompt.format(query=user_query))
255
+ convs.append(bot_resp)
256
+ return convs
257
+
258
+ def append_message(self, query: str, answer: str):
259
+ """Append a new message."""
260
+ self.messages.append([query, answer])
261
+
262
+
263
+ # A global registry for all conversation templates
264
+ conv_templates: Dict[str, Conversation] = {}
265
+
266
+
267
+ def register_conv_template(template: Conversation):
268
+ """Register a new conversation template."""
269
+ conv_templates[template.name] = template
270
+
271
+
272
+ """Vicuna v1.1 template
273
+ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
274
+ https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
275
+ """
276
+ register_conv_template(
277
+ Conversation(
278
+ name="vicuna",
279
+ system_prompt="A chat between a curious user and an artificial intelligence assistant. "
280
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
281
+ messages=[],
282
+ roles=("USER", "ASSISTANT"),
283
+ prompt="USER: {query} ASSISTANT: ",
284
+ sep="</s>",
285
+ )
286
+ )
287
+
288
+ """Alpaca template"""
289
+ register_conv_template(
290
+ Conversation(
291
+ name="alpaca",
292
+ system_prompt="Below is an instruction that describes a task. "
293
+ "Write a response that appropriately completes the request.",
294
+ messages=[],
295
+ roles=("### Instruction", "### Response"),
296
+ prompt="### Instruction:\n{query}\n\n### Response:\n",
297
+ sep="\n\n",
298
+ )
299
+ )
300
+
301
+ """Baichuan-13B-Chat template
302
+ source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
303
+ Support: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
304
+ """
305
+ register_conv_template(
306
+ Conversation(
307
+ name="baichuan-chat",
308
+ system_prompt="",
309
+ messages=[],
310
+ roles=("<reserved_102>", "<reserved_103>"),
311
+ prompt=" <reserved_102> {query} <reserved_103> ",
312
+ sep="</s>",
313
+ )
314
+ )
315
+
316
+ """ziya template"""
317
+ register_conv_template(
318
+ Conversation(
319
+ name="ziya",
320
+ system_prompt="",
321
+ messages=[],
322
+ roles=("<human>", "<bot>"),
323
+ prompt="<human>:{query}\n<bot>:",
324
+ sep="\n",
325
+ )
326
+ )
327
+
328
+ """Linly template"""
329
+ register_conv_template(
330
+ Conversation(
331
+ name="linly",
332
+ system_prompt="",
333
+ messages=[],
334
+ roles=("User", "Bot"),
335
+ prompt="User: {query}\nBot: ",
336
+ sep="\n",
337
+ )
338
+ )
339
+
340
+ """ChatGLM1 template
341
+ source: https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1307
342
+ """
343
+ register_conv_template(
344
+ Conversation(
345
+ name="chatglm",
346
+ system_prompt="",
347
+ messages=[],
348
+ roles=("问", "答"),
349
+ prompt="问:{query}\n答:",
350
+ sep="\n",
351
+ )
352
+ )
353
+
354
+ """ChatGLM2 template
355
+ source: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1007
356
+ """
357
+ register_conv_template(
358
+ # source:
359
+ Conversation(
360
+ name="chatglm2",
361
+ system_prompt="",
362
+ messages=[],
363
+ roles=("问", "答"),
364
+ prompt="问:{query}\n\n答:",
365
+ sep="\n\n",
366
+ )
367
+ )
368
+
369
+ """Phoenix template"""
370
+ register_conv_template(
371
+ Conversation(
372
+ name="phoenix",
373
+ system_prompt="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
374
+ messages=[],
375
+ roles=("Human", "Assistant"),
376
+ prompt="Human: <s>{query}</s>Assistant: ",
377
+ sep="</s>",
378
+ )
379
+ )
380
+
381
+ """belle template
382
+ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
383
+ """
384
+ register_conv_template(
385
+ Conversation(
386
+ name="belle",
387
+ system_prompt="",
388
+ messages=[],
389
+ roles=("Human", "Belle"),
390
+ prompt="Human: {query}\n\nBelle: ",
391
+ sep="\n\n",
392
+ )
393
+ )
394
+
395
+ """aquila template
396
+ Supports: https://huggingface.co/qhduan/aquilachat-7b
397
+ """
398
+ register_conv_template(
399
+ Conversation(
400
+ name="aquila",
401
+ system_prompt="A chat between a curious human and an artificial intelligence assistant. "
402
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
403
+ messages=[],
404
+ roles=("Human", "Assistant"),
405
+ prompt="Human: {query}###Assistant: ",
406
+ sep="###",
407
+ )
408
+ )
409
+
410
+ """intern template
411
+ Supports: https://huggingface.co/internlm/internlm-chat-7b
412
+ """
413
+ register_conv_template(
414
+ Conversation(
415
+ name="intern",
416
+ system_prompt="",
417
+ messages=[],
418
+ roles=("<|User|>", "<|Bot|>"),
419
+ prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
420
+ sep="<eoa>\n",
421
+ stop_str="<eoa>",
422
+ )
423
+ )
424
+
425
+ """StarChat template"""
426
+ register_conv_template(
427
+ Conversation(
428
+ name="starchat",
429
+ system_prompt="<system>\n",
430
+ messages=[],
431
+ roles=("<|user|>", "<|assistant|>"),
432
+ prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
433
+ sep="<|end|>\n",
434
+ stop_str="<|end|>",
435
+ )
436
+ )
437
+
438
+ """llama2 template
439
+ reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
440
+ """
441
+ register_conv_template(
442
+ Conversation(
443
+ name="llama2",
444
+ system_prompt="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
445
+ "Always answer as helpfully as possible, while being safe. "
446
+ "Your answers should not include any harmful, unethical, racist, sexist, "
447
+ "toxic, dangerous, or illegal content. "
448
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
449
+ "If a question does not make any sense, or is not factually coherent, "
450
+ "explain why instead of answering something not correct. "
451
+ "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
452
+ messages=[],
453
+ roles=("[INST]", "[/INST]"),
454
+ prompt=" [INST] {query} [/INST] ",
455
+ sep="</s>",
456
+ )
457
+ )
458
+
459
+ """llama2-zh template
460
+ Sources: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
461
+ Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
462
+ """
463
+ register_conv_template(
464
+ Conversation(
465
+ name="llama2-zh",
466
+ system_prompt="<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n",
467
+ messages=[],
468
+ roles=("[INST]", "[/INST]"),
469
+ prompt=" [INST] {query} [/INST] ",
470
+ sep="</s>",
471
+ )
472
+ )
473
+ """XVERSE template
474
+ Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
475
+ """
476
+ register_conv_template(
477
+ Conversation(
478
+ name="xverse",
479
+ system_prompt="",
480
+ messages=[],
481
+ roles=("Human", "Assistant"),
482
+ prompt="Human: {query}\n\nAssistant: ",
483
+ sep="</s>",
484
+ )
485
+ )
486
+
487
+ """Qwen template
488
+ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
489
+ chatml: https://xbot123.com/645a461b922f176d7cfdbc2d/
490
+ """
491
+ register_conv_template(
492
+ Conversation(
493
+ name="chatml",
494
+ system_prompt="You are a helpful assistant.",
495
+ messages=[],
496
+ roles=("user", "assistant"),
497
+ prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
498
+ sep="<|im_end|>\n",
499
+ stop_str="<|im_end|>",
500
+ )
501
+ )
502
+
503
+
504
+ def get_conv_template(name: str) -> Conversation:
505
+ """Get a conversation template."""
506
+ return conv_templates[name]
507
+
508
+
509
+ class SavePeftModelTrainer(Trainer):
510
+ """
511
+ Trainer for lora models
512
+ """
513
+
514
+ def save_model(self, output_dir=None, _internal_call=False):
515
+ """Save the LoRA model."""
516
+ os.makedirs(output_dir, exist_ok=True)
517
+ if self.args.local_rank in [-1, 0]:
518
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
519
+ self.model.save_pretrained(output_dir)
520
+
521
+
522
+ def save_model(output_dir, model, tokenizer, args):
523
+ """Save the model and the tokenizer."""
524
+ os.makedirs(output_dir, exist_ok=True)
525
+
526
+ # Take care of distributed/parallel training
527
+ model_to_save = model.module if hasattr(model, "module") else model
528
+ if args.local_rank in [-1, 0]:
529
+ model_to_save.save_pretrained(output_dir)
530
+ tokenizer.save_pretrained(output_dir)
531
+ torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
532
+
533
+
534
+ def print_trainable_parameters(model):
535
+ """
536
+ Prints the number of trainable parameters in the model.
537
+ """
538
+ trainable_params = 0
539
+ all_param = 0
540
+ for _, param in model.named_parameters():
541
+ all_param += param.numel()
542
+ if param.requires_grad:
543
+ trainable_params += param.numel()
544
+ print(
545
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
546
+ )
547
+
548
+
549
+ def find_all_linear_names(peft_model, int4=False, int8=False):
550
+ """Find all linear layer names in the model. reference from qlora paper."""
551
+ cls = torch.nn.Linear
552
+ if int4 or int8:
553
+ import bitsandbytes as bnb
554
+ if int4:
555
+ cls = bnb.nn.Linear4bit
556
+ elif int8:
557
+ cls = bnb.nn.Linear8bitLt
558
+ lora_module_names = set()
559
+ for name, module in peft_model.named_modules():
560
+ if isinstance(module, cls):
561
+ # last layer is not add to lora_module_names
562
+ if 'lm_head' in name:
563
+ continue
564
+ if 'output_layer' in name:
565
+ continue
566
+ names = name.split('.')
567
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
568
+ return sorted(lora_module_names)
569
+
570
+
571
+ def main():
572
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
573
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
574
+
575
+ logger.info(f"Model args: {model_args}")
576
+ logger.info(f"Data args: {data_args}")
577
+ logger.info(f"Training args: {training_args}")
578
+ logger.info(
579
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
580
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
581
+ )
582
+
583
+ # Set seed before initializing model.
584
+ set_seed(training_args.seed)
585
+
586
+ if not model_args.model_type:
587
+ raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
588
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
589
+
590
+ # Load tokenizer
591
+ tokenizer_kwargs = {
592
+ "cache_dir": model_args.cache_dir,
593
+ "use_fast": model_args.use_fast_tokenizer,
594
+ "trust_remote_code": model_args.trust_remote_code,
595
+ }
596
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
597
+ if not tokenizer_name_or_path:
598
+ tokenizer_name_or_path = model_args.model_name_or_path
599
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
600
+ prompt_template = get_conv_template(data_args.template_name)
601
+ if tokenizer.eos_token_id is None:
602
+ tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
603
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
604
+ if tokenizer.pad_token_id is None:
605
+ if tokenizer.unk_token_id is not None:
606
+ tokenizer.pad_token = tokenizer.unk_token
607
+ else:
608
+ tokenizer.pad_token = tokenizer.eos_token
609
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
610
+
611
+ logger.debug(f"Tokenizer: {tokenizer}")
612
+ IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
613
+
614
+ # Get datasets
615
+ if data_args.dataset_name is not None:
616
+ # Downloading and loading a dataset from the hub.
617
+ raw_datasets = load_dataset(
618
+ data_args.dataset_name,
619
+ data_args.dataset_config_name,
620
+ cache_dir=model_args.cache_dir,
621
+ )
622
+ if "validation" not in raw_datasets.keys():
623
+ raw_datasets["validation"] = load_dataset(
624
+ data_args.dataset_name,
625
+ data_args.dataset_config_name,
626
+ split=f"train[:{data_args.validation_split_percentage}%]",
627
+ cache_dir=model_args.cache_dir,
628
+ )
629
+ raw_datasets["train"] = load_dataset(
630
+ data_args.dataset_name,
631
+ data_args.dataset_config_name,
632
+ split=f"train[{data_args.validation_split_percentage}%:]",
633
+ cache_dir=model_args.cache_dir,
634
+ )
635
+ else:
636
+ # Loading a dataset from local files.
637
+ data_files = {}
638
+ if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
639
+ train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
640
+ f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
641
+ logger.info(f"train files: {train_data_files}")
642
+ data_files["train"] = train_data_files
643
+ if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
644
+ eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
645
+ f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
646
+ logger.info(f"eval files: {eval_data_files}")
647
+ data_files["validation"] = eval_data_files
648
+ raw_datasets = load_dataset(
649
+ 'json',
650
+ data_files=data_files,
651
+ cache_dir=model_args.cache_dir,
652
+ )
653
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
654
+ if "validation" not in raw_datasets.keys():
655
+ raw_datasets["validation"] = load_dataset(
656
+ 'json',
657
+ data_files=data_files,
658
+ split=f"train[:{data_args.validation_split_percentage}%]",
659
+ cache_dir=model_args.cache_dir,
660
+ )
661
+ raw_datasets["train"] = load_dataset(
662
+ 'json',
663
+ data_files=data_files,
664
+ split=f"train[{data_args.validation_split_percentage}%:]",
665
+ cache_dir=model_args.cache_dir,
666
+ )
667
+ logger.info(f"Raw datasets: {raw_datasets}")
668
+
669
+ # Preprocessing the datasets
670
+ max_source_length = data_args.max_source_length
671
+ max_target_length = data_args.max_target_length
672
+ max_length = max_source_length + max_target_length
673
+
674
+ def preprocess_function(examples):
675
+ """
676
+ Preprocessing the datasets.
677
+ part of code modified from https://github.com/lm-sys/FastChat
678
+ """
679
+ input_ids_list = []
680
+ targets_list = []
681
+ roles = ["human", "gpt"]
682
+
683
+ def get_dialog(examples):
684
+ for i, source in enumerate(examples['conversations']):
685
+ if len(source) < 2:
686
+ continue
687
+ data_role = source[0].get("from", "")
688
+ if data_role not in roles or data_role != roles[0]:
689
+ # Skip the first one if it is not from human
690
+ source = source[1:]
691
+ if len(source) < 2:
692
+ continue
693
+ messages = []
694
+ for j, sentence in enumerate(source):
695
+ data_role = sentence.get("from", "")
696
+ if data_role not in roles:
697
+ logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
698
+ break
699
+ if data_role == roles[j % 2]:
700
+ messages.append(sentence["value"])
701
+ if len(messages) < 2 or len(messages) % 2 != 0:
702
+ continue
703
+ # Convert the list to pairs of elements
704
+ history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
705
+ yield prompt_template.get_dialog(history_messages)
706
+
707
+ for dialog in get_dialog(examples):
708
+ input_ids, labels = [], []
709
+
710
+ for i in range(len(dialog) // 2):
711
+ source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0))
712
+ target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False)
713
+
714
+ if len(source_ids) > max_source_length:
715
+ source_ids = source_ids[:max_source_length]
716
+ if len(target_ids) > max_target_length - 1: # eos token
717
+ target_ids = target_ids[:max_target_length - 1]
718
+ if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id:
719
+ source_ids = source_ids[1:]
720
+ if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id:
721
+ target_ids = target_ids[:-1]
722
+ if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
723
+ break
724
+
725
+ input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn
726
+ labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
727
+
728
+ input_ids_list.append(input_ids)
729
+ targets_list.append(labels)
730
+
731
+ return dict(
732
+ input_ids=input_ids_list,
733
+ labels=targets_list,
734
+ )
735
+
736
+ def filter_empty_labels(example):
737
+ """Remove empty labels dataset."""
738
+ return not all(label == IGNORE_INDEX for label in example["labels"])
739
+
740
+ train_dataset = None
741
+ max_train_samples = 0
742
+ if training_args.do_train:
743
+ if "train" not in raw_datasets:
744
+ raise ValueError("--do_train requires a train dataset")
745
+ train_dataset = raw_datasets['train']
746
+ max_train_samples = len(train_dataset)
747
+ if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
748
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
749
+ train_dataset = train_dataset.select(range(max_train_samples))
750
+ logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
751
+ with training_args.main_process_first(desc="Train dataset tokenization"):
752
+ train_dataset = train_dataset.shuffle().map(
753
+ preprocess_function,
754
+ batched=True,
755
+ num_proc=data_args.preprocessing_num_workers,
756
+ remove_columns=train_dataset.column_names,
757
+ load_from_cache_file=not data_args.overwrite_cache,
758
+ desc="Running tokenizer on dataset",
759
+ )
760
+ train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
761
+ logger.debug(f"Num train_samples: {len(train_dataset)}")
762
+ logger.debug("Tokenized training example:")
763
+ logger.debug(f"Decode input_ids[0]: {tokenizer.decode(train_dataset[0]['input_ids'])}")
764
+ replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id
765
+ for label in list(train_dataset[0]['labels'])]
766
+ logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
767
+
768
+ eval_dataset = None
769
+ max_eval_samples = 0
770
+ if training_args.do_eval:
771
+ with training_args.main_process_first(desc="Eval dataset tokenization"):
772
+ if "validation" not in raw_datasets:
773
+ raise ValueError("--do_eval requires a validation dataset")
774
+ eval_dataset = raw_datasets["validation"]
775
+ max_eval_samples = len(eval_dataset)
776
+ if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
777
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
778
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
779
+ logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
780
+ eval_dataset = eval_dataset.map(
781
+ preprocess_function,
782
+ batched=True,
783
+ num_proc=data_args.preprocessing_num_workers,
784
+ remove_columns=eval_dataset.column_names,
785
+ load_from_cache_file=not data_args.overwrite_cache,
786
+ desc="Running tokenizer on dataset",
787
+ )
788
+ eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
789
+ logger.debug(f"Num eval_samples: {len(eval_dataset)}")
790
+ logger.debug("Tokenized eval example:")
791
+ logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
792
+
793
+ # Load model
794
+ if model_args.model_name_or_path:
795
+ torch_dtype = (
796
+ model_args.torch_dtype
797
+ if model_args.torch_dtype in ["auto", None]
798
+ else getattr(torch, model_args.torch_dtype)
799
+ )
800
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
801
+ ddp = world_size != 1
802
+ if ddp:
803
+ model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
804
+ if training_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
805
+ logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
806
+ config = config_class.from_pretrained(
807
+ model_args.model_name_or_path,
808
+ trust_remote_code=model_args.trust_remote_code,
809
+ torch_dtype=torch_dtype,
810
+ cache_dir=model_args.cache_dir
811
+ )
812
+ model = model_class.from_pretrained(
813
+ model_args.model_name_or_path,
814
+ config=config,
815
+ load_in_8bit=model_args.load_in_8bit,
816
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
817
+ device_map=model_args.device_map,
818
+ trust_remote_code=model_args.trust_remote_code,
819
+ quantization_config=BitsAndBytesConfig(
820
+ load_in_4bit=True,
821
+ bnb_4bit_use_double_quant=True,
822
+ bnb_4bit_quant_type="nf4",
823
+ bnb_4bit_compute_dtype=torch_dtype,
824
+ ) if training_args.qlora else None,
825
+ )
826
+ if hasattr(model, 'lm_head'):
827
+ model.lm_head = CastOutputToFloat(model.lm_head)
828
+ else:
829
+ raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model")
830
+
831
+ if training_args.use_peft:
832
+ logger.info("Fine-tuning method: LoRA(PEFT)")
833
+ if training_args.peft_path is not None:
834
+ logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
835
+ model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
836
+ else:
837
+ target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
838
+ if target_modules and 'all' in target_modules:
839
+ target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
840
+ modules_to_save = training_args.modules_to_save
841
+ if modules_to_save is not None:
842
+ modules_to_save = modules_to_save.split(',')
843
+ logger.info(f"Peft target_modules: {target_modules}")
844
+ logger.info(f"Peft lora_rank: {training_args.lora_rank}")
845
+ peft_config = LoraConfig(
846
+ task_type=TaskType.CAUSAL_LM,
847
+ target_modules=target_modules,
848
+ inference_mode=False,
849
+ r=training_args.lora_rank,
850
+ lora_alpha=training_args.lora_alpha,
851
+ lora_dropout=training_args.lora_dropout,
852
+ modules_to_save=modules_to_save)
853
+ model = get_peft_model(model, peft_config)
854
+ if model_args.load_in_8bit:
855
+ model = prepare_model_for_int8_training(model)
856
+ model.print_trainable_parameters()
857
+ else:
858
+ logger.info("Fine-tuning method: Full parameters training")
859
+ model = model.float()
860
+ print_trainable_parameters(model)
861
+ logger.debug(f"Model: {model}")
862
+
863
+ # Initialize our Trainer
864
+ if training_args.gradient_checkpointing:
865
+ model.gradient_checkpointing_enable()
866
+ model.config.use_cache = False
867
+ else:
868
+ model.config.use_cache = True
869
+ model.enable_input_require_grads()
870
+ if not ddp and torch.cuda.device_count() > 1:
871
+ # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
872
+ model.is_parallelizable = True
873
+ model.model_parallel = True
874
+
875
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
876
+ # Initialize our Trainer
877
+ trainer = SavePeftModelTrainer(
878
+ model=model,
879
+ args=training_args,
880
+ train_dataset=train_dataset if training_args.do_train else None,
881
+ eval_dataset=eval_dataset if training_args.do_eval else None,
882
+ tokenizer=tokenizer,
883
+ data_collator=data_collator,
884
+ )
885
+
886
+ # Training
887
+ if training_args.do_train:
888
+ logger.info("*** Train ***")
889
+ sample = next(iter(trainer.get_train_dataloader()))
890
+ logger.debug(f"Train dataloader example: {sample}")
891
+ logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
892
+ logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
893
+ replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
894
+ logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
895
+ checkpoint = None
896
+ if training_args.resume_from_checkpoint is not None:
897
+ checkpoint = training_args.resume_from_checkpoint
898
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
899
+
900
+ metrics = train_result.metrics
901
+ metrics["train_samples"] = max_train_samples
902
+ logger.debug(f"Training metrics: {metrics}")
903
+ trainer.log_metrics("train", metrics)
904
+ trainer.save_metrics("train", metrics)
905
+ model.config.use_cache = True # enable cache after training
906
+ trainer.save_state()
907
+ logger.info(f"Saving model checkpoint to {training_args.output_dir}")
908
+ save_model(training_args.output_dir, model, tokenizer, training_args)
909
+
910
+ # Evaluation
911
+ if training_args.do_eval and trainer.is_world_process_zero():
912
+ logger.info("*** Evaluate ***")
913
+ metrics = trainer.evaluate()
914
+
915
+ metrics["eval_samples"] = max_eval_samples
916
+ try:
917
+ perplexity = math.exp(metrics["eval_loss"])
918
+ except OverflowError:
919
+ perplexity = float("inf")
920
+ metrics["perplexity"] = perplexity
921
+ logger.debug(f"Eval metrics: {metrics}")
922
+ trainer.log_metrics("eval", metrics)
923
+ trainer.save_metrics("eval", metrics)
924
+
925
+
926
+ if __name__ == "__main__":
927
+ main()