Spaces:
Configuration error
Configuration error
nengrenjie83
commited on
Commit
•
b78b52f
1
Parent(s):
247f626
Upload 28 files
Browse files- .gitignore +129 -0
- CITATION.cff +9 -0
- CONTRIBUTING.md +9 -0
- DISCLAIMER +23 -0
- LICENSE +201 -0
- README.md +326 -13
- README_EN.md +224 -0
- _config.yml +1 -0
- build_domain_tokenizer.py +59 -0
- convert_dataset.py +47 -0
- deepspeed_config.json +43 -0
- dpo_training.py +495 -0
- gradio_demo.py +215 -0
- inference.py +225 -0
- merge_peft_adapter.py +109 -0
- merge_tokenizers.py +150 -0
- pretraining.py +678 -0
- requirements.txt +10 -0
- reward_modeling.py +643 -0
- rl_training.py +499 -0
- run_dpo.sh +29 -0
- run_pt.sh +42 -0
- run_rl.sh +24 -0
- run_rm.sh +39 -0
- run_sft.sh +40 -0
- run_training_dpo_pipeline.ipynb +711 -0
- run_training_pipeline.ipynb +917 -0
- supervised_finetuning.py +927 -0
.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
CITATION.cff
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you use this software, please cite it as below."
|
3 |
+
authors:
|
4 |
+
- family-names: "Xu"
|
5 |
+
given-names: "Ming"
|
6 |
+
title: "MedicalGPT: Training Your Own Medical GPT Model with ChatGPT Training Pipeline"
|
7 |
+
url: "https://github.com/shibing624/MedicalGPT"
|
8 |
+
data-released: 2023-06-02
|
9 |
+
version: 0.0.4
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing
|
2 |
+
|
3 |
+
We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either
|
4 |
+
side, please stick to the following process:
|
5 |
+
|
6 |
+
1. Check if there is already an issue for your concern.
|
7 |
+
2. If there is not, open a new one to start a discussion. We hate to close finished PRs!
|
8 |
+
3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the
|
9 |
+
commit guidelines below.
|
DISCLAIMER
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The software project, data, and models provided by our GitHub project are provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement.
|
2 |
+
|
3 |
+
In no event shall the project owners or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software project, data, or models, even if advised of the possibility of such damage.
|
4 |
+
|
5 |
+
Users of this software project, data, and models are solely responsible for any consequences of their use. The project owners and contributors shall not be held responsible for any subsequent or potential harm caused by the use of this software project, data, or models.
|
6 |
+
|
7 |
+
By using this software project, data, or models, users accept and agree to this disclaimer. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
8 |
+
|
9 |
+
It is important to note that this software project, data, and models are still in the research phase and are provided for experimental purposes only. As such, the project owners and contributors do not guarantee the accuracy, completeness, or usefulness of the software project, data, or models.
|
10 |
+
|
11 |
+
Furthermore, due to the experimental nature of this software project, data, and models, it is possible that they may contain or generate inappropriate responses, errors, or inconsistencies. Users should exercise caution when using this software project, data, or models, and should not rely solely on them for any critical or sensitive tasks.
|
12 |
+
|
13 |
+
The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models, including but not limited to, any inappropriate responses generated by the software project, data, or models.
|
14 |
+
|
15 |
+
By using this software project, data, or models, users acknowledge and accept the experimental nature of the software project, data, and models, and understand the potential risks and limitations associated with their use. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
16 |
+
|
17 |
+
The software project, data, and models provided by our GitHub project are intended for research purposes only. They should not be used for any commercial, business, or legal purposes, and should not be relied upon as a substitute for professional advice or judgment.
|
18 |
+
|
19 |
+
Users of this software project, data, and models are strictly prohibited from using them for any commercial purposes, including but not limited to, selling, licensing, or distributing the software project, data, or models to third parties.
|
20 |
+
|
21 |
+
The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models for any commercial or business purposes.
|
22 |
+
|
23 |
+
By using this software project, data, or models, users agree to use them for research purposes only, and not for any commercial or business purposes. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,326 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<a href="https://github.com/shibing624/MedicalGPT">
|
5 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" height="100" alt="Logo">
|
6 |
+
</a>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
-----------------
|
10 |
+
|
11 |
+
# MedicalGPT: Training Medical GPT Model
|
12 |
+
[![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
|
13 |
+
[![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
|
14 |
+
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
|
15 |
+
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
|
16 |
+
[![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
|
17 |
+
[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
18 |
+
[![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
|
19 |
+
|
20 |
+
## 📖 Introduction
|
21 |
+
|
22 |
+
**MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
|
23 |
+
Supervised Finetuning, RLHF(Reward Modeling and Reinforcement Learning) and DPO(Direct Preference Optimization).
|
24 |
+
|
25 |
+
**MedicalGPT** 训练医疗大模型,实现了包括增量预训练、有监督微调、RLHF(奖励建模、强化学习训练)和DPO(直接偏好优化)。
|
26 |
+
|
27 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/dpo.jpg" width="860" />
|
28 |
+
|
29 |
+
- RLHF training pipeline来自Andrej Karpathy的演讲PDF [State of GPT](https://karpathy.ai/stateofgpt.pdf),视频 [Video](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2)
|
30 |
+
- DPO方法来自论文[Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
|
31 |
+
|
32 |
+
## 🔥 News
|
33 |
+
[2023/08/28] v1.5版本: 新增[DPO(直接偏好优化)](https://arxiv.org/pdf/2305.18290.pdf)方法,DPO通过直接优化语言模型来实现对其行为的精确控制,可以有效学习到人类偏好。详见[Release-v1.5](https://github.com/shibing624/MedicalGPT/releases/tag/1.5.0)
|
34 |
+
|
35 |
+
[2023/08/08] v1.4版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat),和对应的LoRA模型[shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora),详见[Release-v1.4](https://github.com/shibing624/MedicalGPT/releases/tag/1.4.0)
|
36 |
+
|
37 |
+
[2023/08/02] v1.3版本: 新增LLaMA, LLaMA2, Bloom, ChatGLM, ChatGLM2, Baichuan模型的多轮对话微调训练;新增领域词表扩充功能;新增中文预训练数据集和中文ShareGPT微调训练集,详见[Release-v1.3](https://github.com/shibing624/MedicalGPT/releases/tag/1.3.0)
|
38 |
+
|
39 |
+
[2023/07/13] v1.1版本: 发布中文医疗LLaMA-13B模型[shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的完整模型权重,详见[Release-v1.1](https://github.com/shibing624/MedicalGPT/releases/tag/1.1)
|
40 |
+
|
41 |
+
[2023/06/15] v1.0版本: 发布中文医疗LoRA模型[shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的LoRA权重,详见[Release-v1.0](https://github.com/shibing624/MedicalGPT/releases/tag/1.0.0)
|
42 |
+
|
43 |
+
[2023/06/05] v0.2版本: 以医疗为例,训练领域大模型,实现了四阶段训练:包括二次预训练、有监督微调、奖励建模、强化学习训练。详见[Release-v0.2](https://github.com/shibing624/MedicalGPT/releases/tag/0.2.0)
|
44 |
+
|
45 |
+
|
46 |
+
## 😊 Features
|
47 |
+
|
48 |
+
|
49 |
+
基于ChatGPT Training Pipeline,本项目实现了领域模型--医疗行业语言大模型的训练:
|
50 |
+
|
51 |
+
|
52 |
+
- 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识(可选)
|
53 |
+
- 第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
|
54 |
+
- 第三阶段
|
55 |
+
- RLHF(Reinforcement Learning from Human Feedback)基于人类反馈对语言模型进行强化学习,分为两步:
|
56 |
+
- RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
|
57 |
+
- RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
|
58 |
+
- [DPO(Direct Preference Optimization)](https://arxiv.org/pdf/2305.18290.pdf)直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
### Release Models
|
63 |
+
|
64 |
+
|
65 |
+
| Model | Base Model | Introduction |
|
66 |
+
|:------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
67 |
+
| [shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的LoRA权重(单轮对话) |
|
68 |
+
| [shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的完整模型权重(单轮对话) |
|
69 |
+
| [shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的LoRA权重 |
|
70 |
+
| [shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的完整模型权重 |
|
71 |
+
|
72 |
+
演示[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat)模型效果:
|
73 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/demo-screen.gif" width="860" />
|
74 |
+
具体case见[Inference Examples](#inference-examples)
|
75 |
+
|
76 |
+
## ▶️ Demo
|
77 |
+
|
78 |
+
|
79 |
+
我们提供了一个简洁的基于gradio的交互式web界面,启动服务后,可通过浏览器访问,输入问题,模型会返回答案。
|
80 |
+
|
81 |
+
启动服务,命令如下:
|
82 |
+
```shell
|
83 |
+
CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --model_type base_model_type --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
|
84 |
+
```
|
85 |
+
|
86 |
+
参数说明:
|
87 |
+
|
88 |
+
- `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
|
89 |
+
- `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录,也可使用HF Model Hub模型调用名称
|
90 |
+
- `--lora_model {lora_model}`:LoRA文件所在目录,也可使用HF Model Hub模型调用名称。若lora权重已经合并到预训练模型,则删除--lora_model参数
|
91 |
+
- `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
|
92 |
+
- `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
|
93 |
+
- `--only_cpu`: 仅使用CPU进行推理
|
94 |
+
- `--gpus {gpu_ids}`: 指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
|
95 |
+
- `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不��整
|
96 |
+
|
97 |
+
|
98 |
+
## 💾 Install
|
99 |
+
#### Updating the requirements
|
100 |
+
From time to time, the `requirements.txt` changes. To update, use this command:
|
101 |
+
|
102 |
+
```markdown
|
103 |
+
git clone https://github.com/shibing624/MedicalGPT
|
104 |
+
conda activate gpt
|
105 |
+
cd MedicalGPT
|
106 |
+
pip install -r requirements.txt --upgrade
|
107 |
+
```
|
108 |
+
|
109 |
+
## 🚀 Training Pipeline
|
110 |
+
|
111 |
+
Training Stage:
|
112 |
+
|
113 |
+
| Stage | Introduction | Python script | Shell script |
|
114 |
+
|:--------------------------------|:-------------|:--------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------|
|
115 |
+
| Continue Pretraining | 增量预训练 | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |
|
116 |
+
| Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
|
117 |
+
| Direct Preference Optimization | 直接偏好优化 | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |
|
118 |
+
| Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
|
119 |
+
| Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |
|
120 |
+
|
121 |
+
- 提供完整PT+SFT+DPO全阶段串起来训练的pipeline:[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb),运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kMIe3pTec2snQvLBA00Br8ND1_zwy3Gr?usp=sharing)
|
122 |
+
- 提供完整PT+SFT+RLHF全阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,运行完大概需要20分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
|
123 |
+
- [训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
|
124 |
+
- [数据集wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%95%B0%E6%8D%AE%E9%9B%86)
|
125 |
+
- [扩充词表wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%89%A9%E5%85%85%E4%B8%AD%E6%96%87%E8%AF%8D%E8%A1%A8)
|
126 |
+
- [FAQ](https://github.com/shibing624/MedicalGPT/wiki/FAQ)
|
127 |
+
|
128 |
+
#### Supported Models
|
129 |
+
|
130 |
+
| 模型名 | 模型大小 | Template |
|
131 |
+
| ------------------------------------------------------- | --------------------------- |---------------|
|
132 |
+
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | vicuna |
|
133 |
+
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | alpaca |
|
134 |
+
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
135 |
+
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | baichuan-chat |
|
136 |
+
| [InternLM](https://github.com/InternLM/InternLM) | 7B | intern |
|
137 |
+
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | chatml |
|
138 |
+
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | xverse |
|
139 |
+
| [ChatGLM](https://github.com/THUDM/ChatGLM-6B) | 6B | chatglm |
|
140 |
+
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | chatglm2 |
|
141 |
+
|
142 |
+
The following models are tested:
|
143 |
+
|
144 |
+
bloom:
|
145 |
+
- [bigscience/bloomz-560m](https://huggingface.co/bigscience/bloomz-560m)
|
146 |
+
- [bigscience/bloomz-1b7](https://huggingface.co/bigscience/bloomz-1b7)
|
147 |
+
- [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
|
148 |
+
|
149 |
+
llama:
|
150 |
+
- [shibing624/chinese-alpaca-plus-7b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-7b-hf)
|
151 |
+
- [shibing624/chinese-alpaca-plus-13b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-13b-hf)
|
152 |
+
- [minlik/chinese-llama-plus-7b-merged](https://huggingface.co/minlik/chinese-llama-plus-7b-merged)
|
153 |
+
- [shibing624/chinese-llama-plus-13b-hf](https://huggingface.co/shibing624/chinese-llama-plus-13b-hf)
|
154 |
+
- [decapoda-research/llama-7b-hf](https://huggingface.co/decapoda-research/llama-7b-hf)
|
155 |
+
- [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1)
|
156 |
+
|
157 |
+
llama2:
|
158 |
+
- [daryl149/llama-2-7b-chat-hf](https://huggingface.co/daryl149/llama-2-7b-chat-hf)
|
159 |
+
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
160 |
+
- [ziqingyang/chinese-alpaca-2-7b](https://huggingface.co/ziqingyang/chinese-alpaca-2-7b)
|
161 |
+
|
162 |
+
chatglm:
|
163 |
+
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
164 |
+
- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
165 |
+
|
166 |
+
baichuan:
|
167 |
+
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
168 |
+
- [baichuan-inc/Baichuan-13B-Base](https://huggingface.co/baichuan-inc/Baichuan-13B-Base)
|
169 |
+
- [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat)
|
170 |
+
|
171 |
+
xverse:
|
172 |
+
- [xverse/XVERSE-13B-Chat](https://huggingface.co/xverse/XVERSE-13B-Chat)
|
173 |
+
|
174 |
+
Qwen:
|
175 |
+
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
|
176 |
+
|
177 |
+
## 💻 Inference
|
178 |
+
训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。
|
179 |
+
|
180 |
+
```shell
|
181 |
+
CUDA_VISIBLE_DEVICES=0 python inference.py \
|
182 |
+
--model_type base_model_type \
|
183 |
+
--base_model path_to_model_hf_dir \
|
184 |
+
--tokenizer_path path_to_model_hf_dir \
|
185 |
+
--lora_model path_to_lora \
|
186 |
+
--interactive
|
187 |
+
```
|
188 |
+
|
189 |
+
参数说明:
|
190 |
+
|
191 |
+
- `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
|
192 |
+
- `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
|
193 |
+
- `--tokenizer_path {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
|
194 |
+
- `--lora_model {lora_model}`:LoRA解压后文件所在目录,也可使用HF Model Hub模型调用名称。如果已经合并了LoRA权重到预训练模型,则可以不提供此参数
|
195 |
+
- `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
|
196 |
+
- `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
|
197 |
+
- `--interactive`:以交互方式启动多轮问答,使用流式推理
|
198 |
+
- `--data_file {file_name}`:非交互方式启动下,按行读取file_name中的的内容进行预测
|
199 |
+
- `--predictions_file {file_name}`:非交互式方式下,将预测的结果以json格式写入file_name
|
200 |
+
- `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不调整
|
201 |
+
- `--only_cpu`:仅使用CPU进行推理
|
202 |
+
- `--gpus {gpu_ids}`:指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
|
203 |
+
|
204 |
+
|
205 |
+
#### Inference Examples
|
206 |
+
[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) inference examples:
|
207 |
+
|
208 |
+
<details><summary>Model output case</summary>
|
209 |
+
|
210 |
+
| Human | Assitant |
|
211 |
+
|------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
212 |
+
| 小孩发烧怎么办 | 发烧是身体对感染或其他疾病的自然反应。以下是一些可以帮助缓解发烧症状的建议:<br/>1. 保持身体水分充足:发烧时,身体会失去水分,因此需要保持水分充足。<br/>2. 使用冰袋:将冰袋放在额头或颈部可以帮助降低体温。<br/>3. 使用退烧药:退烧药可以帮助降低体温。<br/>4. 保持身体温度:保持身体温度可以帮助降低体温。<br/>5. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>6. 使用退烧药:退烧药可以帮助降低体温。<br/>7. 保持身体温度:保持身体温度可以帮助降低体温。<br/>8. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>9. 使用退烧药:退烧药可以帮助降低体温。<br/>10. 保持身体温度:保持身体温度可以帮助降低体温。 |
|
213 |
+
| 一岁宝宝发烧能吃啥药? | 1岁宝宝发烧时,可以考虑以下几种药物:<br/>1. 退烧药:如布洛芬(Ibuprofen)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助降低体温,缓解发烧症状。<br/>2. 抗生素:如果发烧是由细菌感染引起的,可以考虑使用抗生素。但请务必在医生的建议下使用。<br/>3. 抗病毒药物:如果发烧是由病毒感染引起的,可以考虑使用抗病毒药物。但请务必在医生的建议下使用。<br/>4. 解热镇痛药:如阿司匹林(Aspirin)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助缓解发烧症状。<br/>5. 抗组胺药:如苯海拉明(Diphenhydramine)或非索非那定(Naproxen)。这些药物可以帮助缓解发烧症状。<br/>请务必在医生的建议下使用这些药物,并遵循药物说明书上的用量和用法。同时,保持宝宝充足的水分摄入,并确保宝宝有足够的休息。 |
|
214 |
+
| 肛门病变可能是什么疾病的症状 | 肛门病变可能由多种原因引起,包括感染、炎症、肿瘤或其他疾病。一些常见的肛门病变包括:<br/>1. 肛裂:这是一种肛门皮肤撕裂,通常由便秘或其他压力因素引起。<br/>2. 肛瘘:这是一种肛门感染,导致肛门皮肤和直肠之间形成瘘管。<br/>3. 肛门疣:这是一种肛门皮肤上形成的疣状病变,通常由人类乳头状瘤病毒(HPV)引起。<br/>4. 肛门脓肿:这是一种肛门感染,导致肛门皮肤下形成脓肿。<br/>5. 肛门肿瘤:这是一种肛门皮肤或直肠内形成的肿瘤。<br/>6. 肛门溃疡:这是一种肛门皮肤上形成的溃疡,通常由感染或炎症引起。<br/>7. 肛门瘙痒:这是一种肛门皮肤瘙痒的症状,可能由感染、炎症或其他疾病引起。<br/>如果您有肛门病变的症状,请咨询医疗保健专业人员以获得适当的诊断和治疗。 |
|
215 |
+
| 医生,您好白带增多泛黄怎么回事? | 您好,白带增多和发黄可能是由多种原因引起的,包括感染、荷尔蒙变化、阴道干燥或其他妇科问题。以下是一些可能的原因:<br/>1. 感染:白带增多和发黄可能是由阴道或尿道感染引起的。感染可能由细菌、病毒或真菌引起。<br/>2. 荷尔蒙变化:荷尔蒙变化,如月经周期或更年期,可能会导致白带增多和发黄。<br/>3. 阴道干燥:阴道干燥可能会导致白带增多和发黄。<br/>4. 其他妇科问题:其他妇科问题,如子宫颈炎、子宫内膜炎或子宫内膜异位症,也可能导致白带增多和发黄。<br/>如果您担心自己的白带,建议您咨询妇科医生进行检查和诊断。医生可以根据您的症状和检查结果确定原因并提供适当的治疗。 |
|
216 |
+
| 药流后有巧克力白带是没流干净吗 | 药流后出现巧克力白带并不一定意味着没有流干净。巧克力白带是由于子宫内膜脱落和出血引起的,通常在药流后的一段时间内会出现。<br/><br/>药流后出现巧克力白带的原因可能包括:<br/><br/>1. 子宫内膜��落:药流后,子宫内膜会脱落,导致出血。这种出血可能表现为巧克力白带。<br/>2. 感染:药流后,子宫内膜和宫颈口可能存在创伤,容易受到细菌感染。感染可能导致巧克力白带。<br/>3. 激素水平变化:药流后,激素水平可能发生变化,导致巧克力白带。<br/><br/>如果巧克力白带持续时间较长,或者伴有其他症状,如腹痛、发热、分泌物异味等,建议及时就诊,进行相关检查,如妇科检查、B超等,以确定是否存在感染或其他问题。在医生的指导下进行相应的治疗。 |
|
217 |
+
| 维胺酯维E乳膏能治理什么疾病 | 维胺酯维E乳膏是一种外用药物,主要用于治疗痤疮、脂溢性皮炎等皮肤病。它可以帮助减少皮肤油脂分泌,减轻炎症,促进皮肤修复。<br/><br/>需要注意的是,维胺酯维E乳膏并非万能药,对于某些皮肤病可能并不适用。在使用前,建议咨询医生或药师,了解药物的适应症、禁忌症、用法用量等信息。同时,使用药物时应遵循医嘱,避免自行滥用。 |
|
218 |
+
| 痤疮是什么原因导致的? | 痤疮,也称为粉刺,是由皮肤毛孔堵塞引起的。堵塞毛孔的原因有很多,包括:<br/>1. 皮脂腺过度活跃:皮脂腺产生过多油脂,导致毛孔堵塞。<br/>2. 毛孔堵塞:皮肤表面的死皮细胞、油脂和细菌堵塞毛孔。<br/>3. 炎症:痤疮通常伴随着炎症,这可能是由于皮肤感染或免疫系统对痤疮的反应引起的。<br/>4. 遗传因素:某些人可能更容易患痤疮,这可能是由于遗传因素引起的。<br/>5. 荷尔蒙变化:痤疮在青春期和月经期间更为常见,这可能是由于荷尔蒙变化引起的。<br/>6. 生活方式因素:某些生活方式因素,如压力、不良饮食和缺乏运动,也可能导致痤疮。<br/>7. 化妆品和护肤品:某些化妆品和护肤品可能导致痤疮,特别是那些含有刺激性成分的产品。<br/>8. 药物:某些药物,如口服避孕药,也可能导致痤疮。<br/>要预防痤疮,可以采取一些措施,如保持皮肤清洁、使用温和的护肤品、避免过度使用化妆品和护肤品、保持健康的生活方式和饮食习惯、避免压力和荷尔蒙变化。如果痤疮严重,可能需要寻求皮肤科医生的帮助。 |
|
219 |
+
| 脸上长青春痘怎么办,能吃啥药,需要注意啥 | 青春痘,又称痤疮,是一种常见的皮肤病。青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。以下是一些建议:<br/><br/>1. 饮食:<br/>* 避免高糖、高脂、高油的食物,如快餐、甜点、油炸食品等。<br/>* 多吃新鲜蔬菜、水果、全谷类食物,以增加纤维摄入,促进肠道蠕动。<br/>* 避免摄入过多的咖啡因、酒精、辛辣食物等刺激性食物。<br/><br/>2. 生活习惯:<br/>* 保持良好的作息时间,保证充足的睡眠。<br/>* 避免长时间使用电脑、手机等电子产品,以免辐射对皮肤造成刺激。<br/>* 保持皮肤清洁,每天早晚使用温和的洁面产品清洁皮肤。<br/>* 避免挤压青春痘,以免感染和留下疤痕。<br/><br/>3. 药物:<br/>* 外用药物:可以使用含有水杨酸、苯甲酸、壬二酸等成分的护肤品,帮助去除皮肤表面的角质,减少痘痘的形成。<br/>* 口服药物:如果青春痘严重,可以考虑口服抗生素、维生素B族、锌等药物,但需在医生指导下使用。<br/><br/>4. 注意:<br/>* 不要随意使用化妆品、护肤品,以免刺激皮肤。<br/>* 不要过度清洁皮肤,以免破坏皮肤屏障。<br/>* 保持心情舒畅,避免焦虑、抑郁等负面情绪。<br/><br/>总之,青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。在治疗过程中,要保持耐心和信心,遵循医生的建议,逐步改善皮肤状况。 |
|
220 |
+
| 由于输精管堵塞会不会导致男性不育?���医生您好,我同事这几天不开心,我问她怎么了,她说她老公输精管堵塞,所以一直没孩子,不是不想要,请问输精管堵塞会不会导致男性不育? | 输精管堵塞可能会导致男性不育。输精管是连接睾丸和尿道的一条管道,负责将精子从睾丸输送到尿道。如果输精管堵塞,精子就无法通过输精管到达尿道,从而导致不育。<br/><br/>输精管堵塞的原因有很多,包括先天性发育异常、炎症、肿瘤、外伤等。输精管堵塞的治疗方法包括手术疏通、药物治疗等。如果输精管堵塞严重,可能需要进行手术疏通。<br/><br/>需要注意的是,输精管堵塞并不一定导致不育,因为还有其他途径可以将精子输送到女性体内。例如,精子可以通过附睾、前列腺等组织到达女性体内。因此,如果输精管堵塞,建议及时就医,寻求专业医生的帮助。 |
|
221 |
+
| 心肌梗塞能治疗吗? | 心肌梗死是一种严重的疾病,如果及时治疗,可以挽救生命并防止长期残疾。治疗心肌梗死的关键是尽早识别和治疗。<br/><br/>心肌梗死的治疗通常包括以下步骤:<br/><br/>1. 立即就医:如果出现胸痛、呼吸急促或晕厥等症状,立即就医。<br/><br/>2. 药物治疗:医生可能会开处方药物,如硝酸甘油或阿司匹林,以帮助缓解症状并防止进一步的损伤。<br/><br/>3. 溶栓治疗:如果患者在到达医院前出现心肌梗死,医生可能会进行溶栓治疗,通过注射溶解血栓的药物来帮助恢复血流。<br/><br/>4. 手术治疗:如果药物治疗或溶栓治疗无效,医生可能会建议进行手术,如冠状动脉搭桥手术或心脏支架植入术。<br/><br/>5. 康复和预防:心肌梗死的治疗通常包括康复计划,以帮助患者恢复身体功能并预防未来的心脏事件。此外,患者还应遵循医生的建议,以降低患心脏病的风险。<br/><br/>需要注意的是,心肌梗死的治疗取决于患者的具体情况和症状。因此,如果您或您认识的人出现心肌梗死症状,请立即就医。 |
|
222 |
+
| 怎么让自己精力充沛,列5点建议 | 1. 保持充足的睡眠:睡眠对于保持精力充沛至关重要。确保每晚获得7-9小时的高质量睡眠。<br/>2. 均衡饮食:均衡饮食对于保持能量水平至关重要。确保摄入足够的蛋白质、碳水化合物和健康脂肪。<br/>3. 定期锻炼:定期锻炼可以帮助提高能量水平。尝试每天进行至少30分钟的有氧运动,如散步、跑步或骑自行车。<br/>4. 管理压力:压力会消耗能量,因此管理压力对于保持精力充沛至关重要。尝试找到放松和减压的方法,如冥想、瑜伽或深呼吸练习。<br/>5. 保持水分:脱水会降低能量水平。确保每天喝足够的水,并根据需要补充水分。 |
|
223 |
+
| 写个春天的诗 | 春天,<br/>是大地苏醒的季节,<br/>是万物复苏的季节,<br/>是生命绽放的季节。<br/>春天,<br/>是阳光明媚的季节,<br/>是鸟语花香的季节,<br/>是绿意盎然的季节,<br/>是希望无限的季节。<br/>春天,<br/>是人们欢笑的季节,<br/>是孩子们嬉戏的季节,<br/>是恋人们相拥的季节,<br/>是人们追求梦想的季节。 |
|
224 |
+
| 南极洲为啥没有北极熊? | 南极洲没有北极熊,因为北极熊生活在北极地区,而南极洲位于南半球,气候和环境与北极地区有很大不同。南极洲��气候非常寒冷,常年被冰雪覆盖,而北极熊则生活在北极地区,那里有丰富的海洋生物和陆地哺乳动物。此外,南极洲的生态系统也与北极地区不同,没有北极熊所需的生存条件。 |
|
225 |
+
|
226 |
+
</details>
|
227 |
+
|
228 |
+
## 📚 Dataset
|
229 |
+
### 医疗数据集
|
230 |
+
|
231 |
+
- 240万条中文医疗数据集(包括预训练、指令微调和奖励数据集):[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
|
232 |
+
- 22万条中文医疗对话数据集(华佗项目):[FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
|
233 |
+
|
234 |
+
### 通用数据集
|
235 |
+
|
236 |
+
#### Pretraining datasets
|
237 |
+
- 16GB中英文无监督、平行语料[Linly-AI/Chinese-pretraining-dataset](https://huggingface.co/datasets/Linly-AI/Chinese-pretraining-dataset)
|
238 |
+
- 524MB中文维基百科语料[wikipedia-cn-20230720-filtered](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
239 |
+
#### SFT datasets
|
240 |
+
- 10万条多语言ShareGPT GPT4多轮对话数据集:[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) [本项目支持格式]
|
241 |
+
- 9万条英文ShareGPT多轮对话数集:[anon8231489123/ShareGPT_Vicuna_unfiltered](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) [本项目支持格式]
|
242 |
+
- 50万条中文ChatGPT指令Belle数据集:[BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
243 |
+
- 100万条中文ChatGPT指令Belle数据集:[BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
244 |
+
- 5万条英文ChatGPT指令Alpaca数据集:[50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
|
245 |
+
- 2万条中文ChatGPT指令Alpaca数据集:[shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
|
246 |
+
- 69万条中文指令Guanaco数据集(Belle50万条+Guanaco19万条):[Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
|
247 |
+
- 5万条英文ChatGPT多轮对话数据集:[RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
|
248 |
+
- 80万条中文ChatGPT多轮对话数据集:[BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
249 |
+
- 116万条中文ChatGPT多轮对话数据集:[fnlp/moss-002-sft-data](https://huggingface.co/datasets/fnlp/moss-002-sft-data)
|
250 |
+
- 3.8万条中文ShareGPT多轮对话数据集:[FreedomIntelligence/ShareGPT-CN](https://huggingface.co/datasets/FreedomIntelligence/ShareGPT-CN)
|
251 |
+
|
252 |
+
#### Reward Model datasets
|
253 |
+
- 原版的oasst1数据集:[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
254 |
+
- 2万条多语言oasst1的reward数据集:[tasksource/oasst1_pairwise_rlhf_reward](https://huggingface.co/datasets/tasksource/oasst1_pairwise_rlhf_reward)[本项目支持格式]
|
255 |
+
- 11万条英文hh-rlhf的reward数据集:[Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
|
256 |
+
- 9万条英文reward数据集(来自Anthropic's Helpful Harmless dataset):[Dahoas/static-hh](https://huggingface.co/datasets/Dahoas/static-hh)
|
257 |
+
- 7万条英文reward数据集(来源同上):[Dahoas/rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
258 |
+
- 7万条繁体中文的reward数据集(翻译自rm-static)[liswei/rm-static-m2m100-zh](https://huggingface.co/datasets/liswei/rm-static-m2m100-zh)
|
259 |
+
- 7万条英文Reward数据集:[yitingxie/rlhf-reward-datasets](https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets)
|
260 |
+
- 3千条中文知乎问答偏好数据集:[liyucheng/zhihu_rlhf_3k](https://huggingface.co/datasets/liyucheng/zhihu_rlhf_3k)
|
261 |
+
|
262 |
+
## ✅ Todo
|
263 |
+
|
264 |
+
1. [x] add multi-round dialogue data fine-tuning method
|
265 |
+
2. [x] add reward model fine-tuning
|
266 |
+
3. [x] add rl fine-tuning
|
267 |
+
4. [x] add medical reward dataset
|
268 |
+
5. [x] add llama in8/int4 training
|
269 |
+
6. [x] add all training and predict demo in colab
|
270 |
+
7. [x] add dpo training
|
271 |
+
|
272 |
+
## ☎️ Contact
|
273 |
+
|
274 |
+
- Issue(建议)
|
275 |
+
:[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
276 |
+
- 邮件我:xuming: [email protected]
|
277 |
+
- 微信我: 加我*微信号:xuming624, 备注:姓名-公司名-NLP* 进NLP交流群。
|
278 |
+
|
279 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
|
280 |
+
|
281 |
+
## ⚠️ 局限性、使用限制与免责声明
|
282 |
+
|
283 |
+
基于当前数据和基础模型训练得到的SFT模型,在效果上仍存在以下问题:
|
284 |
+
|
285 |
+
1. 在涉及事实性的指令上可能会产生违背事实的错���回答。
|
286 |
+
|
287 |
+
2. 对于具备危害性的指令无法很好的鉴别,由此会产生危害性言论。
|
288 |
+
|
289 |
+
3. 在一些涉及推理、代码、多轮对话等场景下模型的能力仍有待提高。
|
290 |
+
|
291 |
+
基于以上模型局限性,我们要求开发者仅将我们开源的模型权重及后续用此项目生成的衍生物用于研究目的,不得用于商业,以及其他会对社会带来危害的用途。
|
292 |
+
|
293 |
+
本项目仅可应用于研究目的,项目开发者不承担任何因使用本项目(包含但不限于数据、模型、代码等)导致的危害或损失。详细请参考[免责声明](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER)。
|
294 |
+
|
295 |
+
项目代码的授权协议为 [The Apache License 2.0](/LICENSE),代码可免费用做商业用途,模型权重和数据只能用于研究目的。请在产品说明中附加MedicalGPT的链接和授权协议。
|
296 |
+
|
297 |
+
|
298 |
+
## 😇 Citation
|
299 |
+
|
300 |
+
如果你在研究中使用了MedicalGPT,请按如下格式引用:
|
301 |
+
|
302 |
+
```latex
|
303 |
+
@misc{MedicalGPT,
|
304 |
+
title={MedicalGPT: Training Medical GPT Model},
|
305 |
+
author={Ming Xu},
|
306 |
+
year={2023},
|
307 |
+
howpublished={\url{https://github.com/shibing624/MedicalGPT}},
|
308 |
+
}
|
309 |
+
```
|
310 |
+
|
311 |
+
## 😍 Contribute
|
312 |
+
|
313 |
+
项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
|
314 |
+
|
315 |
+
- 在`tests`添加相应的单元测试
|
316 |
+
- 使用`python -m pytest`来运行所有单元测试,确保所有单测都是通过的
|
317 |
+
|
318 |
+
之后即可提交PR。
|
319 |
+
|
320 |
+
## 💕 Acknowledgements
|
321 |
+
|
322 |
+
- [Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
|
323 |
+
- [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
|
324 |
+
- [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
325 |
+
|
326 |
+
Thanks for their great work!
|
README_EN.md
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<a href="https://github.com/shibing624/MedicalGPT">
|
5 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" width="120" alt="Logo">
|
6 |
+
</a>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
-----------------
|
10 |
+
|
11 |
+
# MedicalGPT: Training Medical GPT Model
|
12 |
+
[![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
|
13 |
+
[![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
|
14 |
+
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
|
15 |
+
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
|
16 |
+
[![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
|
17 |
+
[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
18 |
+
[![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
|
19 |
+
|
20 |
+
## 📖 Introduction
|
21 |
+
|
22 |
+
**MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
|
23 |
+
Supervised Finetuning, Reward Modeling and Reinforcement Learning.
|
24 |
+
|
25 |
+
|
26 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/GPT_Training.jpg" width="860" />
|
27 |
+
|
28 |
+
Training MedicalGPT model:
|
29 |
+
|
30 |
+
- Stage 1:PT(Continue PreTraining), Pre-training the LLaMA model on massive domain document data to inject domain knowledge
|
31 |
+
- Stage 2: SFT (Supervised Fine-tuning) has supervised fine-tuning, constructs instruction fine-tuning data sets, and performs instruction fine-tuning on the basis of pre-trained models to align instruction intentions
|
32 |
+
- Stage 3: RM (Reward Model) reward model modeling, constructing a human preference ranking data set, training the reward model to align human preferences, mainly the "HHH" principle, specifically "helpful, honest, harmless"
|
33 |
+
- Stage 4: RL (Reinforcement Learning) is based on human feedback reinforcement learning (RLHF), using the reward model to train the SFT model, and the generation model uses rewards or penalties to update its strategy in order to generate higher quality, more in line with human preferences text
|
34 |
+
|
35 |
+
## ▶️ Demo
|
36 |
+
|
37 |
+
- Hugging Face Demo: doing
|
38 |
+
|
39 |
+
We provide a simple Gradio-based interactive web interface. After the service is started, it can be accessed through a browser, enter a question, and the model will return an answer. The command is as follows:
|
40 |
+
```shell
|
41 |
+
python scripts/gradio_demo.py --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
|
42 |
+
```
|
43 |
+
|
44 |
+
Parameter Description:
|
45 |
+
|
46 |
+
- `--base_model {base_model}`: directory to store LLaMA model weights and configuration files in HF format, or use the HF Model Hub model call name
|
47 |
+
- `--lora_model {lora_model}`: The directory where the LoRA file is located, and the name of the HF Model Hub model can also be used. If the lora weights have been merged into the pre-trained model, delete the --lora_model parameter
|
48 |
+
- `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
|
49 |
+
- `--use_cpu`: use only CPU for inference
|
50 |
+
- `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
|
51 |
+
|
52 |
+
|
53 |
+
## 🚀 Training Pipeline
|
54 |
+
|
55 |
+
### Stage 1: Continue Pretraining
|
56 |
+
|
57 |
+
Based on the llama-7b model, use medical encyclopedia data to continue pre-training, and expect to inject medical knowledge into the pre-training model to obtain the llama-7b-pt model. This step is optional
|
58 |
+
|
59 |
+
|
60 |
+
```shell
|
61 |
+
cd scripts
|
62 |
+
sh run_pt.sh
|
63 |
+
```
|
64 |
+
|
65 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
66 |
+
|
67 |
+
### Stage 2: Supervised FineTuning
|
68 |
+
Based on the llama-7b-pt model, the llama-7b-sft model is obtained by using medical question-and-answer data for supervised fine-tuning. This step is required
|
69 |
+
|
70 |
+
Supervised fine-tuning of the base llama-7b-pt model to create llama-7b-sft
|
71 |
+
|
72 |
+
```shell
|
73 |
+
cd scripts
|
74 |
+
sh run_sft.sh
|
75 |
+
```
|
76 |
+
|
77 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
78 |
+
|
79 |
+
### Stage 3: Reward Modeling
|
80 |
+
RM(Reward Model): reward model modeling
|
81 |
+
|
82 |
+
In principle, we can directly use human annotations to fine-tune the model with RLHF.
|
83 |
+
|
84 |
+
However, this will require us to send some samples to humans to be scored after each round of optimization. This is expensive and slow due to the large number of training samples required for convergence and the limited speed at which humans can read and annotate them.
|
85 |
+
A better strategy than direct feedback is to train a reward model RM on the human annotated set before entering the RL loop. The purpose of the reward model is to simulate human scoring of text.
|
86 |
+
|
87 |
+
The best practice for building a reward model is to rank the prediction results, that is, for each prompt (input text) corresponding to two results (yk, yj), the model predicts which score the human annotation is higher.
|
88 |
+
The RM model is trained by manually marking the scoring results of the SFT model. The purpose is to replace manual scoring. It is essentially a regression model used to align human preferences, mainly based on the "HHH" principle, specifically "helpful, honest, harmless".
|
89 |
+
|
90 |
+
|
91 |
+
Based on the llama-7b-sft model, the reward preference model is trained using medical question and answer preference data, and the llama-7b-reward model is obtained after training. This step is required
|
92 |
+
|
93 |
+
Reward modeling using dialog pairs from the reward dataset using the llama-7b-sft to create llama-7b-reward:
|
94 |
+
|
95 |
+
```shell
|
96 |
+
cd scripts
|
97 |
+
sh run_rm.sh
|
98 |
+
```
|
99 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
100 |
+
|
101 |
+
### Stage 4: Reinforcement Learning
|
102 |
+
The purpose of the RL (Reinforcement Learning) model is to maximize the output of the reward model. Based on the above steps, we have a fine-tuned language model (llama-7b-sft) and reward model (llama-7b-reward).
|
103 |
+
The RL loop is ready to execute.
|
104 |
+
|
105 |
+
This process is roughly divided into three steps:
|
106 |
+
|
107 |
+
1. Enter prompt, the model generates a reply
|
108 |
+
2. Use a reward model to score responses
|
109 |
+
3. Based on the score, a round of reinforcement learning for policy optimization (PPO)
|
110 |
+
|
111 |
+
<img src=https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/trl_loop.png height=400 />
|
112 |
+
|
113 |
+
Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward reward model to create llama-7b-rl
|
114 |
+
|
115 |
+
```shell
|
116 |
+
cd scripts
|
117 |
+
sh run_rl.sh
|
118 |
+
```
|
119 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
120 |
+
|
121 |
+
## 🔥 Inference
|
122 |
+
After the training is complete, now we load the trained model to verify the effect of the model generating text.
|
123 |
+
|
124 |
+
```shell
|
125 |
+
python scripts/inference.py \
|
126 |
+
--base_model path_to_llama_hf_dir \
|
127 |
+
--lora_model path_to_lora \
|
128 |
+
--with_prompt \
|
129 |
+
--interactive
|
130 |
+
```
|
131 |
+
|
132 |
+
Parameter Description:
|
133 |
+
|
134 |
+
- `--base_model {base_model}`: Directory to store LLaMA model weights and configuration files in HF format
|
135 |
+
- `--lora_model {lora_model}`: The directory where the LoRA file is decompressed, and the name of the HF Model Hub model can also be used. If you have incorporated LoRA weights into the pre-trained model, you can not provide this parameter
|
136 |
+
- `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
|
137 |
+
- `--with_prompt`: Whether to merge the input with the prompt template. Be sure to enable this option if loading an Alpaca model!
|
138 |
+
- `--interactive`: start interactively for multiple single rounds of question and answer
|
139 |
+
- `--data_file {file_name}`: Start in non-interactive mode, read the contents of file_name line by line for prediction
|
140 |
+
- `--predictions_file {file_name}`: In non-interactive mode, write the predicted results to file_name in json format
|
141 |
+
- `--use_cpu`: use only CPU for inference
|
142 |
+
- `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
|
143 |
+
|
144 |
+
|
145 |
+
#### Inference Examples
|
146 |
+
|
147 |
+
<details>
|
148 |
+
| Human | Assistant |
|
149 |
+
| ----- | ------ |
|
150 |
+
| How to make yourself energetic, list 5 suggestions | Here are some suggestions to keep you energetic: <br/>1. Get enough sleep. Make sure to get enough sleep every night. <br/>2. Exercise more. Doing moderate aerobic exercise, such as jogging and swimming, can help improve your energy and stamina. 3. Eat healthy food. Choose protein-rich foods such as chicken, fish, and eggs, as well as fresh fruits and vegetables. 4. Reduce caffeine intake. Try to avoid drinking caffeinated beverages or eating caffeinated foods during the day. <br/>5. Relax. Try meditation, yoga, or other forms of relaxation to reduce stress and anxiety. |
|
151 |
+
</details>
|
152 |
+
<br/>
|
153 |
+
|
154 |
+
|
155 |
+
## 📚 Dataset
|
156 |
+
|
157 |
+
- 2.4 million Chinese medical datasets (including pre-training, instruction fine-tuning and reward datasets): [shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
|
158 |
+
|
159 |
+
**Attach links to some general datasets and medical datasets**
|
160 |
+
|
161 |
+
- Belle dataset of 500,000 Chinese ChatGPT commands: [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
162 |
+
- Belle dataset of 1 million Chinese ChatGPT commands: [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
163 |
+
- Alpaca dataset of 50,000 English ChatGPT commands: [50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
|
164 |
+
- Alpaca dataset of 20,000 Chinese GPT-4 instructions: [shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
|
165 |
+
- Guanaco dataset with 690,000 Chinese instructions (500,000 Belle + 190,000 Guanaco): [Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
|
166 |
+
- 220,000 Chinese medical dialogue datasets (HuatuoGPT project): [FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
|
167 |
+
|
168 |
+
## ✅ Todo
|
169 |
+
|
170 |
+
1. [ ] Added multi-round dialogue data fine-tuning method
|
171 |
+
2. [x] add reward model finetuning
|
172 |
+
3. [x] add rl finetuning
|
173 |
+
4. [x] add medical reward dataset
|
174 |
+
5. [x] add llama in8/int4 training
|
175 |
+
6. [ ] add all training and predict demo in colab
|
176 |
+
## ☎️ Contact
|
177 |
+
|
178 |
+
- Issue (suggestion)
|
179 |
+
: [![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
180 |
+
- Email me: xuming: [email protected]
|
181 |
+
- WeChat Me: Add me* WeChat ID: xuming624, Remarks: Name-Company Name-NLP* Enter the NLP exchange group.
|
182 |
+
|
183 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
|
184 |
+
|
185 |
+
## ⚠️ Limitations, Restrictions of Use and Disclaimer
|
186 |
+
|
187 |
+
The SFT model trained based on the current data and the basic model still has the following problems in terms of effect:
|
188 |
+
|
189 |
+
1. Wrong answers that contradict the facts may be generated on the factual instructions.
|
190 |
+
2. Unable to identify harmful instructions well, resulting in harmful speech.
|
191 |
+
3. The ability of the model still needs to be improved in some scenarios involving reasoning, code, and multiple rounds of dialogue.
|
192 |
+
|
193 |
+
Based on the limitations of the above models, we require developers to only use our open source model weights and subsequent derivatives generated by this project for research purposes, and not for commercial use, and other purposes that will cause harm to society.
|
194 |
+
This project can only be used for research purposes, and the project developer is not responsible for any harm or loss caused by the use of this project (including but not limited to data, models, codes, etc.). For details, please refer to [Disclaimer](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER).
|
195 |
+
The license agreement for the project code is [The Apache License 2.0](/LICENSE), the code is free for commercial use, and the model weights and data can only be used for research purposes. Please attach MedicalGPT's link and license agreement in the product description.
|
196 |
+
|
197 |
+
## 😇 Citation
|
198 |
+
|
199 |
+
If you used MedicalGPT in your research, please cite as follows:
|
200 |
+
|
201 |
+
```latex
|
202 |
+
@misc{MedicalGPT,
|
203 |
+
title={MedicalGPT: Training Medical GPT Model},
|
204 |
+
author={Ming Xu},
|
205 |
+
year={2023},
|
206 |
+
howpublished={\url{https://github.com/shibing624/MedicalGPT}},
|
207 |
+
}
|
208 |
+
```
|
209 |
+
|
210 |
+
## 😍 Contribute
|
211 |
+
|
212 |
+
The project code is still very rough. If you have improved the code, you are welcome to submit it back to this project. Before submitting, please pay attention to the following two points:
|
213 |
+
|
214 |
+
- Add corresponding unit tests in `tests`
|
215 |
+
- Use `python -m pytest` to run all unit tests to ensure that all unit tests are passed
|
216 |
+
|
217 |
+
Then you can submit a PR.
|
218 |
+
|
219 |
+
## 💕 Acknowledgements
|
220 |
+
|
221 |
+
- [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
|
222 |
+
- [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
223 |
+
|
224 |
+
Thanks for their great work!
|
_config.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
theme: jekyll-theme-cayman
|
build_domain_tokenizer.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description: Build chinese tokenizer from corpus txt
|
5 |
+
|
6 |
+
# train sentencepiece model from `corpus.txt` and makes `m.model` and `m.vocab`
|
7 |
+
# `m.vocab` is just a reference. not used in the segmentation.
|
8 |
+
# spm.SentencePieceTrainer.train('--input=data/pretrain/tianlongbabu.txt --model_prefix=m --vocab_size=20000')
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import sentencepiece as spm
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--in_file', default='data/pretrain/fever.txt', type=str)
|
18 |
+
parser.add_argument('--domain_sp_model_name', default='domain_sp', type=str)
|
19 |
+
parser.add_argument('--max_sentence_length', default=16384, type=int)
|
20 |
+
parser.add_argument('--pad_id', default=3, type=int)
|
21 |
+
parser.add_argument('--vocab_size', default=2236, type=int)
|
22 |
+
parser.add_argument('--model_type', default="BPE", type=str)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
print(args)
|
26 |
+
|
27 |
+
spm.SentencePieceTrainer.train(
|
28 |
+
input=args.in_file,
|
29 |
+
model_prefix=args.domain_sp_model_name,
|
30 |
+
shuffle_input_sentence=False,
|
31 |
+
train_extremely_large_corpus=True,
|
32 |
+
max_sentence_length=args.max_sentence_length,
|
33 |
+
pad_id=args.pad_id,
|
34 |
+
model_type=args.model_type,
|
35 |
+
vocab_size=args.vocab_size,
|
36 |
+
split_digits=True,
|
37 |
+
split_by_unicode_script=True,
|
38 |
+
byte_fallback=True,
|
39 |
+
allow_whitespace_only_pieces=True,
|
40 |
+
remove_extra_whitespaces=False,
|
41 |
+
normalization_rule_name="nfkc",
|
42 |
+
)
|
43 |
+
|
44 |
+
# makes segmenter instance and loads the model file (m.model)
|
45 |
+
sp = spm.SentencePieceProcessor()
|
46 |
+
model_file = args.domain_sp_model_name + '.model'
|
47 |
+
sp.load(model_file)
|
48 |
+
|
49 |
+
# encode: text => id
|
50 |
+
print(sp.encode_as_pieces('潜伏性感染又称潜在性感染。慕容复来到河边,this is a test'))
|
51 |
+
print(sp.encode_as_ids('this is a test'))
|
52 |
+
|
53 |
+
# decode: id => text
|
54 |
+
print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
|
55 |
+
# print(sp.decode_ids([209, 31, 9, 375, 586]))
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
main()
|
convert_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Convert alpaca dataset into sharegpt format.
|
3 |
+
|
4 |
+
Usage: python convert_alpaca.py --in_file alpaca_data.json --out_file alpaca_data_sharegpt.json
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--in_file", type=str)
|
14 |
+
parser.add_argument("--out_file", type=str)
|
15 |
+
parser.add_argument("--data_type", type=str, default='alpaca')
|
16 |
+
args = parser.parse_args()
|
17 |
+
print(args)
|
18 |
+
data_files = {"train": args.in_file}
|
19 |
+
raw_datasets = load_dataset('json', data_files=data_files)
|
20 |
+
ds = raw_datasets['train']
|
21 |
+
|
22 |
+
|
23 |
+
def process_alpaca(examples):
|
24 |
+
convs = []
|
25 |
+
for instruction, inp, output in zip(examples['instruction'], examples['input'], examples['output']):
|
26 |
+
if len(inp.strip()) > 1:
|
27 |
+
instruction = instruction + '\n\n' + inp
|
28 |
+
q = instruction
|
29 |
+
a = output
|
30 |
+
convs.append([
|
31 |
+
{"from": "human", "value": q},
|
32 |
+
{"from": "gpt", "value": a}
|
33 |
+
])
|
34 |
+
return {"conversations": convs}
|
35 |
+
|
36 |
+
|
37 |
+
if args.data_type in ['alpaca']:
|
38 |
+
ds = ds.map(process_alpaca, batched=True, remove_columns=ds.column_names, desc="Running process")
|
39 |
+
else:
|
40 |
+
# Other sharegpt dataset, need rename to conversations and remove unused columns
|
41 |
+
if "items" in ds.column_names:
|
42 |
+
ds = ds.rename(columns={"items": "conversations"})
|
43 |
+
columns_to_remove = ds.column_names.copy()
|
44 |
+
columns_to_remove.remove('conversations')
|
45 |
+
ds = ds.remove_columns(columns_to_remove)
|
46 |
+
|
47 |
+
ds.to_json(f"{args.out_file}", lines=True, force_ascii=False)
|
deepspeed_config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"optimizer": {
|
3 |
+
"type": "AdamW",
|
4 |
+
"params": {
|
5 |
+
"lr": "auto",
|
6 |
+
"weight_decay": "auto",
|
7 |
+
"torch_adam": true,
|
8 |
+
"adam_w_mode": true
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"scheduler": {
|
12 |
+
"type": "WarmupDecayLR",
|
13 |
+
"params": {
|
14 |
+
"warmup_min_lr": "auto",
|
15 |
+
"warmup_max_lr": "auto",
|
16 |
+
"warmup_num_steps": "auto",
|
17 |
+
"total_num_steps": "auto"
|
18 |
+
}
|
19 |
+
},
|
20 |
+
"fp16": {
|
21 |
+
"enabled": true,
|
22 |
+
"loss_scale": 0,
|
23 |
+
"loss_scale_window": 1000,
|
24 |
+
"initial_scale_power": 16,
|
25 |
+
"hysteresis": 2,
|
26 |
+
"min_loss_scale": 1
|
27 |
+
},
|
28 |
+
"zero_optimization": {
|
29 |
+
"stage": 2,
|
30 |
+
"allgather_partitions": true,
|
31 |
+
"allgather_bucket_size": 2e8,
|
32 |
+
"reduce_scatter": true,
|
33 |
+
"reduce_bucket_size": "auto",
|
34 |
+
"overlap_comm": true,
|
35 |
+
"contiguous_gradients": true
|
36 |
+
},
|
37 |
+
"gradient_accumulation_steps": "auto",
|
38 |
+
"gradient_clipping": "auto",
|
39 |
+
"steps_per_print": 1000,
|
40 |
+
"train_batch_size": "auto",
|
41 |
+
"train_micro_batch_size_per_gpu": "auto",
|
42 |
+
"wall_clock_breakdown": false
|
43 |
+
}
|
dpo_training.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description: Train a model from SFT using DPO
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from glob import glob
|
10 |
+
from typing import Dict, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset
|
14 |
+
from loguru import logger
|
15 |
+
from peft import LoraConfig, TaskType
|
16 |
+
from transformers import (
|
17 |
+
AutoConfig,
|
18 |
+
BloomForCausalLM,
|
19 |
+
AutoModelForCausalLM,
|
20 |
+
AutoModel,
|
21 |
+
LlamaTokenizer,
|
22 |
+
LlamaForCausalLM,
|
23 |
+
BloomTokenizerFast,
|
24 |
+
AutoTokenizer,
|
25 |
+
HfArgumentParser,
|
26 |
+
TrainingArguments,
|
27 |
+
BitsAndBytesConfig,
|
28 |
+
)
|
29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
30 |
+
from trl import DPOTrainer
|
31 |
+
|
32 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
|
33 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
34 |
+
|
35 |
+
MODEL_CLASSES = {
|
36 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
37 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
38 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
39 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
40 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class ScriptArguments:
|
46 |
+
"""
|
47 |
+
The name of the Casual LM model we wish to fine with DPO
|
48 |
+
"""
|
49 |
+
# Model arguments
|
50 |
+
model_type: str = field(
|
51 |
+
default=None,
|
52 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
53 |
+
)
|
54 |
+
model_name_or_path: Optional[str] = field(
|
55 |
+
default=None, metadata={"help": "The model checkpoint for weights initialization."}
|
56 |
+
)
|
57 |
+
tokenizer_name_or_path: Optional[str] = field(
|
58 |
+
default=None, metadata={"help": "The tokenizer for weights initialization."}
|
59 |
+
)
|
60 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
61 |
+
load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."})
|
62 |
+
cache_dir: Optional[str] = field(
|
63 |
+
default=None,
|
64 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
65 |
+
)
|
66 |
+
use_fast_tokenizer: bool = field(
|
67 |
+
default=False,
|
68 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
69 |
+
)
|
70 |
+
torch_dtype: Optional[str] = field(
|
71 |
+
default=None,
|
72 |
+
metadata={
|
73 |
+
"help": (
|
74 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
75 |
+
"dtype will be automatically derived from the model's weights."
|
76 |
+
),
|
77 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
78 |
+
},
|
79 |
+
)
|
80 |
+
device_map: Optional[str] = field(
|
81 |
+
default="auto",
|
82 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
83 |
+
)
|
84 |
+
trust_remote_code: bool = field(
|
85 |
+
default=True,
|
86 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
87 |
+
)
|
88 |
+
# Dataset arguments
|
89 |
+
dataset_name: Optional[str] = field(
|
90 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
91 |
+
)
|
92 |
+
dataset_config_name: Optional[str] = field(
|
93 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
94 |
+
)
|
95 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
96 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
97 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
|
98 |
+
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
|
99 |
+
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "Eval batch size per device"})
|
100 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
101 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
102 |
+
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
|
103 |
+
max_train_samples: Optional[int] = field(
|
104 |
+
default=None,
|
105 |
+
metadata={
|
106 |
+
"help": (
|
107 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
108 |
+
"value if set."
|
109 |
+
)
|
110 |
+
},
|
111 |
+
)
|
112 |
+
max_eval_samples: Optional[int] = field(
|
113 |
+
default=None,
|
114 |
+
metadata={
|
115 |
+
"help": (
|
116 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
117 |
+
"value if set."
|
118 |
+
)
|
119 |
+
},
|
120 |
+
)
|
121 |
+
overwrite_cache: bool = field(
|
122 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
123 |
+
)
|
124 |
+
validation_split_percentage: Optional[int] = field(
|
125 |
+
default=1,
|
126 |
+
metadata={
|
127 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
128 |
+
},
|
129 |
+
)
|
130 |
+
preprocessing_num_workers: Optional[int] = field(
|
131 |
+
default=4, metadata={"help": "The number of processes to use for the preprocessing."},
|
132 |
+
)
|
133 |
+
# Training arguments
|
134 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
135 |
+
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
|
136 |
+
target_modules: Optional[str] = field(default=None)
|
137 |
+
lora_rank: Optional[int] = field(default=8)
|
138 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
139 |
+
lora_alpha: Optional[float] = field(default=16.0)
|
140 |
+
peft_path: Optional[str] = field(default=None)
|
141 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
142 |
+
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
|
143 |
+
beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for DPO loss"})
|
144 |
+
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "Learning rate"})
|
145 |
+
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler type"})
|
146 |
+
warmup_steps: Optional[int] = field(default=100, metadata={"help": "The number of warmup steps"})
|
147 |
+
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "The weight decay"})
|
148 |
+
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer type"})
|
149 |
+
fp16: Optional[bool] = field(default=True, metadata={"help": "Whether to use fp16"})
|
150 |
+
bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to use bf16"})
|
151 |
+
gradient_checkpointing: Optional[bool] = field(
|
152 |
+
default=True, metadata={"help": "Whether to use gradient checkpointing"}
|
153 |
+
)
|
154 |
+
gradient_accumulation_steps: Optional[int] = field(
|
155 |
+
default=4, metadata={"help": "The number of gradient accumulation steps"}
|
156 |
+
)
|
157 |
+
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
|
158 |
+
eval_steps: Optional[int] = field(default=50, metadata={"help": "X steps to evaluate the model"})
|
159 |
+
logging_steps: Optional[int] = field(default=1, metadata={"help": "X steps to log the model"})
|
160 |
+
output_dir: Optional[str] = field(default="outputs-dpo", metadata={"help": "The output directory"})
|
161 |
+
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
|
162 |
+
eval_strategy: Optional[str] = field(default="steps", metadata={"help": "Evaluation strategy"})
|
163 |
+
remove_unused_columns: Optional[bool] = field(
|
164 |
+
default=False,
|
165 |
+
metadata={"help": "Remove unused columns from the dataset if `datasets.Dataset` is used"},
|
166 |
+
)
|
167 |
+
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
|
168 |
+
|
169 |
+
def __post_init__(self):
|
170 |
+
if self.model_type is None:
|
171 |
+
raise ValueError("You must specify a valid model_type to run training.")
|
172 |
+
if self.model_name_or_path is None:
|
173 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
174 |
+
|
175 |
+
|
176 |
+
def print_trainable_parameters(model):
|
177 |
+
"""
|
178 |
+
Prints the number of trainable parameters in the model.
|
179 |
+
"""
|
180 |
+
trainable_params = 0
|
181 |
+
all_param = 0
|
182 |
+
for _, param in model.named_parameters():
|
183 |
+
all_param += param.numel()
|
184 |
+
if param.requires_grad:
|
185 |
+
trainable_params += param.numel()
|
186 |
+
print(
|
187 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
192 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
193 |
+
cls = torch.nn.Linear
|
194 |
+
if int4 or int8:
|
195 |
+
import bitsandbytes as bnb
|
196 |
+
if int4:
|
197 |
+
cls = bnb.nn.Linear4bit
|
198 |
+
elif int8:
|
199 |
+
cls = bnb.nn.Linear8bitLt
|
200 |
+
lora_module_names = set()
|
201 |
+
for name, module in peft_model.named_modules():
|
202 |
+
if isinstance(module, cls):
|
203 |
+
# last layer is not add to lora_module_names
|
204 |
+
if 'lm_head' in name:
|
205 |
+
continue
|
206 |
+
if 'output_layer' in name:
|
207 |
+
continue
|
208 |
+
names = name.split('.')
|
209 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
210 |
+
return sorted(lora_module_names)
|
211 |
+
|
212 |
+
|
213 |
+
def return_prompt_and_responses(examples) -> Dict[str, str]:
|
214 |
+
"""Load the paired dataset and convert it to the necessary format.
|
215 |
+
|
216 |
+
The dataset is converted to a dictionary with the following structure:
|
217 |
+
{
|
218 |
+
'prompt': List[str],
|
219 |
+
'chosen': List[str],
|
220 |
+
'rejected': List[str],
|
221 |
+
}
|
222 |
+
|
223 |
+
Prompts are structured as follows:
|
224 |
+
"Question: " + <prompt> + "\n\nAnswer: "
|
225 |
+
"""
|
226 |
+
return {
|
227 |
+
"prompt": ["Question: " + question + "\n\nAnswer: " for question in examples["question"]],
|
228 |
+
"chosen": examples["response_chosen"],
|
229 |
+
"rejected": examples["response_rejected"],
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def main():
|
234 |
+
parser = HfArgumentParser(ScriptArguments)
|
235 |
+
args = parser.parse_args_into_dataclasses()[0]
|
236 |
+
logger.info(f"Parse args: {args}")
|
237 |
+
|
238 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
239 |
+
if args.model_type == 'bloom':
|
240 |
+
args.use_fast_tokenizer = True
|
241 |
+
# Load tokenizer
|
242 |
+
tokenizer_kwargs = {
|
243 |
+
"cache_dir": args.cache_dir,
|
244 |
+
"use_fast": args.use_fast_tokenizer,
|
245 |
+
"trust_remote_code": args.trust_remote_code,
|
246 |
+
}
|
247 |
+
tokenizer_name_or_path = args.tokenizer_name_or_path
|
248 |
+
if not tokenizer_name_or_path:
|
249 |
+
tokenizer_name_or_path = args.model_name_or_path
|
250 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
251 |
+
if tokenizer.pad_token_id is None:
|
252 |
+
tokenizer.pad_token_id = 0 # set as the <unk> token
|
253 |
+
|
254 |
+
# Get datasets
|
255 |
+
if args.dataset_name is not None:
|
256 |
+
# Downloading and loading a dataset from the hub.
|
257 |
+
raw_datasets = load_dataset(
|
258 |
+
args.dataset_name,
|
259 |
+
args.dataset_config_name,
|
260 |
+
cache_dir=args.cache_dir,
|
261 |
+
)
|
262 |
+
if "validation" not in raw_datasets.keys():
|
263 |
+
raw_datasets["validation"] = load_dataset(
|
264 |
+
args.dataset_name,
|
265 |
+
args.dataset_config_name,
|
266 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
267 |
+
cache_dir=args.cache_dir,
|
268 |
+
)
|
269 |
+
raw_datasets["train"] = load_dataset(
|
270 |
+
args.dataset_name,
|
271 |
+
args.dataset_config_name,
|
272 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
273 |
+
cache_dir=args.cache_dir,
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
data_files = {}
|
277 |
+
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
|
278 |
+
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
|
279 |
+
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
|
280 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
281 |
+
data_files["train"] = train_data_files
|
282 |
+
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
|
283 |
+
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
284 |
+
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
|
285 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
286 |
+
data_files["validation"] = eval_data_files
|
287 |
+
raw_datasets = load_dataset(
|
288 |
+
'json',
|
289 |
+
data_files=data_files,
|
290 |
+
cache_dir=args.cache_dir,
|
291 |
+
)
|
292 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
293 |
+
if "validation" not in raw_datasets.keys():
|
294 |
+
raw_datasets["validation"] = load_dataset(
|
295 |
+
'json',
|
296 |
+
data_files=data_files,
|
297 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
298 |
+
cache_dir=args.cache_dir,
|
299 |
+
)
|
300 |
+
raw_datasets["train"] = load_dataset(
|
301 |
+
'json',
|
302 |
+
data_files=data_files,
|
303 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
304 |
+
cache_dir=args.cache_dir,
|
305 |
+
)
|
306 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
307 |
+
|
308 |
+
# Preprocessing the datasets
|
309 |
+
max_source_length = args.max_source_length
|
310 |
+
max_target_length = args.max_target_length
|
311 |
+
full_max_length = max_source_length + max_target_length
|
312 |
+
|
313 |
+
# Preprocess the dataset
|
314 |
+
train_dataset = None
|
315 |
+
max_train_samples = 0
|
316 |
+
if args.do_train:
|
317 |
+
if "train" not in raw_datasets:
|
318 |
+
raise ValueError("--do_train requires a train dataset")
|
319 |
+
train_dataset = raw_datasets['train']
|
320 |
+
max_train_samples = len(train_dataset)
|
321 |
+
if args.max_train_samples is not None and args.max_train_samples > 0:
|
322 |
+
max_train_samples = min(len(train_dataset), args.max_train_samples)
|
323 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
324 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
325 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
326 |
+
return_prompt_and_responses,
|
327 |
+
batched=True,
|
328 |
+
num_proc=args.preprocessing_num_workers,
|
329 |
+
remove_columns=train_dataset.column_names,
|
330 |
+
load_from_cache_file=not args.overwrite_cache,
|
331 |
+
desc="Running tokenizer on dataset",
|
332 |
+
)
|
333 |
+
train_dataset = tokenized_dataset.filter(
|
334 |
+
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
|
335 |
+
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
|
336 |
+
)
|
337 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
338 |
+
logger.debug("First train example:")
|
339 |
+
logger.debug(train_dataset[0]['prompt'] + train_dataset[0]['chosen'])
|
340 |
+
|
341 |
+
eval_dataset = None
|
342 |
+
max_eval_samples = 0
|
343 |
+
if args.do_eval:
|
344 |
+
if "validation" not in raw_datasets:
|
345 |
+
raise ValueError("--do_eval requires a validation dataset")
|
346 |
+
eval_dataset = raw_datasets["validation"]
|
347 |
+
max_eval_samples = len(eval_dataset)
|
348 |
+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
|
349 |
+
max_eval_samples = min(len(eval_dataset), args.max_eval_samples)
|
350 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
351 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
352 |
+
eval_dataset = eval_dataset.map(
|
353 |
+
return_prompt_and_responses,
|
354 |
+
batched=True,
|
355 |
+
num_proc=args.preprocessing_num_workers,
|
356 |
+
remove_columns=eval_dataset.column_names,
|
357 |
+
load_from_cache_file=not args.overwrite_cache,
|
358 |
+
desc="Running tokenizer on dataset",
|
359 |
+
)
|
360 |
+
eval_dataset = eval_dataset.filter(
|
361 |
+
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
|
362 |
+
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
|
363 |
+
)
|
364 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
365 |
+
logger.debug("First eval example:")
|
366 |
+
logger.debug(eval_dataset[0]['prompt'] + eval_dataset[0]['chosen'])
|
367 |
+
|
368 |
+
logger.info("Loading model")
|
369 |
+
torch_dtype = (
|
370 |
+
args.torch_dtype
|
371 |
+
if args.torch_dtype in ["auto", None]
|
372 |
+
else getattr(torch, args.torch_dtype)
|
373 |
+
)
|
374 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
375 |
+
ddp = world_size != 1
|
376 |
+
if ddp:
|
377 |
+
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
378 |
+
if args.qlora and is_deepspeed_zero3_enabled():
|
379 |
+
logger.warning("ZeRO3 are both currently incompatible with QLoRA.")
|
380 |
+
config = config_class.from_pretrained(
|
381 |
+
args.model_name_or_path,
|
382 |
+
trust_remote_code=args.trust_remote_code,
|
383 |
+
torch_dtype=torch_dtype,
|
384 |
+
cache_dir=args.cache_dir
|
385 |
+
)
|
386 |
+
model = model_class.from_pretrained(
|
387 |
+
args.model_name_or_path,
|
388 |
+
config=config,
|
389 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
390 |
+
device_map=args.device_map,
|
391 |
+
trust_remote_code=args.trust_remote_code,
|
392 |
+
quantization_config=BitsAndBytesConfig(
|
393 |
+
load_in_4bit=True,
|
394 |
+
bnb_4bit_use_double_quant=True,
|
395 |
+
bnb_4bit_quant_type="nf4",
|
396 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
397 |
+
) if args.qlora else None,
|
398 |
+
)
|
399 |
+
model_ref = model_class.from_pretrained(
|
400 |
+
args.model_name_or_path,
|
401 |
+
config=config,
|
402 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
403 |
+
device_map=args.device_map,
|
404 |
+
trust_remote_code=args.trust_remote_code,
|
405 |
+
quantization_config=BitsAndBytesConfig(
|
406 |
+
load_in_4bit=True,
|
407 |
+
bnb_4bit_use_double_quant=True,
|
408 |
+
bnb_4bit_quant_type="nf4",
|
409 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
410 |
+
) if args.qlora else None,
|
411 |
+
)
|
412 |
+
|
413 |
+
# Initialize our Trainer
|
414 |
+
if args.gradient_checkpointing:
|
415 |
+
model.gradient_checkpointing_enable()
|
416 |
+
model.config.use_cache = False
|
417 |
+
else:
|
418 |
+
model.config.use_cache = True
|
419 |
+
|
420 |
+
training_args = TrainingArguments(
|
421 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
422 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
423 |
+
max_steps=args.max_steps,
|
424 |
+
logging_steps=args.logging_steps,
|
425 |
+
save_steps=args.save_steps,
|
426 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
427 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
428 |
+
learning_rate=args.learning_rate,
|
429 |
+
evaluation_strategy=args.eval_strategy,
|
430 |
+
eval_steps=args.eval_steps,
|
431 |
+
output_dir=args.output_dir,
|
432 |
+
report_to=args.report_to,
|
433 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
434 |
+
warmup_steps=args.warmup_steps,
|
435 |
+
optim=args.optim,
|
436 |
+
bf16=args.bf16,
|
437 |
+
fp16=args.fp16,
|
438 |
+
remove_unused_columns=args.remove_unused_columns,
|
439 |
+
run_name=f"dpo_{args.model_type}",
|
440 |
+
)
|
441 |
+
|
442 |
+
# Initialize DPO trainer
|
443 |
+
target_modules = args.target_modules.split(',') if args.target_modules else None
|
444 |
+
if target_modules and 'all' in target_modules:
|
445 |
+
target_modules = find_all_linear_names(model, int4=args.load_in_4bit, int8=args.load_in_8bit)
|
446 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
447 |
+
peft_config = LoraConfig(
|
448 |
+
task_type=TaskType.CAUSAL_LM,
|
449 |
+
target_modules=target_modules,
|
450 |
+
inference_mode=False,
|
451 |
+
r=args.lora_rank,
|
452 |
+
lora_alpha=args.lora_alpha,
|
453 |
+
lora_dropout=args.lora_dropout,
|
454 |
+
)
|
455 |
+
trainer = DPOTrainer(
|
456 |
+
model,
|
457 |
+
model_ref,
|
458 |
+
args=training_args,
|
459 |
+
beta=args.beta,
|
460 |
+
train_dataset=train_dataset,
|
461 |
+
eval_dataset=eval_dataset,
|
462 |
+
tokenizer=tokenizer,
|
463 |
+
peft_config=peft_config if args.use_peft else None,
|
464 |
+
max_prompt_length=args.max_source_length,
|
465 |
+
max_length=full_max_length,
|
466 |
+
)
|
467 |
+
print_trainable_parameters(trainer.model)
|
468 |
+
|
469 |
+
# Training
|
470 |
+
if args.do_train:
|
471 |
+
logger.info("*** Train ***")
|
472 |
+
train_result = trainer.train()
|
473 |
+
metrics = train_result.metrics
|
474 |
+
metrics["train_samples"] = max_train_samples
|
475 |
+
logger.debug(f"Training metrics: {metrics}")
|
476 |
+
trainer.log_metrics("train", metrics)
|
477 |
+
trainer.save_metrics("train", metrics)
|
478 |
+
trainer.save_state()
|
479 |
+
logger.info(f"Saving model checkpoint to {args.output_dir}")
|
480 |
+
trainer.save_model(args.output_dir)
|
481 |
+
tokenizer.save_pretrained(args.output_dir)
|
482 |
+
trainer.model.save_pretrained(args.output_dir)
|
483 |
+
|
484 |
+
# Evaluation
|
485 |
+
if args.do_eval and trainer.is_world_process_zero():
|
486 |
+
logger.info("*** Evaluate ***")
|
487 |
+
metrics = trainer.evaluate()
|
488 |
+
metrics["eval_samples"] = max_eval_samples
|
489 |
+
logger.debug(f"Eval metrics: {metrics}")
|
490 |
+
trainer.log_metrics("eval", metrics)
|
491 |
+
trainer.save_metrics("eval", metrics)
|
492 |
+
|
493 |
+
|
494 |
+
if __name__ == "__main__":
|
495 |
+
main()
|
gradio_demo.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
|
6 |
+
pip install gradio
|
7 |
+
pip install mdtex2html
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
from threading import Thread
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import mdtex2html
|
15 |
+
import torch
|
16 |
+
from peft import PeftModel
|
17 |
+
from transformers import (
|
18 |
+
AutoModel,
|
19 |
+
AutoTokenizer,
|
20 |
+
AutoModelForCausalLM,
|
21 |
+
BloomForCausalLM,
|
22 |
+
BloomTokenizerFast,
|
23 |
+
LlamaTokenizer,
|
24 |
+
LlamaForCausalLM,
|
25 |
+
GenerationConfig,
|
26 |
+
TextIteratorStreamer,
|
27 |
+
)
|
28 |
+
|
29 |
+
from supervised_finetuning import get_conv_template
|
30 |
+
|
31 |
+
MODEL_CLASSES = {
|
32 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
33 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
34 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
35 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
36 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def stream_generate_answer(
|
42 |
+
model,
|
43 |
+
tokenizer,
|
44 |
+
prompt,
|
45 |
+
device,
|
46 |
+
max_new_tokens=512,
|
47 |
+
temperature=0.7,
|
48 |
+
top_p=0.8,
|
49 |
+
repetition_penalty=1.0,
|
50 |
+
context_len=2048,
|
51 |
+
):
|
52 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
|
53 |
+
input_ids = tokenizer(prompt).input_ids
|
54 |
+
max_src_len = context_len - max_new_tokens - 8
|
55 |
+
input_ids = input_ids[-max_src_len:]
|
56 |
+
generation_kwargs = dict(
|
57 |
+
input_ids=torch.as_tensor([input_ids]).to(device),
|
58 |
+
max_new_tokens=max_new_tokens,
|
59 |
+
temperature=temperature,
|
60 |
+
top_p=top_p,
|
61 |
+
repetition_penalty=repetition_penalty,
|
62 |
+
streamer=streamer,
|
63 |
+
)
|
64 |
+
|
65 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
66 |
+
thread.start()
|
67 |
+
|
68 |
+
yield from streamer
|
69 |
+
|
70 |
+
|
71 |
+
def main():
|
72 |
+
parser = argparse.ArgumentParser()
|
73 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
74 |
+
parser.add_argument('--base_model', default=None, type=str, required=True)
|
75 |
+
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
|
76 |
+
parser.add_argument('--tokenizer_path', default=None, type=str)
|
77 |
+
parser.add_argument('--template_name', default="vicuna", type=str,
|
78 |
+
help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
|
79 |
+
parser.add_argument('--gpus', default="0", type=str)
|
80 |
+
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
|
81 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
82 |
+
args = parser.parse_args()
|
83 |
+
if args.only_cpu is True:
|
84 |
+
args.gpus = ""
|
85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
86 |
+
|
87 |
+
def postprocess(self, y):
|
88 |
+
if y is None:
|
89 |
+
return []
|
90 |
+
for i, (message, response) in enumerate(y):
|
91 |
+
y[i] = (
|
92 |
+
None if message is None else mdtex2html.convert((message)),
|
93 |
+
None if response is None else mdtex2html.convert(response),
|
94 |
+
)
|
95 |
+
return y
|
96 |
+
|
97 |
+
gr.Chatbot.postprocess = postprocess
|
98 |
+
|
99 |
+
load_type = torch.float16
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
device = torch.device(0)
|
102 |
+
else:
|
103 |
+
device = torch.device('cpu')
|
104 |
+
|
105 |
+
if args.tokenizer_path is None:
|
106 |
+
args.tokenizer_path = args.base_model
|
107 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
108 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
109 |
+
base_model = model_class.from_pretrained(
|
110 |
+
args.base_model,
|
111 |
+
load_in_8bit=False,
|
112 |
+
torch_dtype=load_type,
|
113 |
+
low_cpu_mem_usage=True,
|
114 |
+
device_map='auto',
|
115 |
+
trust_remote_code=True,
|
116 |
+
)
|
117 |
+
try:
|
118 |
+
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
|
119 |
+
except OSError:
|
120 |
+
print("Failed to load generation config, use default.")
|
121 |
+
if args.resize_emb:
|
122 |
+
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
|
123 |
+
tokenzier_vocab_size = len(tokenizer)
|
124 |
+
print(f"Vocab of the base model: {model_vocab_size}")
|
125 |
+
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
|
126 |
+
if model_vocab_size != tokenzier_vocab_size:
|
127 |
+
print("Resize model embeddings to fit tokenizer")
|
128 |
+
base_model.resize_token_embeddings(tokenzier_vocab_size)
|
129 |
+
if args.lora_model:
|
130 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
|
131 |
+
print("loaded lora model")
|
132 |
+
else:
|
133 |
+
model = base_model
|
134 |
+
if device == torch.device('cpu'):
|
135 |
+
model.float()
|
136 |
+
|
137 |
+
model.eval()
|
138 |
+
|
139 |
+
def reset_user_input():
|
140 |
+
return gr.update(value='')
|
141 |
+
|
142 |
+
def reset_state():
|
143 |
+
return [], []
|
144 |
+
|
145 |
+
prompt_template = get_conv_template(args.template_name)
|
146 |
+
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
|
147 |
+
history = []
|
148 |
+
|
149 |
+
def predict(
|
150 |
+
input,
|
151 |
+
chatbot,
|
152 |
+
history,
|
153 |
+
max_new_tokens,
|
154 |
+
temperature,
|
155 |
+
top_p
|
156 |
+
):
|
157 |
+
now_input = input
|
158 |
+
chatbot.append((input, ""))
|
159 |
+
history = history or []
|
160 |
+
history.append([now_input, ''])
|
161 |
+
|
162 |
+
prompt = prompt_template.get_prompt(messages=history)
|
163 |
+
response = ""
|
164 |
+
|
165 |
+
for new_text in stream_generate_answer(
|
166 |
+
model,
|
167 |
+
tokenizer,
|
168 |
+
prompt,
|
169 |
+
device,
|
170 |
+
max_new_tokens=max_new_tokens,
|
171 |
+
temperature=temperature,
|
172 |
+
top_p=top_p,
|
173 |
+
):
|
174 |
+
stop = False
|
175 |
+
pos = new_text.find(stop_str)
|
176 |
+
if pos != -1:
|
177 |
+
new_text = new_text[:pos]
|
178 |
+
stop = True
|
179 |
+
response += new_text
|
180 |
+
new_history = history + [(now_input, response)]
|
181 |
+
chatbot[-1] = (now_input, response)
|
182 |
+
yield chatbot, new_history
|
183 |
+
if stop:
|
184 |
+
break
|
185 |
+
|
186 |
+
with gr.Blocks() as demo:
|
187 |
+
gr.HTML("""<h1 align="center">MedicalGPT</h1>""")
|
188 |
+
gr.Markdown(
|
189 |
+
"> 为了促进医疗行业大模型的开放研究,本项目开源了MedicalGPT医疗大模型")
|
190 |
+
chatbot = gr.Chatbot()
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=4):
|
193 |
+
with gr.Column(scale=12):
|
194 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
195 |
+
container=False)
|
196 |
+
with gr.Column(min_width=32, scale=1):
|
197 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
198 |
+
with gr.Column(scale=1):
|
199 |
+
emptyBtn = gr.Button("Clear History")
|
200 |
+
max_length = gr.Slider(
|
201 |
+
0, 4096, value=512, step=1.0, label="Maximum length", interactive=True)
|
202 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01,
|
203 |
+
label="Top P", interactive=True)
|
204 |
+
temperature = gr.Slider(
|
205 |
+
0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
|
206 |
+
history = gr.State([])
|
207 |
+
submitBtn.click(predict, [user_input, chatbot, history, max_length, temperature, top_p], [chatbot, history],
|
208 |
+
show_progress=True)
|
209 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
210 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
211 |
+
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=8082)
|
212 |
+
|
213 |
+
|
214 |
+
if __name__ == '__main__':
|
215 |
+
main()
|
inference.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from threading import Thread
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from peft import PeftModel
|
13 |
+
from transformers import (
|
14 |
+
AutoModel,
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
AutoTokenizer,
|
17 |
+
BloomForCausalLM,
|
18 |
+
BloomTokenizerFast,
|
19 |
+
LlamaTokenizer,
|
20 |
+
LlamaForCausalLM,
|
21 |
+
TextIteratorStreamer,
|
22 |
+
GenerationConfig,
|
23 |
+
)
|
24 |
+
|
25 |
+
from supervised_finetuning import get_conv_template
|
26 |
+
|
27 |
+
MODEL_CLASSES = {
|
28 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
29 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
30 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
31 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
32 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
def stream_generate_answer(
|
38 |
+
model,
|
39 |
+
tokenizer,
|
40 |
+
prompt,
|
41 |
+
device,
|
42 |
+
do_print=True,
|
43 |
+
max_new_tokens=512,
|
44 |
+
temperature=0.7,
|
45 |
+
repetition_penalty=1.0,
|
46 |
+
context_len=2048,
|
47 |
+
stop_str="</s>",
|
48 |
+
):
|
49 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
|
50 |
+
input_ids = tokenizer(prompt).input_ids
|
51 |
+
max_src_len = context_len - max_new_tokens - 8
|
52 |
+
input_ids = input_ids[-max_src_len:]
|
53 |
+
generation_kwargs = dict(
|
54 |
+
input_ids=torch.as_tensor([input_ids]).to(device),
|
55 |
+
max_new_tokens=max_new_tokens,
|
56 |
+
temperature=temperature,
|
57 |
+
repetition_penalty=repetition_penalty,
|
58 |
+
streamer=streamer,
|
59 |
+
)
|
60 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
61 |
+
thread.start()
|
62 |
+
|
63 |
+
generated_text = ""
|
64 |
+
for new_text in streamer:
|
65 |
+
stop = False
|
66 |
+
pos = new_text.find(stop_str)
|
67 |
+
if pos != -1:
|
68 |
+
new_text = new_text[:pos]
|
69 |
+
stop = True
|
70 |
+
generated_text += new_text
|
71 |
+
if do_print:
|
72 |
+
print(new_text, end="", flush=True)
|
73 |
+
if stop:
|
74 |
+
break
|
75 |
+
if do_print:
|
76 |
+
print()
|
77 |
+
return generated_text
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
83 |
+
parser.add_argument('--base_model', default=None, type=str, required=True)
|
84 |
+
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
|
85 |
+
parser.add_argument('--tokenizer_path', default=None, type=str)
|
86 |
+
parser.add_argument('--template_name', default="vicuna", type=str,
|
87 |
+
help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
|
88 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
89 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
90 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
91 |
+
parser.add_argument('--data_file', default=None, type=str,
|
92 |
+
help="A file that contains instructions (one instruction per line)")
|
93 |
+
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)")
|
94 |
+
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
|
95 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
96 |
+
parser.add_argument('--gpus', default="0", type=str)
|
97 |
+
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
|
98 |
+
args = parser.parse_args()
|
99 |
+
print(args)
|
100 |
+
if args.only_cpu is True:
|
101 |
+
args.gpus = ""
|
102 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
103 |
+
load_type = torch.float16
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
device = torch.device(0)
|
106 |
+
else:
|
107 |
+
device = torch.device('cpu')
|
108 |
+
if args.tokenizer_path is None:
|
109 |
+
args.tokenizer_path = args.base_model
|
110 |
+
|
111 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
112 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
113 |
+
base_model = model_class.from_pretrained(
|
114 |
+
args.base_model,
|
115 |
+
load_in_8bit=False,
|
116 |
+
torch_dtype=load_type,
|
117 |
+
low_cpu_mem_usage=True,
|
118 |
+
device_map='auto',
|
119 |
+
trust_remote_code=True,
|
120 |
+
)
|
121 |
+
try:
|
122 |
+
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
|
123 |
+
except OSError:
|
124 |
+
print("Failed to load generation config, use default.")
|
125 |
+
if args.resize_emb:
|
126 |
+
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
|
127 |
+
tokenzier_vocab_size = len(tokenizer)
|
128 |
+
print(f"Vocab of the base model: {model_vocab_size}")
|
129 |
+
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
|
130 |
+
if model_vocab_size != tokenzier_vocab_size:
|
131 |
+
print("Resize model embeddings to fit tokenizer")
|
132 |
+
base_model.resize_token_embeddings(tokenzier_vocab_size)
|
133 |
+
|
134 |
+
if args.lora_model:
|
135 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
|
136 |
+
print("Loaded lora model")
|
137 |
+
else:
|
138 |
+
model = base_model
|
139 |
+
if device == torch.device('cpu'):
|
140 |
+
model.float()
|
141 |
+
model.eval()
|
142 |
+
print(tokenizer)
|
143 |
+
# test data
|
144 |
+
if args.data_file is None:
|
145 |
+
examples = ["介绍下北京", "乙肝和丙肝的区别?"]
|
146 |
+
else:
|
147 |
+
with open(args.data_file, 'r') as f:
|
148 |
+
examples = [l.strip() for l in f.readlines()]
|
149 |
+
print("first 10 examples:")
|
150 |
+
for example in examples[:10]:
|
151 |
+
print(example)
|
152 |
+
|
153 |
+
# Chat
|
154 |
+
prompt_template = get_conv_template(args.template_name)
|
155 |
+
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
|
156 |
+
|
157 |
+
if args.interactive:
|
158 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
159 |
+
history = []
|
160 |
+
while True:
|
161 |
+
try:
|
162 |
+
query = input(f"{prompt_template.roles[0]}: ")
|
163 |
+
except UnicodeDecodeError:
|
164 |
+
print("Detected decoding error at the inputs, please try again.")
|
165 |
+
continue
|
166 |
+
except Exception:
|
167 |
+
raise
|
168 |
+
if query == "":
|
169 |
+
print("Please input text, try again.")
|
170 |
+
continue
|
171 |
+
if query.strip() == "exit":
|
172 |
+
print("exit...")
|
173 |
+
break
|
174 |
+
if query.strip() == "clear":
|
175 |
+
history = []
|
176 |
+
print("history cleared.")
|
177 |
+
continue
|
178 |
+
|
179 |
+
print(f"{prompt_template.roles[1]}: ", end="", flush=True)
|
180 |
+
|
181 |
+
history.append([query, ''])
|
182 |
+
prompt = prompt_template.get_prompt(messages=history)
|
183 |
+
response = stream_generate_answer(
|
184 |
+
model,
|
185 |
+
tokenizer,
|
186 |
+
prompt,
|
187 |
+
device,
|
188 |
+
do_print=True,
|
189 |
+
max_new_tokens=args.max_new_tokens,
|
190 |
+
temperature=args.temperature,
|
191 |
+
repetition_penalty=args.repetition_penalty,
|
192 |
+
stop_str=stop_str,
|
193 |
+
)
|
194 |
+
if history:
|
195 |
+
history[-1][-1] = response.strip()
|
196 |
+
else:
|
197 |
+
print("Start inference.")
|
198 |
+
results = []
|
199 |
+
for index, example in enumerate(examples):
|
200 |
+
# Single turn inference
|
201 |
+
history = [[example, '']]
|
202 |
+
prompt = prompt_template.get_prompt(messages=history)
|
203 |
+
response = stream_generate_answer(
|
204 |
+
model,
|
205 |
+
tokenizer,
|
206 |
+
prompt,
|
207 |
+
device,
|
208 |
+
do_print=False,
|
209 |
+
max_new_tokens=args.max_new_tokens,
|
210 |
+
temperature=args.temperature,
|
211 |
+
repetition_penalty=args.repetition_penalty,
|
212 |
+
stop_str=stop_str,
|
213 |
+
)
|
214 |
+
response = response.strip()
|
215 |
+
print(f"======={index}=======")
|
216 |
+
print(f"Input: {example}\n")
|
217 |
+
print(f"Output: {response}\n")
|
218 |
+
results.append({"Input": prompt, "Output": response})
|
219 |
+
|
220 |
+
with open(args.predictions_file, 'w', encoding='utf-8') as f:
|
221 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
main()
|
merge_peft_adapter.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
|
6 |
+
Usage:
|
7 |
+
python merge_peft_adapter.py \
|
8 |
+
--base_model_name_or_path path/to/llama/model \
|
9 |
+
--tokenizer_path path/to/llama/tokenizer \
|
10 |
+
--peft_model_path path/to/lora/model \
|
11 |
+
--output_dir path/to/output/dir
|
12 |
+
|
13 |
+
after merged, chatglm and baichuan model need copy python script to output dir.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from peft import PeftModel, PeftConfig
|
20 |
+
from transformers import (
|
21 |
+
AutoModel,
|
22 |
+
AutoTokenizer,
|
23 |
+
BloomForCausalLM,
|
24 |
+
BloomTokenizerFast,
|
25 |
+
AutoModelForCausalLM,
|
26 |
+
LlamaTokenizer,
|
27 |
+
LlamaForCausalLM,
|
28 |
+
AutoModelForSequenceClassification,
|
29 |
+
)
|
30 |
+
|
31 |
+
MODEL_CLASSES = {
|
32 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
33 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
34 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
35 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
36 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
43 |
+
parser.add_argument('--base_model_name_or_path', default=None, required=True, type=str,
|
44 |
+
help="Base model name or path")
|
45 |
+
parser.add_argument('--tokenizer_path', default=None, type=str,
|
46 |
+
help="Please specify tokenization path.")
|
47 |
+
parser.add_argument('--peft_model_path', default=None, required=True, type=str,
|
48 |
+
help="Please specify LoRA model to be merged.")
|
49 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
50 |
+
parser.add_argument('--output_dir', default='./merged', type=str)
|
51 |
+
args = parser.parse_args()
|
52 |
+
print(args)
|
53 |
+
|
54 |
+
base_model_path = args.base_model_name_or_path
|
55 |
+
peft_model_path = args.peft_model_path
|
56 |
+
output_dir = args.output_dir
|
57 |
+
print(f"Base model: {base_model_path}")
|
58 |
+
print(f"LoRA model: {peft_model_path}")
|
59 |
+
peft_config = PeftConfig.from_pretrained(peft_model_path)
|
60 |
+
|
61 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
62 |
+
if peft_config.task_type == "SEQ_CLS":
|
63 |
+
print("Loading LoRA for sequence classification model")
|
64 |
+
if args.model_type == "chatglm":
|
65 |
+
raise ValueError("chatglm does not support sequence classification")
|
66 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(
|
67 |
+
base_model_path,
|
68 |
+
load_in_8bit=False,
|
69 |
+
torch_dtype=torch.float16,
|
70 |
+
trust_remote_code=True,
|
71 |
+
device_map="auto",
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
print("Loading LoRA for causal language model")
|
75 |
+
base_model = model_class.from_pretrained(
|
76 |
+
base_model_path,
|
77 |
+
load_in_8bit=False,
|
78 |
+
torch_dtype=torch.float16,
|
79 |
+
trust_remote_code=True,
|
80 |
+
device_map="auto",
|
81 |
+
)
|
82 |
+
if args.tokenizer_path:
|
83 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
84 |
+
else:
|
85 |
+
tokenizer = tokenizer_class.from_pretrained(peft_model_path, trust_remote_code=True)
|
86 |
+
if args.resize_emb:
|
87 |
+
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
|
88 |
+
if base_model_token_size != len(tokenizer):
|
89 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
90 |
+
print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
|
91 |
+
|
92 |
+
lora_model = PeftModel.from_pretrained(
|
93 |
+
base_model,
|
94 |
+
peft_model_path,
|
95 |
+
device_map="auto",
|
96 |
+
torch_dtype=torch.float16,
|
97 |
+
)
|
98 |
+
lora_model.eval()
|
99 |
+
print(f"Merging with merge_and_unload...")
|
100 |
+
base_model = lora_model.merge_and_unload()
|
101 |
+
|
102 |
+
print("Saving to Hugging Face format...")
|
103 |
+
tokenizer.save_pretrained(output_dir)
|
104 |
+
base_model.save_pretrained(output_dir)
|
105 |
+
print(f"Done! model saved to {output_dir}")
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
main()
|
merge_tokenizers.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
|
8 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
9 |
+
from transformers import LlamaTokenizer
|
10 |
+
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
11 |
+
import sentencepiece as spm
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
|
15 |
+
def is_chinese(uchar):
|
16 |
+
"""判断一个unicode是否是汉字"""
|
17 |
+
return '\u4e00' <= uchar <= '\u9fa5'
|
18 |
+
|
19 |
+
|
20 |
+
def is_chinese_string(string):
|
21 |
+
"""判断是否全为汉字"""
|
22 |
+
return all(is_chinese(c) for c in string)
|
23 |
+
|
24 |
+
|
25 |
+
def load_baichuan_vocab(vocab_file):
|
26 |
+
words = set()
|
27 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
28 |
+
for line in f:
|
29 |
+
if line.strip():
|
30 |
+
words.add(line.strip().split()[0])
|
31 |
+
return words
|
32 |
+
|
33 |
+
|
34 |
+
def load_jieba_vocab(jieba_vocab_file):
|
35 |
+
# Read jieba vocab and sort by freq
|
36 |
+
with open(jieba_vocab_file, "r", encoding="utf-8") as f:
|
37 |
+
lines = f.readlines()
|
38 |
+
word_freqs = [line.strip().split() for line in lines]
|
39 |
+
word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
|
40 |
+
return word_freqs
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
|
46 |
+
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
|
47 |
+
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
|
48 |
+
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
|
49 |
+
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
|
50 |
+
parser.add_argument('--jieba_word_size', default=20000, type=int)
|
51 |
+
|
52 |
+
args = parser.parse_args()
|
53 |
+
print(args)
|
54 |
+
|
55 |
+
# load
|
56 |
+
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
|
57 |
+
chinese_sp_model = spm.SentencePieceProcessor()
|
58 |
+
chinese_sp_model.Load(args.domain_sp_model_file)
|
59 |
+
|
60 |
+
llama_spm = sp_pb2_model.ModelProto()
|
61 |
+
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
|
62 |
+
chinese_spm = sp_pb2_model.ModelProto()
|
63 |
+
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
|
64 |
+
|
65 |
+
# print number of tokens
|
66 |
+
print(len(llama_tokenizer), len(chinese_sp_model))
|
67 |
+
print(llama_tokenizer.all_special_tokens)
|
68 |
+
print(llama_tokenizer.all_special_ids)
|
69 |
+
print(llama_tokenizer.special_tokens_map)
|
70 |
+
|
71 |
+
# Add Chinese tokens to LLaMA tokenizer
|
72 |
+
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
|
73 |
+
|
74 |
+
print(len(llama_spm_tokens_set))
|
75 |
+
print(f"Before:{len(llama_spm_tokens_set)}")
|
76 |
+
added_set = set()
|
77 |
+
for p in chinese_spm.pieces:
|
78 |
+
piece = p.piece
|
79 |
+
if piece not in llama_spm_tokens_set:
|
80 |
+
# print('picec', piece)
|
81 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
82 |
+
new_p.piece = piece
|
83 |
+
new_p.score = 0
|
84 |
+
llama_spm.pieces.append(new_p)
|
85 |
+
added_set.add(piece)
|
86 |
+
print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")
|
87 |
+
|
88 |
+
vocab = load_baichuan_vocab(args.baichuan_vocab_file)
|
89 |
+
print('baichuan vocab len:', len(vocab))
|
90 |
+
baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
|
91 |
+
print('baichuan chinese vocab size:', len(baichuan_vocab_set))
|
92 |
+
print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
|
93 |
+
for p in baichuan_vocab_set:
|
94 |
+
piece = p
|
95 |
+
if piece not in llama_spm_tokens_set and piece not in added_set:
|
96 |
+
# print('baichuan picec', piece)
|
97 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
98 |
+
new_p.piece = piece
|
99 |
+
new_p.score = 0
|
100 |
+
llama_spm.pieces.append(new_p)
|
101 |
+
added_set.add(piece)
|
102 |
+
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")
|
103 |
+
|
104 |
+
if args.add_jieba:
|
105 |
+
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
|
106 |
+
top_words = word_freqs[:args.jieba_word_size]
|
107 |
+
print('jieba top10 freq words:', top_words[:10])
|
108 |
+
jieba_vocab_set = set([i[0] for i in top_words if i])
|
109 |
+
print('jieba_vocab_set size:', len(jieba_vocab_set))
|
110 |
+
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
|
111 |
+
for p in jieba_vocab_set:
|
112 |
+
piece = p
|
113 |
+
if piece not in llama_spm_tokens_set and piece not in added_set:
|
114 |
+
# print('jieba picec', piece)
|
115 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
116 |
+
new_p.piece = piece
|
117 |
+
new_p.score = 0
|
118 |
+
llama_spm.pieces.append(new_p)
|
119 |
+
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
|
120 |
+
|
121 |
+
# Save
|
122 |
+
output_sp_dir = 'merged_tokenizer_sp'
|
123 |
+
output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer
|
124 |
+
os.makedirs(output_sp_dir, exist_ok=True)
|
125 |
+
with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
|
126 |
+
f.write(llama_spm.SerializeToString())
|
127 |
+
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')
|
128 |
+
|
129 |
+
tokenizer.save_pretrained(output_hf_dir)
|
130 |
+
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
|
131 |
+
|
132 |
+
# Test
|
133 |
+
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
|
134 |
+
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
|
135 |
+
print(chinese_llama_tokenizer.all_special_tokens)
|
136 |
+
print(chinese_llama_tokenizer.all_special_ids)
|
137 |
+
print(chinese_llama_tokenizer.special_tokens_map)
|
138 |
+
print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
|
139 |
+
text = '''this is a test, hello world. thisisatesthelloworld,
|
140 |
+
慕容复来到河边,姑苏慕容氏在外面丢了人。
|
141 |
+
1号店一周岁了,我们一古脑儿买了10斤零食。
|
142 |
+
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
|
143 |
+
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
|
144 |
+
print("Test text:\n", text)
|
145 |
+
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
|
146 |
+
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
main()
|
pretraining.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright 2023 XuMing([email protected]) and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
17 |
+
|
18 |
+
part of this code is adapted from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
|
19 |
+
"""
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from glob import glob
|
24 |
+
from itertools import chain
|
25 |
+
from typing import Optional, List, Dict, Any, Mapping
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from datasets import load_dataset
|
30 |
+
from loguru import logger
|
31 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
32 |
+
from sklearn.metrics import accuracy_score
|
33 |
+
from transformers import (
|
34 |
+
AutoConfig,
|
35 |
+
BloomForCausalLM,
|
36 |
+
AutoModelForCausalLM,
|
37 |
+
AutoModel,
|
38 |
+
LlamaTokenizer,
|
39 |
+
LlamaForCausalLM,
|
40 |
+
BloomTokenizerFast,
|
41 |
+
AutoTokenizer,
|
42 |
+
HfArgumentParser,
|
43 |
+
Trainer,
|
44 |
+
TrainingArguments,
|
45 |
+
is_torch_tpu_available,
|
46 |
+
set_seed,
|
47 |
+
)
|
48 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
49 |
+
from transformers.utils.versions import require_version
|
50 |
+
|
51 |
+
MODEL_CLASSES = {
|
52 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
53 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
54 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
55 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
56 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class ModelArguments:
|
62 |
+
"""
|
63 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
64 |
+
"""
|
65 |
+
|
66 |
+
model_type: str = field(
|
67 |
+
default=None,
|
68 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
69 |
+
)
|
70 |
+
model_name_or_path: Optional[str] = field(
|
71 |
+
default=None,
|
72 |
+
metadata={
|
73 |
+
"help": (
|
74 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
75 |
+
)
|
76 |
+
},
|
77 |
+
)
|
78 |
+
tokenizer_name_or_path: Optional[str] = field(
|
79 |
+
default=None,
|
80 |
+
metadata={
|
81 |
+
"help": (
|
82 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
83 |
+
)
|
84 |
+
},
|
85 |
+
)
|
86 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
87 |
+
cache_dir: Optional[str] = field(
|
88 |
+
default=None,
|
89 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
90 |
+
)
|
91 |
+
use_fast_tokenizer: bool = field(
|
92 |
+
default=False,
|
93 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
94 |
+
)
|
95 |
+
torch_dtype: Optional[str] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={
|
98 |
+
"help": (
|
99 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
100 |
+
"dtype will be automatically derived from the model's weights."
|
101 |
+
),
|
102 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
103 |
+
},
|
104 |
+
)
|
105 |
+
device_map: Optional[str] = field(
|
106 |
+
default="auto",
|
107 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
108 |
+
)
|
109 |
+
trust_remote_code: bool = field(
|
110 |
+
default=True,
|
111 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
112 |
+
)
|
113 |
+
|
114 |
+
def __post_init__(self):
|
115 |
+
if self.model_type is None:
|
116 |
+
raise ValueError(
|
117 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
118 |
+
MODEL_CLASSES.keys()))
|
119 |
+
if self.model_name_or_path is None:
|
120 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class DataTrainingArguments:
|
125 |
+
"""
|
126 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
127 |
+
"""
|
128 |
+
|
129 |
+
dataset_name: Optional[str] = field(
|
130 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
131 |
+
)
|
132 |
+
dataset_config_name: Optional[str] = field(
|
133 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
134 |
+
)
|
135 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train text data file folder."})
|
136 |
+
validation_file_dir: Optional[str] = field(
|
137 |
+
default=None,
|
138 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on text file folder."},
|
139 |
+
)
|
140 |
+
max_train_samples: Optional[int] = field(
|
141 |
+
default=None,
|
142 |
+
metadata={
|
143 |
+
"help": (
|
144 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
145 |
+
"value if set."
|
146 |
+
)
|
147 |
+
},
|
148 |
+
)
|
149 |
+
max_eval_samples: Optional[int] = field(
|
150 |
+
default=None,
|
151 |
+
metadata={
|
152 |
+
"help": (
|
153 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
154 |
+
"value if set."
|
155 |
+
)
|
156 |
+
},
|
157 |
+
)
|
158 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
159 |
+
block_size: Optional[int] = field(
|
160 |
+
default=1024,
|
161 |
+
metadata={
|
162 |
+
"help": (
|
163 |
+
"Optional input sequence length after tokenization. "
|
164 |
+
"The training dataset will be truncated in block of this size for training. "
|
165 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
166 |
+
)
|
167 |
+
},
|
168 |
+
)
|
169 |
+
overwrite_cache: bool = field(
|
170 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
171 |
+
)
|
172 |
+
validation_split_percentage: Optional[int] = field(
|
173 |
+
default=1,
|
174 |
+
metadata={
|
175 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
176 |
+
},
|
177 |
+
)
|
178 |
+
preprocessing_num_workers: Optional[int] = field(
|
179 |
+
default=None,
|
180 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
181 |
+
)
|
182 |
+
keep_linebreaks: bool = field(
|
183 |
+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
184 |
+
)
|
185 |
+
|
186 |
+
def __post_init__(self):
|
187 |
+
if self.streaming:
|
188 |
+
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
189 |
+
|
190 |
+
|
191 |
+
@dataclass
|
192 |
+
class PeftArguments(TrainingArguments):
|
193 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
194 |
+
target_modules: Optional[str] = field(default="all")
|
195 |
+
lora_rank: Optional[int] = field(default=8)
|
196 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
197 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
198 |
+
modules_to_save: Optional[str] = field(default=None)
|
199 |
+
peft_path: Optional[str] = field(default=None)
|
200 |
+
|
201 |
+
|
202 |
+
def accuracy(predictions, references, normalize=True, sample_weight=None):
|
203 |
+
return {
|
204 |
+
"accuracy": float(accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight))
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
def compute_metrics(eval_preds):
|
209 |
+
preds, labels = eval_preds
|
210 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
211 |
+
# by preprocess_logits_for_metrics, we need to shift the labels
|
212 |
+
labels = labels[:, 1:].reshape(-1)
|
213 |
+
preds = preds[:, :-1].reshape(-1)
|
214 |
+
return accuracy(predictions=preds, references=labels)
|
215 |
+
|
216 |
+
|
217 |
+
def preprocess_logits_for_metrics(logits, labels):
|
218 |
+
if isinstance(logits, tuple):
|
219 |
+
# Depending on the model and config, logits may contain extra tensors,
|
220 |
+
# like past_key_values, but logits always come first
|
221 |
+
logits = logits[0]
|
222 |
+
return logits.argmax(dim=-1)
|
223 |
+
|
224 |
+
|
225 |
+
def fault_tolerance_data_collator(features: List) -> Dict[str, Any]:
|
226 |
+
if not isinstance(features[0], Mapping):
|
227 |
+
features = [vars(f) for f in features]
|
228 |
+
first = features[0]
|
229 |
+
batch = {}
|
230 |
+
|
231 |
+
# Special handling for labels.
|
232 |
+
# Ensure that tensor is created with the correct type
|
233 |
+
if "label" in first and first["label"] is not None:
|
234 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
235 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
236 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
237 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
238 |
+
if isinstance(first["label_ids"], torch.Tensor):
|
239 |
+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
240 |
+
else:
|
241 |
+
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
242 |
+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
243 |
+
|
244 |
+
# Handling of all other possible keys.
|
245 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
246 |
+
try:
|
247 |
+
for k, v in first.items():
|
248 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
249 |
+
if isinstance(v, torch.Tensor):
|
250 |
+
batch[k] = torch.stack([f[k] for f in features])
|
251 |
+
elif isinstance(v, np.ndarray):
|
252 |
+
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
|
253 |
+
else:
|
254 |
+
batch[k] = torch.tensor([f[k] for f in features])
|
255 |
+
except ValueError: # quick fix by simply take the first example
|
256 |
+
for k, v in first.items():
|
257 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
258 |
+
if isinstance(v, torch.Tensor):
|
259 |
+
batch[k] = torch.stack([features[0][k]] * len(features))
|
260 |
+
elif isinstance(v, np.ndarray):
|
261 |
+
batch[k] = torch.tensor(np.stack([features[0][k]] * len(features)))
|
262 |
+
else:
|
263 |
+
batch[k] = torch.tensor([features[0][k]] * len(features))
|
264 |
+
|
265 |
+
return batch
|
266 |
+
|
267 |
+
|
268 |
+
class GroupTextsBuilder:
|
269 |
+
def __init__(self, max_seq_length):
|
270 |
+
self.max_seq_length = max_seq_length
|
271 |
+
|
272 |
+
def __call__(self, examples):
|
273 |
+
# Concatenate all texts.
|
274 |
+
firsts = {k: examples[k][0][0] for k in examples.keys()}
|
275 |
+
lasts = {k: examples[k][0][-1] for k in examples.keys()}
|
276 |
+
contents = {k: sum([vi[1:-1] for vi in v], []) for k, v in examples.items()}
|
277 |
+
total_length = len(contents[list(examples.keys())[0]])
|
278 |
+
|
279 |
+
content_length = self.max_seq_length - 2
|
280 |
+
if total_length >= content_length:
|
281 |
+
total_length = (total_length // content_length) * content_length
|
282 |
+
# Split by chunks of max_len.
|
283 |
+
result = {
|
284 |
+
k: [[firsts[k]] + t[i: i + content_length] + [lasts[k]] for i in range(0, total_length, content_length)] for
|
285 |
+
k, t in contents.items()}
|
286 |
+
return result
|
287 |
+
|
288 |
+
|
289 |
+
class SavePeftModelTrainer(Trainer):
|
290 |
+
"""
|
291 |
+
Trainer for lora models
|
292 |
+
"""
|
293 |
+
|
294 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
295 |
+
"""Save the LoRA model."""
|
296 |
+
os.makedirs(output_dir, exist_ok=True)
|
297 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
298 |
+
self.model.save_pretrained(output_dir)
|
299 |
+
|
300 |
+
|
301 |
+
def save_model(output_dir, model, tokenizer, args):
|
302 |
+
"""Save the model and the tokenizer."""
|
303 |
+
os.makedirs(output_dir, exist_ok=True)
|
304 |
+
|
305 |
+
# Take care of distributed/parallel training
|
306 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
307 |
+
model_to_save.save_pretrained(output_dir)
|
308 |
+
tokenizer.save_pretrained(output_dir)
|
309 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
310 |
+
|
311 |
+
|
312 |
+
def print_trainable_parameters(model):
|
313 |
+
"""
|
314 |
+
Prints the number of trainable parameters in the model.
|
315 |
+
"""
|
316 |
+
trainable_params = 0
|
317 |
+
all_param = 0
|
318 |
+
for _, param in model.named_parameters():
|
319 |
+
all_param += param.numel()
|
320 |
+
if param.requires_grad:
|
321 |
+
trainable_params += param.numel()
|
322 |
+
print(
|
323 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
328 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
329 |
+
cls = torch.nn.Linear
|
330 |
+
if int4 or int8:
|
331 |
+
import bitsandbytes as bnb
|
332 |
+
if int4:
|
333 |
+
cls = bnb.nn.Linear4bit
|
334 |
+
elif int8:
|
335 |
+
cls = bnb.nn.Linear8bitLt
|
336 |
+
lora_module_names = set()
|
337 |
+
for name, module in peft_model.named_modules():
|
338 |
+
if isinstance(module, cls):
|
339 |
+
# last layer is not add to lora_module_names
|
340 |
+
if 'lm_head' in name:
|
341 |
+
continue
|
342 |
+
names = name.split('.')
|
343 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
344 |
+
return sorted(lora_module_names)
|
345 |
+
|
346 |
+
|
347 |
+
def main():
|
348 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
349 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
350 |
+
|
351 |
+
logger.info(f"Model args: {model_args}")
|
352 |
+
logger.info(f"Data args: {data_args}")
|
353 |
+
logger.info(f"Training args: {training_args}")
|
354 |
+
logger.info(
|
355 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
356 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
357 |
+
)
|
358 |
+
|
359 |
+
# Set seed before initializing model.
|
360 |
+
set_seed(training_args.seed)
|
361 |
+
|
362 |
+
# Load tokenizer
|
363 |
+
if not model_args.model_type:
|
364 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
365 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
366 |
+
|
367 |
+
tokenizer_kwargs = {
|
368 |
+
"cache_dir": model_args.cache_dir,
|
369 |
+
"use_fast": model_args.use_fast_tokenizer,
|
370 |
+
"trust_remote_code": model_args.trust_remote_code,
|
371 |
+
}
|
372 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
373 |
+
if not tokenizer_name_or_path:
|
374 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
375 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
376 |
+
|
377 |
+
# Preprocessing the datasets.
|
378 |
+
def tokenize_function(examples):
|
379 |
+
return tokenizer(examples["text"])
|
380 |
+
|
381 |
+
if data_args.block_size is None:
|
382 |
+
block_size = tokenizer.model_max_length
|
383 |
+
if block_size > 2048:
|
384 |
+
logger.warning(
|
385 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
386 |
+
" of 2048. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
387 |
+
" override this default with `--block_size xxx`."
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
if data_args.block_size > tokenizer.model_max_length:
|
391 |
+
logger.warning(
|
392 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
393 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
394 |
+
)
|
395 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
396 |
+
|
397 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
398 |
+
def group_texts(examples):
|
399 |
+
# Concatenate all texts.
|
400 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
401 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
402 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
403 |
+
# customize this part to your needs.
|
404 |
+
if total_length >= block_size:
|
405 |
+
total_length = (total_length // block_size) * block_size
|
406 |
+
# Split by chunks of max_len.
|
407 |
+
result = {
|
408 |
+
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
409 |
+
for k, t in concatenated_examples.items()
|
410 |
+
}
|
411 |
+
result["labels"] = result["input_ids"].copy()
|
412 |
+
return result
|
413 |
+
|
414 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
415 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
416 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
417 |
+
#
|
418 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
419 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
420 |
+
#
|
421 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
422 |
+
# download the dataset.
|
423 |
+
if data_args.dataset_name is not None:
|
424 |
+
# Downloading and loading a dataset from the hub.
|
425 |
+
raw_datasets = load_dataset(
|
426 |
+
data_args.dataset_name,
|
427 |
+
data_args.dataset_config_name,
|
428 |
+
cache_dir=model_args.cache_dir,
|
429 |
+
streaming=data_args.streaming,
|
430 |
+
)
|
431 |
+
if "validation" not in raw_datasets.keys():
|
432 |
+
raw_datasets["validation"] = load_dataset(
|
433 |
+
data_args.dataset_name,
|
434 |
+
data_args.dataset_config_name,
|
435 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
436 |
+
cache_dir=model_args.cache_dir,
|
437 |
+
streaming=data_args.streaming,
|
438 |
+
)
|
439 |
+
raw_datasets["train"] = load_dataset(
|
440 |
+
data_args.dataset_name,
|
441 |
+
data_args.dataset_config_name,
|
442 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
443 |
+
cache_dir=model_args.cache_dir,
|
444 |
+
streaming=data_args.streaming,
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
data_files = {}
|
448 |
+
dataset_args = {}
|
449 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
450 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
|
451 |
+
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
452 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
453 |
+
logger.info(f"train files: {train_data_files}")
|
454 |
+
# Train data files must be same type, e.g. all txt or all jsonl
|
455 |
+
types = [f.split('.')[-1] for f in train_data_files]
|
456 |
+
if len(set(types)) > 1:
|
457 |
+
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
|
458 |
+
data_files["train"] = train_data_files
|
459 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
460 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.txt', recursive=True) + glob(
|
461 |
+
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
462 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
463 |
+
logger.info(f"eval files: {eval_data_files}")
|
464 |
+
data_files["validation"] = eval_data_files
|
465 |
+
# Train data files must be same type, e.g. all txt or all jsonl
|
466 |
+
types = [f.split('.')[-1] for f in eval_data_files]
|
467 |
+
if len(set(types)) > 1:
|
468 |
+
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
|
469 |
+
extension = "text" if data_files["train"][0].endswith('txt') else 'json'
|
470 |
+
if extension == "text":
|
471 |
+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
472 |
+
raw_datasets = load_dataset(
|
473 |
+
extension,
|
474 |
+
data_files=data_files,
|
475 |
+
cache_dir=model_args.cache_dir,
|
476 |
+
**dataset_args,
|
477 |
+
)
|
478 |
+
|
479 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
480 |
+
if "validation" not in raw_datasets.keys():
|
481 |
+
raw_datasets["validation"] = load_dataset(
|
482 |
+
extension,
|
483 |
+
data_files=data_files,
|
484 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
485 |
+
cache_dir=model_args.cache_dir,
|
486 |
+
**dataset_args,
|
487 |
+
)
|
488 |
+
raw_datasets["train"] = load_dataset(
|
489 |
+
extension,
|
490 |
+
data_files=data_files,
|
491 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
492 |
+
cache_dir=model_args.cache_dir,
|
493 |
+
**dataset_args,
|
494 |
+
)
|
495 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
496 |
+
|
497 |
+
# Preprocessing the datasets.
|
498 |
+
if training_args.do_train:
|
499 |
+
column_names = list(raw_datasets["train"].features)
|
500 |
+
else:
|
501 |
+
column_names = list(raw_datasets["validation"].features)
|
502 |
+
|
503 |
+
with training_args.main_process_first(desc="Dataset tokenization and grouping"):
|
504 |
+
if not data_args.streaming:
|
505 |
+
tokenized_datasets = raw_datasets.map(
|
506 |
+
tokenize_function,
|
507 |
+
batched=True,
|
508 |
+
num_proc=data_args.preprocessing_num_workers,
|
509 |
+
remove_columns=column_names,
|
510 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
511 |
+
desc="Running tokenizer on dataset",
|
512 |
+
)
|
513 |
+
lm_datasets = tokenized_datasets.map(
|
514 |
+
group_texts,
|
515 |
+
batched=True,
|
516 |
+
num_proc=data_args.preprocessing_num_workers,
|
517 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
518 |
+
desc=f"Grouping texts in chunks of {block_size}",
|
519 |
+
)
|
520 |
+
else:
|
521 |
+
tokenized_datasets = raw_datasets.map(
|
522 |
+
tokenize_function,
|
523 |
+
batched=True,
|
524 |
+
remove_columns=column_names,
|
525 |
+
)
|
526 |
+
lm_datasets = tokenized_datasets.map(
|
527 |
+
group_texts,
|
528 |
+
batched=True,
|
529 |
+
)
|
530 |
+
|
531 |
+
train_dataset = None
|
532 |
+
max_train_samples = 0
|
533 |
+
if training_args.do_train:
|
534 |
+
if "train" not in tokenized_datasets:
|
535 |
+
raise ValueError("--do_train requires a train dataset")
|
536 |
+
train_dataset = lm_datasets['train']
|
537 |
+
max_train_samples = len(train_dataset)
|
538 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
539 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
540 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
541 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
542 |
+
logger.debug("Tokenized training example:")
|
543 |
+
logger.debug(tokenizer.decode(train_dataset[0]['input_ids']))
|
544 |
+
|
545 |
+
eval_dataset = None
|
546 |
+
max_eval_samples = 0
|
547 |
+
if training_args.do_eval:
|
548 |
+
if "validation" not in tokenized_datasets:
|
549 |
+
raise ValueError("--do_eval requires a validation dataset")
|
550 |
+
eval_dataset = lm_datasets["validation"]
|
551 |
+
max_eval_samples = len(eval_dataset)
|
552 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
553 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
554 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
555 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
556 |
+
logger.debug("Tokenized eval example:")
|
557 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
|
558 |
+
|
559 |
+
# Load model
|
560 |
+
if model_args.model_type and model_args.model_name_or_path:
|
561 |
+
torch_dtype = (
|
562 |
+
model_args.torch_dtype
|
563 |
+
if model_args.torch_dtype in ["auto", None]
|
564 |
+
else getattr(torch, model_args.torch_dtype)
|
565 |
+
)
|
566 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
567 |
+
ddp = world_size != 1
|
568 |
+
if ddp:
|
569 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
570 |
+
|
571 |
+
config = config_class.from_pretrained(
|
572 |
+
model_args.model_name_or_path,
|
573 |
+
torch_dtype=torch_dtype,
|
574 |
+
trust_remote_code=model_args.trust_remote_code,
|
575 |
+
cache_dir=model_args.cache_dir
|
576 |
+
)
|
577 |
+
model = model_class.from_pretrained(
|
578 |
+
model_args.model_name_or_path,
|
579 |
+
config=config,
|
580 |
+
load_in_8bit=model_args.load_in_8bit,
|
581 |
+
device_map=model_args.device_map,
|
582 |
+
trust_remote_code=model_args.trust_remote_code,
|
583 |
+
)
|
584 |
+
else:
|
585 |
+
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")
|
586 |
+
|
587 |
+
if training_args.use_peft:
|
588 |
+
if training_args.peft_path is not None:
|
589 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
590 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
591 |
+
else:
|
592 |
+
logger.info("Init new peft model")
|
593 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
594 |
+
if target_modules and 'all' in target_modules:
|
595 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
596 |
+
modules_to_save = training_args.modules_to_save
|
597 |
+
if modules_to_save is not None:
|
598 |
+
modules_to_save = modules_to_save.split(',')
|
599 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
600 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
601 |
+
peft_config = LoraConfig(
|
602 |
+
task_type=TaskType.CAUSAL_LM,
|
603 |
+
target_modules=target_modules,
|
604 |
+
inference_mode=False,
|
605 |
+
r=training_args.lora_rank,
|
606 |
+
lora_alpha=training_args.lora_alpha,
|
607 |
+
lora_dropout=training_args.lora_dropout,
|
608 |
+
modules_to_save=modules_to_save)
|
609 |
+
model = get_peft_model(model, peft_config)
|
610 |
+
if model_args.load_in_8bit:
|
611 |
+
model = prepare_model_for_int8_training(model)
|
612 |
+
model.print_trainable_parameters()
|
613 |
+
else:
|
614 |
+
logger.info("Full parameters training")
|
615 |
+
model = model.float()
|
616 |
+
print_trainable_parameters(model)
|
617 |
+
|
618 |
+
# Initialize our Trainer
|
619 |
+
if training_args.gradient_checkpointing:
|
620 |
+
model.gradient_checkpointing_enable()
|
621 |
+
model.config.use_cache = False
|
622 |
+
else:
|
623 |
+
model.config.use_cache = True
|
624 |
+
model.enable_input_require_grads()
|
625 |
+
if not ddp and torch.cuda.device_count() > 1:
|
626 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
627 |
+
model.is_parallelizable = True
|
628 |
+
model.model_parallel = True
|
629 |
+
|
630 |
+
trainer = SavePeftModelTrainer(
|
631 |
+
model=model,
|
632 |
+
args=training_args,
|
633 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
634 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
635 |
+
tokenizer=tokenizer,
|
636 |
+
data_collator=fault_tolerance_data_collator,
|
637 |
+
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
638 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
639 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
640 |
+
else None,
|
641 |
+
)
|
642 |
+
|
643 |
+
# Training
|
644 |
+
if training_args.do_train:
|
645 |
+
logger.info("*** Train ***")
|
646 |
+
logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
|
647 |
+
checkpoint = None
|
648 |
+
if training_args.resume_from_checkpoint is not None:
|
649 |
+
checkpoint = training_args.resume_from_checkpoint
|
650 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
651 |
+
|
652 |
+
metrics = train_result.metrics
|
653 |
+
metrics["train_samples"] = max_train_samples
|
654 |
+
logger.debug(f"Training metrics: {metrics}")
|
655 |
+
trainer.log_metrics("train", metrics)
|
656 |
+
trainer.save_metrics("train", metrics)
|
657 |
+
trainer.save_state()
|
658 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
659 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
660 |
+
|
661 |
+
# Evaluation
|
662 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
663 |
+
logger.info("*** Evaluate ***")
|
664 |
+
metrics = trainer.evaluate()
|
665 |
+
|
666 |
+
metrics["eval_samples"] = max_eval_samples
|
667 |
+
try:
|
668 |
+
perplexity = math.exp(metrics["eval_loss"])
|
669 |
+
except OverflowError:
|
670 |
+
perplexity = float("inf")
|
671 |
+
metrics["perplexity"] = perplexity
|
672 |
+
logger.debug(f"Eval metrics: {metrics}")
|
673 |
+
trainer.log_metrics("eval", metrics)
|
674 |
+
trainer.save_metrics("eval", metrics)
|
675 |
+
|
676 |
+
|
677 |
+
if __name__ == "__main__":
|
678 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
loguru
|
2 |
+
transformers>=4.30.1
|
3 |
+
sentencepiece
|
4 |
+
datasets
|
5 |
+
tqdm
|
6 |
+
tensorboard
|
7 |
+
tqdm>=4.47.0
|
8 |
+
peft>=0.5.0
|
9 |
+
accelerate>=0.20.3
|
10 |
+
trl>=0.6.0
|
reward_modeling.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from glob import glob
|
11 |
+
from typing import Any, List, Union, Optional, Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from datasets import load_dataset
|
15 |
+
from loguru import logger
|
16 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
17 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from transformers import (
|
20 |
+
AutoConfig,
|
21 |
+
PreTrainedTokenizerBase,
|
22 |
+
BloomForSequenceClassification,
|
23 |
+
LlamaForSequenceClassification,
|
24 |
+
LlamaTokenizer,
|
25 |
+
BloomTokenizerFast,
|
26 |
+
AlbertForSequenceClassification,
|
27 |
+
BertForSequenceClassification,
|
28 |
+
BertTokenizer,
|
29 |
+
AutoTokenizer,
|
30 |
+
RobertaForSequenceClassification,
|
31 |
+
AutoModelForSequenceClassification,
|
32 |
+
RobertaTokenizer,
|
33 |
+
HfArgumentParser,
|
34 |
+
Trainer,
|
35 |
+
TrainingArguments,
|
36 |
+
set_seed,
|
37 |
+
)
|
38 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
39 |
+
|
40 |
+
MODEL_CLASSES = {
|
41 |
+
"bert": (AutoConfig, BertForSequenceClassification, BertTokenizer),
|
42 |
+
"roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
43 |
+
"albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer),
|
44 |
+
"bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast),
|
45 |
+
"llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer),
|
46 |
+
"auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer),
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class ModelArguments:
|
52 |
+
"""
|
53 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
54 |
+
"""
|
55 |
+
|
56 |
+
model_type: str = field(
|
57 |
+
default=None,
|
58 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
59 |
+
)
|
60 |
+
model_name_or_path: Optional[str] = field(
|
61 |
+
default=None,
|
62 |
+
metadata={
|
63 |
+
"help": (
|
64 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
65 |
+
)
|
66 |
+
},
|
67 |
+
)
|
68 |
+
tokenizer_name_or_path: Optional[str] = field(
|
69 |
+
default=None,
|
70 |
+
metadata={
|
71 |
+
"help": (
|
72 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
73 |
+
)
|
74 |
+
},
|
75 |
+
)
|
76 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
77 |
+
cache_dir: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
80 |
+
)
|
81 |
+
use_fast_tokenizer: bool = field(
|
82 |
+
default=False,
|
83 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
84 |
+
)
|
85 |
+
torch_dtype: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": (
|
89 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
90 |
+
"dtype will be automatically derived from the model's weights."
|
91 |
+
),
|
92 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
93 |
+
},
|
94 |
+
)
|
95 |
+
device_map: Optional[str] = field(
|
96 |
+
default="auto",
|
97 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
98 |
+
)
|
99 |
+
trust_remote_code: bool = field(
|
100 |
+
default=True,
|
101 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
102 |
+
)
|
103 |
+
|
104 |
+
def __post_init__(self):
|
105 |
+
if self.model_type is None:
|
106 |
+
raise ValueError(
|
107 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
108 |
+
MODEL_CLASSES.keys()))
|
109 |
+
if self.model_name_or_path is None:
|
110 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class DataTrainingArguments:
|
115 |
+
"""
|
116 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
117 |
+
"""
|
118 |
+
|
119 |
+
dataset_name: Optional[str] = field(
|
120 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
121 |
+
)
|
122 |
+
dataset_config_name: Optional[str] = field(
|
123 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
124 |
+
)
|
125 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
126 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
127 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
128 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
129 |
+
max_train_samples: Optional[int] = field(
|
130 |
+
default=None,
|
131 |
+
metadata={
|
132 |
+
"help": (
|
133 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
134 |
+
"value if set."
|
135 |
+
)
|
136 |
+
},
|
137 |
+
)
|
138 |
+
max_eval_samples: Optional[int] = field(
|
139 |
+
default=None,
|
140 |
+
metadata={
|
141 |
+
"help": (
|
142 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
143 |
+
"value if set."
|
144 |
+
)
|
145 |
+
},
|
146 |
+
)
|
147 |
+
overwrite_cache: bool = field(
|
148 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
149 |
+
)
|
150 |
+
validation_split_percentage: Optional[int] = field(
|
151 |
+
default=1,
|
152 |
+
metadata={
|
153 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
154 |
+
},
|
155 |
+
)
|
156 |
+
preprocessing_num_workers: Optional[int] = field(
|
157 |
+
default=4,
|
158 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
@dataclass
|
163 |
+
class PeftArguments(TrainingArguments):
|
164 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
165 |
+
target_modules: Optional[str] = field(default="all")
|
166 |
+
lora_rank: Optional[int] = field(default=8)
|
167 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
168 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
169 |
+
modules_to_save: Optional[str] = field(default=None)
|
170 |
+
peft_path: Optional[str] = field(default=None)
|
171 |
+
|
172 |
+
|
173 |
+
def compute_metrics(eval_preds):
|
174 |
+
preds, labels = eval_preds
|
175 |
+
# Here, predictions is rewards_chosen and rewards_rejected.
|
176 |
+
if isinstance(preds, torch.Tensor):
|
177 |
+
preds = preds.detach().cpu().numpy()
|
178 |
+
if isinstance(labels, torch.Tensor):
|
179 |
+
labels = labels.detach().cpu().numpy()
|
180 |
+
# MSE
|
181 |
+
mse = mean_squared_error(labels, preds)
|
182 |
+
# MAE
|
183 |
+
mae = mean_absolute_error(labels, preds)
|
184 |
+
|
185 |
+
return {"mse": mse, "mae": mae}
|
186 |
+
|
187 |
+
|
188 |
+
@dataclass
|
189 |
+
class RewardDataCollatorWithPadding:
|
190 |
+
"""We need to define a special data collator that batches the data in our chosen vs rejected format"""
|
191 |
+
tokenizer: PreTrainedTokenizerBase
|
192 |
+
padding: Union[bool, str] = True
|
193 |
+
max_length: Optional[int] = None
|
194 |
+
pad_to_multiple_of: Optional[int] = None
|
195 |
+
return_tensors: str = "pt"
|
196 |
+
|
197 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
198 |
+
features_chosen = []
|
199 |
+
features_rejected = []
|
200 |
+
for feature in features:
|
201 |
+
features_chosen.append(
|
202 |
+
{
|
203 |
+
"input_ids": feature["input_ids_chosen"],
|
204 |
+
"attention_mask": feature["attention_mask_chosen"],
|
205 |
+
}
|
206 |
+
)
|
207 |
+
features_rejected.append(
|
208 |
+
{
|
209 |
+
"input_ids": feature["input_ids_rejected"],
|
210 |
+
"attention_mask": feature["attention_mask_rejected"],
|
211 |
+
}
|
212 |
+
)
|
213 |
+
batch_chosen = self.tokenizer.pad(
|
214 |
+
features_chosen,
|
215 |
+
padding=self.padding,
|
216 |
+
max_length=self.max_length,
|
217 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
218 |
+
return_tensors=self.return_tensors,
|
219 |
+
)
|
220 |
+
batch_rejected = self.tokenizer.pad(
|
221 |
+
features_rejected,
|
222 |
+
padding=self.padding,
|
223 |
+
max_length=self.max_length,
|
224 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
225 |
+
return_tensors=self.return_tensors,
|
226 |
+
)
|
227 |
+
batch = {
|
228 |
+
"input_ids_chosen": batch_chosen["input_ids"],
|
229 |
+
"attention_mask_chosen": batch_chosen["attention_mask"],
|
230 |
+
"input_ids_rejected": batch_rejected["input_ids"],
|
231 |
+
"attention_mask_rejected": batch_rejected["attention_mask"],
|
232 |
+
"return_loss": True,
|
233 |
+
}
|
234 |
+
return batch
|
235 |
+
|
236 |
+
|
237 |
+
class RewardTrainer(Trainer):
|
238 |
+
"""
|
239 |
+
Trainer for reward models
|
240 |
+
Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
|
241 |
+
"""
|
242 |
+
|
243 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
244 |
+
rewards_chosen = model(input_ids=inputs["input_ids_chosen"],
|
245 |
+
attention_mask=inputs["attention_mask_chosen"])[0]
|
246 |
+
rewards_rejected = model(input_ids=inputs["input_ids_rejected"],
|
247 |
+
attention_mask=inputs["attention_mask_rejected"])[0]
|
248 |
+
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
249 |
+
if return_outputs:
|
250 |
+
return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected}
|
251 |
+
return loss
|
252 |
+
|
253 |
+
def evaluate(
|
254 |
+
self,
|
255 |
+
eval_dataset: Optional[Dataset] = None,
|
256 |
+
ignore_keys: Optional[List[str]] = None,
|
257 |
+
metric_key_prefix: str = "eval",
|
258 |
+
) -> Dict[str, float]:
|
259 |
+
if eval_dataset is None:
|
260 |
+
eval_dataset = self.eval_dataset
|
261 |
+
return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
262 |
+
|
263 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
264 |
+
# Prepare inputs for chosen and rejected separately
|
265 |
+
device = model.device
|
266 |
+
|
267 |
+
inputs_chosen = {
|
268 |
+
"input_ids": inputs["input_ids_chosen"].to(device),
|
269 |
+
"attention_mask": inputs["attention_mask_chosen"].to(device),
|
270 |
+
}
|
271 |
+
outputs_chosen = model(**inputs_chosen)
|
272 |
+
rewards_chosen = outputs_chosen.logits.detach()
|
273 |
+
|
274 |
+
inputs_rejected = {
|
275 |
+
"input_ids": inputs["input_ids_rejected"].to(device),
|
276 |
+
"attention_mask": inputs["attention_mask_rejected"].to(device),
|
277 |
+
}
|
278 |
+
outputs_rejected = model(**inputs_rejected)
|
279 |
+
rewards_rejected = outputs_rejected.logits.detach()
|
280 |
+
|
281 |
+
# Keep the compute_loss method
|
282 |
+
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
283 |
+
if prediction_loss_only:
|
284 |
+
return (loss, None, None)
|
285 |
+
|
286 |
+
return (loss, rewards_chosen, rewards_rejected)
|
287 |
+
|
288 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
289 |
+
"""Save the LoRA model."""
|
290 |
+
os.makedirs(output_dir, exist_ok=True)
|
291 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
292 |
+
self.model.save_pretrained(output_dir)
|
293 |
+
|
294 |
+
|
295 |
+
def save_model(output_dir, model, tokenizer, args):
|
296 |
+
"""Save the model and the tokenizer."""
|
297 |
+
os.makedirs(output_dir, exist_ok=True)
|
298 |
+
|
299 |
+
# Take care of distributed/parallel training
|
300 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
301 |
+
model_to_save.save_pretrained(output_dir)
|
302 |
+
tokenizer.save_pretrained(output_dir)
|
303 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
304 |
+
|
305 |
+
|
306 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
307 |
+
"""Cast the output of the model to float"""
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
return super().forward(x).to(torch.float32)
|
311 |
+
|
312 |
+
|
313 |
+
def print_trainable_parameters(model):
|
314 |
+
"""
|
315 |
+
Prints the number of trainable parameters in the model.
|
316 |
+
"""
|
317 |
+
trainable_params = 0
|
318 |
+
all_param = 0
|
319 |
+
for _, param in model.named_parameters():
|
320 |
+
all_param += param.numel()
|
321 |
+
if param.requires_grad:
|
322 |
+
trainable_params += param.numel()
|
323 |
+
print(
|
324 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
329 |
+
cls = torch.nn.Linear
|
330 |
+
if int4 or int8:
|
331 |
+
import bitsandbytes as bnb
|
332 |
+
if int4:
|
333 |
+
cls = bnb.nn.Linear4bit
|
334 |
+
elif int8:
|
335 |
+
cls = bnb.nn.Linear8bitLt
|
336 |
+
lora_module_names = set()
|
337 |
+
for name, module in peft_model.named_modules():
|
338 |
+
if isinstance(module, cls):
|
339 |
+
# last layer is not add to lora_module_names
|
340 |
+
if 'lm_head' in name:
|
341 |
+
continue
|
342 |
+
if 'score' in name:
|
343 |
+
continue
|
344 |
+
names = name.split('.')
|
345 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
346 |
+
return sorted(lora_module_names)
|
347 |
+
|
348 |
+
|
349 |
+
def main():
|
350 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
351 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
352 |
+
|
353 |
+
logger.info(f"Model args: {model_args}")
|
354 |
+
logger.info(f"Data args: {data_args}")
|
355 |
+
logger.info(f"Training args: {training_args}")
|
356 |
+
logger.info(
|
357 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
358 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
359 |
+
)
|
360 |
+
|
361 |
+
# Set seed before initializing model.
|
362 |
+
set_seed(training_args.seed)
|
363 |
+
|
364 |
+
# Load model
|
365 |
+
if not model_args.model_type:
|
366 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
367 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
368 |
+
if model_args.model_name_or_path:
|
369 |
+
torch_dtype = (
|
370 |
+
model_args.torch_dtype
|
371 |
+
if model_args.torch_dtype in ["auto", None]
|
372 |
+
else getattr(torch, model_args.torch_dtype)
|
373 |
+
)
|
374 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
375 |
+
if world_size > 1:
|
376 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
377 |
+
config = config_class.from_pretrained(
|
378 |
+
model_args.model_name_or_path,
|
379 |
+
num_labels=1,
|
380 |
+
torch_dtype=torch_dtype,
|
381 |
+
trust_remote_code=model_args.trust_remote_code,
|
382 |
+
cache_dir=model_args.cache_dir
|
383 |
+
)
|
384 |
+
if model_args.model_type in ['bloom', 'llama']:
|
385 |
+
model = model_class.from_pretrained(
|
386 |
+
model_args.model_name_or_path,
|
387 |
+
config=config,
|
388 |
+
load_in_8bit=model_args.load_in_8bit,
|
389 |
+
device_map=model_args.device_map,
|
390 |
+
trust_remote_code=model_args.trust_remote_code,
|
391 |
+
)
|
392 |
+
model.score = CastOutputToFloat(model.score)
|
393 |
+
else:
|
394 |
+
model = model_class.from_pretrained(
|
395 |
+
model_args.model_name_or_path,
|
396 |
+
config=config,
|
397 |
+
cache_dir=model_args.cache_dir,
|
398 |
+
ignore_mismatched_sizes=True
|
399 |
+
)
|
400 |
+
model.to(training_args.device)
|
401 |
+
else:
|
402 |
+
raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model")
|
403 |
+
|
404 |
+
# Load tokenizer
|
405 |
+
if model_args.model_type == "bloom":
|
406 |
+
model_args.use_fast_tokenizer = True
|
407 |
+
tokenizer_kwargs = {
|
408 |
+
"cache_dir": model_args.cache_dir,
|
409 |
+
"use_fast": model_args.use_fast_tokenizer,
|
410 |
+
"trust_remote_code": model_args.trust_remote_code,
|
411 |
+
}
|
412 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
413 |
+
if not tokenizer_name_or_path:
|
414 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
415 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
416 |
+
if tokenizer.pad_token_id is None:
|
417 |
+
tokenizer.pad_token_id = 0
|
418 |
+
|
419 |
+
if training_args.use_peft:
|
420 |
+
if training_args.peft_path is not None:
|
421 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
422 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
423 |
+
else:
|
424 |
+
logger.info("Init new peft model")
|
425 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
426 |
+
if target_modules and 'all' in target_modules:
|
427 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
428 |
+
modules_to_save = training_args.modules_to_save
|
429 |
+
if modules_to_save is not None:
|
430 |
+
modules_to_save = modules_to_save.split(',')
|
431 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
432 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
433 |
+
peft_config = LoraConfig(
|
434 |
+
task_type=TaskType.SEQ_CLS,
|
435 |
+
target_modules=target_modules,
|
436 |
+
inference_mode=False,
|
437 |
+
r=training_args.lora_rank,
|
438 |
+
lora_alpha=training_args.lora_alpha,
|
439 |
+
lora_dropout=training_args.lora_dropout,
|
440 |
+
modules_to_save=modules_to_save)
|
441 |
+
model = get_peft_model(model, peft_config)
|
442 |
+
if model_args.load_in_8bit:
|
443 |
+
model = prepare_model_for_int8_training(model)
|
444 |
+
model.print_trainable_parameters()
|
445 |
+
else:
|
446 |
+
logger.info("Full parameters training")
|
447 |
+
print_trainable_parameters(model)
|
448 |
+
|
449 |
+
# Get reward dataset for tuning the reward model.
|
450 |
+
if data_args.dataset_name is not None:
|
451 |
+
# Downloading and loading a dataset from the hub.
|
452 |
+
raw_datasets = load_dataset(
|
453 |
+
data_args.dataset_name,
|
454 |
+
data_args.dataset_config_name,
|
455 |
+
cache_dir=model_args.cache_dir,
|
456 |
+
)
|
457 |
+
if "validation" not in raw_datasets.keys():
|
458 |
+
raw_datasets["validation"] = load_dataset(
|
459 |
+
data_args.dataset_name,
|
460 |
+
data_args.dataset_config_name,
|
461 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
462 |
+
cache_dir=model_args.cache_dir,
|
463 |
+
)
|
464 |
+
raw_datasets["train"] = load_dataset(
|
465 |
+
data_args.dataset_name,
|
466 |
+
data_args.dataset_config_name,
|
467 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
468 |
+
cache_dir=model_args.cache_dir,
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
data_files = {}
|
472 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
473 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
474 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
475 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
476 |
+
data_files["train"] = train_data_files
|
477 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
478 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
479 |
+
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
|
480 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
481 |
+
data_files["validation"] = eval_data_files
|
482 |
+
raw_datasets = load_dataset(
|
483 |
+
'json',
|
484 |
+
data_files=data_files,
|
485 |
+
cache_dir=model_args.cache_dir,
|
486 |
+
)
|
487 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
488 |
+
if "validation" not in raw_datasets.keys():
|
489 |
+
raw_datasets["validation"] = load_dataset(
|
490 |
+
'json',
|
491 |
+
data_files=data_files,
|
492 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
493 |
+
cache_dir=model_args.cache_dir,
|
494 |
+
)
|
495 |
+
raw_datasets["train"] = load_dataset(
|
496 |
+
'json',
|
497 |
+
data_files=data_files,
|
498 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
499 |
+
cache_dir=model_args.cache_dir,
|
500 |
+
)
|
501 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
502 |
+
|
503 |
+
# Preprocessing the datasets
|
504 |
+
full_max_length = data_args.max_source_length + data_args.max_target_length
|
505 |
+
|
506 |
+
def preprocess_reward_function(examples):
|
507 |
+
"""
|
508 |
+
Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer
|
509 |
+
and text_rejected is the other.
|
510 |
+
"""
|
511 |
+
new_examples = {
|
512 |
+
"input_ids_chosen": [],
|
513 |
+
"attention_mask_chosen": [],
|
514 |
+
"input_ids_rejected": [],
|
515 |
+
"attention_mask_rejected": [],
|
516 |
+
}
|
517 |
+
for question, chosen, rejected in zip(examples["question"], examples["response_chosen"],
|
518 |
+
examples["response_rejected"]):
|
519 |
+
tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen)
|
520 |
+
tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected)
|
521 |
+
|
522 |
+
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
523 |
+
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
524 |
+
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
525 |
+
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
526 |
+
|
527 |
+
return new_examples
|
528 |
+
|
529 |
+
train_dataset = None
|
530 |
+
max_train_samples = 0
|
531 |
+
if training_args.do_train:
|
532 |
+
if "train" not in raw_datasets:
|
533 |
+
raise ValueError("--do_train requires a train dataset")
|
534 |
+
train_dataset = raw_datasets['train']
|
535 |
+
max_train_samples = len(train_dataset)
|
536 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
537 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
538 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
539 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
540 |
+
with training_args.main_process_first(desc="Train dataset tokenization"):
|
541 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
542 |
+
preprocess_reward_function,
|
543 |
+
batched=True,
|
544 |
+
num_proc=data_args.preprocessing_num_workers,
|
545 |
+
remove_columns=train_dataset.column_names,
|
546 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
547 |
+
desc="Running tokenizer on dataset",
|
548 |
+
)
|
549 |
+
train_dataset = tokenized_dataset.filter(
|
550 |
+
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
|
551 |
+
x['input_ids_chosen']) <= full_max_length
|
552 |
+
)
|
553 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
554 |
+
logger.debug("Tokenized training example:")
|
555 |
+
logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen']))
|
556 |
+
|
557 |
+
eval_dataset = None
|
558 |
+
max_eval_samples = 0
|
559 |
+
if training_args.do_eval:
|
560 |
+
with training_args.main_process_first(desc="Eval dataset tokenization"):
|
561 |
+
if "validation" not in raw_datasets:
|
562 |
+
raise ValueError("--do_eval requires a validation dataset")
|
563 |
+
eval_dataset = raw_datasets["validation"]
|
564 |
+
max_eval_samples = len(eval_dataset)
|
565 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
566 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
567 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
568 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
569 |
+
tokenized_dataset = eval_dataset.map(
|
570 |
+
preprocess_reward_function,
|
571 |
+
batched=True,
|
572 |
+
num_proc=data_args.preprocessing_num_workers,
|
573 |
+
remove_columns=eval_dataset.column_names,
|
574 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
575 |
+
desc="Running tokenizer on dataset",
|
576 |
+
)
|
577 |
+
eval_dataset = tokenized_dataset.filter(
|
578 |
+
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
|
579 |
+
x['input_ids_chosen']) <= full_max_length
|
580 |
+
)
|
581 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
582 |
+
logger.debug("Tokenized eval example:")
|
583 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen']))
|
584 |
+
|
585 |
+
# Initialize our Trainer
|
586 |
+
if training_args.gradient_checkpointing:
|
587 |
+
model.gradient_checkpointing_enable()
|
588 |
+
model.config.use_cache = False
|
589 |
+
else:
|
590 |
+
model.config.use_cache = True
|
591 |
+
model.enable_input_require_grads()
|
592 |
+
if torch.cuda.device_count() > 1:
|
593 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
594 |
+
model.is_parallelizable = True
|
595 |
+
model.model_parallel = True
|
596 |
+
trainer = RewardTrainer(
|
597 |
+
model=model,
|
598 |
+
args=training_args,
|
599 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
600 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
601 |
+
tokenizer=tokenizer,
|
602 |
+
compute_metrics=compute_metrics,
|
603 |
+
data_collator=RewardDataCollatorWithPadding(
|
604 |
+
tokenizer=tokenizer, max_length=full_max_length, padding="max_length"
|
605 |
+
),
|
606 |
+
)
|
607 |
+
|
608 |
+
# Training
|
609 |
+
if training_args.do_train:
|
610 |
+
logger.info("*** Train ***")
|
611 |
+
logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
|
612 |
+
checkpoint = None
|
613 |
+
if training_args.resume_from_checkpoint is not None:
|
614 |
+
checkpoint = training_args.resume_from_checkpoint
|
615 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
616 |
+
|
617 |
+
metrics = train_result.metrics
|
618 |
+
metrics["train_samples"] = max_train_samples
|
619 |
+
logger.debug(f"Training metrics: {metrics}")
|
620 |
+
trainer.log_metrics("train", metrics)
|
621 |
+
trainer.save_metrics("train", metrics)
|
622 |
+
trainer.save_state()
|
623 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
624 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
625 |
+
|
626 |
+
# Evaluation
|
627 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
628 |
+
logger.info("*** Evaluate ***")
|
629 |
+
metrics = trainer.evaluate()
|
630 |
+
|
631 |
+
metrics["eval_samples"] = max_eval_samples
|
632 |
+
try:
|
633 |
+
perplexity = math.exp(metrics["eval_loss"])
|
634 |
+
except OverflowError:
|
635 |
+
perplexity = float("inf")
|
636 |
+
metrics["perplexity"] = perplexity
|
637 |
+
logger.debug(f"Eval metrics: {metrics}")
|
638 |
+
trainer.log_metrics("eval", metrics)
|
639 |
+
trainer.save_metrics("eval", metrics)
|
640 |
+
|
641 |
+
|
642 |
+
if __name__ == "__main__":
|
643 |
+
main()
|
rl_training.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing([email protected])
|
4 |
+
@description: Train a model from SFT using PPO
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from glob import glob
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset
|
14 |
+
from loguru import logger
|
15 |
+
from peft import LoraConfig, TaskType
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import (
|
18 |
+
AutoConfig,
|
19 |
+
AutoModelForSequenceClassification,
|
20 |
+
BloomForCausalLM,
|
21 |
+
AutoModelForCausalLM,
|
22 |
+
AutoModel,
|
23 |
+
LlamaTokenizer,
|
24 |
+
LlamaForCausalLM,
|
25 |
+
BloomTokenizerFast,
|
26 |
+
AutoTokenizer,
|
27 |
+
HfArgumentParser,
|
28 |
+
)
|
29 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
30 |
+
|
31 |
+
from supervised_finetuning import get_conv_template
|
32 |
+
|
33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
|
34 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
35 |
+
|
36 |
+
MODEL_CLASSES = {
|
37 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
38 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
39 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
40 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
41 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class ScriptArguments:
|
47 |
+
"""
|
48 |
+
The name of the Casual LM model we wish to fine with PPO
|
49 |
+
"""
|
50 |
+
# Model arguments
|
51 |
+
model_type: str = field(
|
52 |
+
default=None,
|
53 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
54 |
+
)
|
55 |
+
model_name_or_path: Optional[str] = field(
|
56 |
+
default=None, metadata={"help": "The model checkpoint for weights initialization."}
|
57 |
+
)
|
58 |
+
reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The reward model name"})
|
59 |
+
tokenizer_name_or_path: Optional[str] = field(
|
60 |
+
default=None, metadata={"help": "The tokenizer for weights initialization."}
|
61 |
+
)
|
62 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
63 |
+
cache_dir: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
66 |
+
)
|
67 |
+
use_fast_tokenizer: bool = field(
|
68 |
+
default=False,
|
69 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
70 |
+
)
|
71 |
+
torch_dtype: Optional[str] = field(
|
72 |
+
default=None,
|
73 |
+
metadata={
|
74 |
+
"help": (
|
75 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
76 |
+
"dtype will be automatically derived from the model's weights."
|
77 |
+
),
|
78 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
79 |
+
},
|
80 |
+
)
|
81 |
+
device_map: Optional[str] = field(
|
82 |
+
default="auto",
|
83 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
84 |
+
)
|
85 |
+
trust_remote_code: bool = field(
|
86 |
+
default=True,
|
87 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
88 |
+
)
|
89 |
+
# Dataset arguments
|
90 |
+
dataset_name: Optional[str] = field(
|
91 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
92 |
+
)
|
93 |
+
dataset_config_name: Optional[str] = field(
|
94 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
95 |
+
)
|
96 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
97 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
98 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The template name."})
|
99 |
+
batch_size: Optional[int] = field(default=8, metadata={"help": "Batch size"})
|
100 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "PPO minibatch size"})
|
101 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
102 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
103 |
+
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
|
104 |
+
max_train_samples: Optional[int] = field(
|
105 |
+
default=None,
|
106 |
+
metadata={
|
107 |
+
"help": (
|
108 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
109 |
+
"value if set."
|
110 |
+
)
|
111 |
+
},
|
112 |
+
)
|
113 |
+
max_eval_samples: Optional[int] = field(
|
114 |
+
default=None,
|
115 |
+
metadata={
|
116 |
+
"help": (
|
117 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
118 |
+
"value if set."
|
119 |
+
)
|
120 |
+
},
|
121 |
+
)
|
122 |
+
overwrite_cache: bool = field(
|
123 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
124 |
+
)
|
125 |
+
validation_split_percentage: Optional[int] = field(
|
126 |
+
default=1,
|
127 |
+
metadata={
|
128 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
129 |
+
},
|
130 |
+
)
|
131 |
+
preprocessing_num_workers: Optional[int] = field(
|
132 |
+
default=None, metadata={"help": "The number of processes to use for the preprocessing."},
|
133 |
+
)
|
134 |
+
# Training arguments
|
135 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
136 |
+
target_modules: Optional[str] = field(default=None)
|
137 |
+
lora_rank: Optional[int] = field(default=8)
|
138 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
139 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
140 |
+
modules_to_save: Optional[str] = field(default=None)
|
141 |
+
peft_path: Optional[str] = field(default=None)
|
142 |
+
|
143 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
144 |
+
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
|
145 |
+
early_stopping: Optional[bool] = field(default=False, metadata={"help": "Whether to early stop"})
|
146 |
+
target_kl: Optional[float] = field(default=0.1, metadata={"help": "The kl target for early stopping"})
|
147 |
+
reward_baseline: Optional[float] = field(
|
148 |
+
default=0.0, metadata={"help": "Baseline value that is subtracted from the reward"},
|
149 |
+
)
|
150 |
+
init_kl_coef: Optional[float] = field(
|
151 |
+
default=0.2, metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
152 |
+
)
|
153 |
+
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
154 |
+
learning_rate: Optional[float] = field(default=1.5e-5, metadata={"help": "Learning rate"})
|
155 |
+
gradient_accumulation_steps: Optional[int] = field(
|
156 |
+
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
157 |
+
)
|
158 |
+
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
|
159 |
+
output_dir: Optional[str] = field(default="outputs-rl", metadata={"help": "The output directory"})
|
160 |
+
seed: Optional[int] = field(default=0, metadata={"help": "Seed"})
|
161 |
+
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
|
162 |
+
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
|
163 |
+
|
164 |
+
def __post_init__(self):
|
165 |
+
if self.model_type is None:
|
166 |
+
raise ValueError("You must specify a valid model_type to run training.")
|
167 |
+
if self.model_name_or_path is None:
|
168 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
169 |
+
if self.reward_model_name_or_path is None:
|
170 |
+
raise ValueError("You must specify a valid reward_model_name_or_path to run training.")
|
171 |
+
|
172 |
+
|
173 |
+
def print_trainable_parameters(model):
|
174 |
+
"""
|
175 |
+
Prints the number of trainable parameters in the model.
|
176 |
+
"""
|
177 |
+
trainable_params = 0
|
178 |
+
all_param = 0
|
179 |
+
for _, param in model.named_parameters():
|
180 |
+
all_param += param.numel()
|
181 |
+
if param.requires_grad:
|
182 |
+
trainable_params += param.numel()
|
183 |
+
print(
|
184 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def get_reward_model_output(reward_model, reward_tokenizer, question, answer, device):
|
189 |
+
"""
|
190 |
+
Get the reward score for a given question and answer pair.
|
191 |
+
"""
|
192 |
+
inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device)
|
193 |
+
score = reward_model(**inputs).logits[0].cpu().detach()
|
194 |
+
|
195 |
+
return score
|
196 |
+
|
197 |
+
|
198 |
+
def calculate_rewards(reward_score_outputs, reward_baseline=0):
|
199 |
+
"""
|
200 |
+
Calculate the reward for a given score output.
|
201 |
+
:param reward_score_outputs:
|
202 |
+
:param reward_baseline:
|
203 |
+
:return:
|
204 |
+
"""
|
205 |
+
rewards = []
|
206 |
+
for score in reward_score_outputs:
|
207 |
+
if isinstance(score, torch.Tensor) and score.numel() == 1:
|
208 |
+
reward_value = score.item() - reward_baseline
|
209 |
+
rewards.append(torch.tensor(reward_value))
|
210 |
+
else:
|
211 |
+
# Use the average of the tensor elements as `score` is multiple elements
|
212 |
+
reward_value = torch.mean(score).item() - reward_baseline
|
213 |
+
rewards.append(torch.tensor(reward_value))
|
214 |
+
return rewards
|
215 |
+
|
216 |
+
|
217 |
+
def main():
|
218 |
+
parser = HfArgumentParser(ScriptArguments)
|
219 |
+
args = parser.parse_args_into_dataclasses()[0]
|
220 |
+
logger.info(f"Parse args: {args}")
|
221 |
+
|
222 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
223 |
+
if args.model_type == 'bloom':
|
224 |
+
args.use_fast_tokenizer = True
|
225 |
+
# Load tokenizer
|
226 |
+
tokenizer_kwargs = {
|
227 |
+
"cache_dir": args.cache_dir,
|
228 |
+
"use_fast": args.use_fast_tokenizer,
|
229 |
+
"trust_remote_code": args.trust_remote_code,
|
230 |
+
}
|
231 |
+
tokenizer_name_or_path = args.tokenizer_name_or_path
|
232 |
+
if not tokenizer_name_or_path:
|
233 |
+
tokenizer_name_or_path = args.model_name_or_path
|
234 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
235 |
+
if tokenizer.pad_token_id is None:
|
236 |
+
tokenizer.pad_token_id = 0 # set as the <unk> token
|
237 |
+
|
238 |
+
logger.info("Load model")
|
239 |
+
peft_config = LoraConfig(
|
240 |
+
task_type=TaskType.CAUSAL_LM,
|
241 |
+
target_modules=args.target_modules,
|
242 |
+
inference_mode=False,
|
243 |
+
r=args.lora_rank,
|
244 |
+
lora_alpha=args.lora_alpha,
|
245 |
+
lora_dropout=args.lora_dropout,
|
246 |
+
)
|
247 |
+
torch_dtype = (
|
248 |
+
args.torch_dtype
|
249 |
+
if args.torch_dtype in ["auto", None]
|
250 |
+
else getattr(torch, args.torch_dtype)
|
251 |
+
)
|
252 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
253 |
+
if world_size > 1:
|
254 |
+
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
255 |
+
config = config_class.from_pretrained(
|
256 |
+
args.model_name_or_path,
|
257 |
+
torch_dtype=torch_dtype,
|
258 |
+
trust_remote_code=args.trust_remote_code,
|
259 |
+
cache_dir=args.cache_dir
|
260 |
+
)
|
261 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
262 |
+
args.model_name_or_path,
|
263 |
+
config=config,
|
264 |
+
load_in_8bit=args.load_in_8bit,
|
265 |
+
device_map=args.device_map,
|
266 |
+
trust_remote_code=args.trust_remote_code,
|
267 |
+
peft_config=peft_config if args.use_peft else None,
|
268 |
+
)
|
269 |
+
print_trainable_parameters(model)
|
270 |
+
# Load reward model
|
271 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
272 |
+
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
273 |
+
args.reward_model_name_or_path,
|
274 |
+
config=config,
|
275 |
+
load_in_8bit=args.load_in_8bit,
|
276 |
+
trust_remote_code=args.trust_remote_code,
|
277 |
+
)
|
278 |
+
reward_model.to(device)
|
279 |
+
reward_tokenizer = AutoTokenizer.from_pretrained(
|
280 |
+
args.reward_model_name_or_path, **tokenizer_kwargs
|
281 |
+
)
|
282 |
+
|
283 |
+
# Get datasets
|
284 |
+
if args.dataset_name is not None:
|
285 |
+
# Downloading and loading a dataset from the hub.
|
286 |
+
raw_datasets = load_dataset(
|
287 |
+
args.dataset_name,
|
288 |
+
args.dataset_config_name,
|
289 |
+
cache_dir=args.cache_dir,
|
290 |
+
)
|
291 |
+
if "validation" not in raw_datasets.keys():
|
292 |
+
raw_datasets["validation"] = load_dataset(
|
293 |
+
args.dataset_name,
|
294 |
+
args.dataset_config_name,
|
295 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
296 |
+
cache_dir=args.cache_dir,
|
297 |
+
)
|
298 |
+
raw_datasets["train"] = load_dataset(
|
299 |
+
args.dataset_name,
|
300 |
+
args.dataset_config_name,
|
301 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
302 |
+
cache_dir=args.cache_dir,
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
data_files = {}
|
306 |
+
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
|
307 |
+
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
|
308 |
+
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
|
309 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
310 |
+
data_files["train"] = train_data_files
|
311 |
+
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
|
312 |
+
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
313 |
+
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
|
314 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
315 |
+
data_files["validation"] = eval_data_files
|
316 |
+
raw_datasets = load_dataset(
|
317 |
+
'json',
|
318 |
+
data_files=data_files,
|
319 |
+
cache_dir=args.cache_dir,
|
320 |
+
)
|
321 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
322 |
+
if "validation" not in raw_datasets.keys():
|
323 |
+
raw_datasets["validation"] = load_dataset(
|
324 |
+
'json',
|
325 |
+
data_files=data_files,
|
326 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
327 |
+
cache_dir=args.cache_dir,
|
328 |
+
)
|
329 |
+
raw_datasets["train"] = load_dataset(
|
330 |
+
'json',
|
331 |
+
data_files=data_files,
|
332 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
333 |
+
cache_dir=args.cache_dir,
|
334 |
+
)
|
335 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
336 |
+
|
337 |
+
# Preprocessing the datasets
|
338 |
+
max_source_length = args.max_source_length
|
339 |
+
max_target_length = args.max_target_length
|
340 |
+
prompt_template = get_conv_template(args.template_name)
|
341 |
+
|
342 |
+
def preprocess_function(examples):
|
343 |
+
new_examples = {
|
344 |
+
"query": [],
|
345 |
+
"input_ids": [],
|
346 |
+
}
|
347 |
+
roles = ["human", "gpt"]
|
348 |
+
|
349 |
+
def get_prompt(examples):
|
350 |
+
for i, source in enumerate(examples['conversations']):
|
351 |
+
if len(source) < 2:
|
352 |
+
continue
|
353 |
+
data_role = source[0].get("from", "")
|
354 |
+
if data_role not in roles or data_role != roles[0]:
|
355 |
+
# Skip the first one if it is not from human
|
356 |
+
source = source[1:]
|
357 |
+
if len(source) < 2:
|
358 |
+
continue
|
359 |
+
messages = []
|
360 |
+
for j, sentence in enumerate(source):
|
361 |
+
data_role = sentence.get("from", "")
|
362 |
+
if data_role not in roles:
|
363 |
+
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
|
364 |
+
break
|
365 |
+
if data_role == roles[j % 2]:
|
366 |
+
messages.append(sentence["value"])
|
367 |
+
if len(messages) < 2 or len(messages) % 2 != 0:
|
368 |
+
continue
|
369 |
+
# Convert the list to pairs of elements
|
370 |
+
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
|
371 |
+
yield prompt_template.get_prompt(history_messages)
|
372 |
+
|
373 |
+
for prompt in get_prompt(examples):
|
374 |
+
for i in range(len(prompt) // 2):
|
375 |
+
source_txt = prompt[2 * i]
|
376 |
+
tokenized_question = tokenizer(
|
377 |
+
source_txt, truncation=True, max_length=max_source_length, padding="max_length",
|
378 |
+
return_tensors="pt"
|
379 |
+
)
|
380 |
+
new_examples["query"].append(source_txt)
|
381 |
+
new_examples["input_ids"].append(tokenized_question["input_ids"])
|
382 |
+
|
383 |
+
return new_examples
|
384 |
+
|
385 |
+
# Preprocess the dataset
|
386 |
+
train_dataset = None
|
387 |
+
if args.do_train:
|
388 |
+
if "train" not in raw_datasets:
|
389 |
+
raise ValueError("--do_train requires a train dataset")
|
390 |
+
train_dataset = raw_datasets['train']
|
391 |
+
if args.max_train_samples is not None and args.max_train_samples > 0:
|
392 |
+
max_train_samples = min(len(train_dataset), args.max_train_samples)
|
393 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
394 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
395 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
396 |
+
preprocess_function,
|
397 |
+
batched=True,
|
398 |
+
num_proc=args.preprocessing_num_workers,
|
399 |
+
remove_columns=train_dataset.column_names,
|
400 |
+
load_from_cache_file=not args.overwrite_cache,
|
401 |
+
desc="Running tokenizer on dataset",
|
402 |
+
)
|
403 |
+
train_dataset = tokenized_dataset.filter(
|
404 |
+
lambda x: len(x['input_ids']) > 0
|
405 |
+
)
|
406 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
407 |
+
|
408 |
+
def collator(data):
|
409 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
410 |
+
|
411 |
+
output_dir = args.output_dir
|
412 |
+
config = PPOConfig(
|
413 |
+
steps=args.max_steps,
|
414 |
+
model_name=args.model_name_or_path,
|
415 |
+
learning_rate=args.learning_rate,
|
416 |
+
log_with=args.report_to,
|
417 |
+
batch_size=args.batch_size,
|
418 |
+
mini_batch_size=args.mini_batch_size,
|
419 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
420 |
+
optimize_cuda_cache=True,
|
421 |
+
early_stopping=args.early_stopping,
|
422 |
+
target_kl=args.target_kl,
|
423 |
+
seed=args.seed,
|
424 |
+
init_kl_coef=args.init_kl_coef,
|
425 |
+
adap_kl_ctrl=args.adap_kl_ctrl,
|
426 |
+
project_kwargs={"logging_dir": output_dir},
|
427 |
+
)
|
428 |
+
# Set seed before initializing value head for deterministic eval
|
429 |
+
set_seed(config.seed)
|
430 |
+
|
431 |
+
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
432 |
+
trainer = PPOTrainer(
|
433 |
+
config,
|
434 |
+
model,
|
435 |
+
ref_model=None,
|
436 |
+
tokenizer=tokenizer,
|
437 |
+
dataset=train_dataset,
|
438 |
+
data_collator=collator,
|
439 |
+
)
|
440 |
+
|
441 |
+
# These arguments are passed to the `generate` function of the PPOTrainer
|
442 |
+
generation_kwargs = {
|
443 |
+
"max_new_tokens": max_target_length,
|
444 |
+
"temperature": 1.0,
|
445 |
+
"repetition_penalty": 1.0,
|
446 |
+
"top_p": 1.0,
|
447 |
+
"do_sample": True,
|
448 |
+
}
|
449 |
+
|
450 |
+
def save_model(save_dir):
|
451 |
+
trainer.accelerator.unwrap_model(trainer.model).save_pretrained(save_dir)
|
452 |
+
trainer.tokenizer.save_pretrained(save_dir)
|
453 |
+
|
454 |
+
# Training
|
455 |
+
if args.do_train:
|
456 |
+
logger.info("*** Train ***")
|
457 |
+
total_steps = config.total_ppo_epochs
|
458 |
+
for step, batch in tqdm(enumerate(trainer.dataloader)):
|
459 |
+
if step >= total_steps:
|
460 |
+
break
|
461 |
+
question_tensors = batch["input_ids"]
|
462 |
+
question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors]
|
463 |
+
responses = []
|
464 |
+
response_tensors = []
|
465 |
+
for q_tensor in question_tensors:
|
466 |
+
response_tensor = trainer.generate(
|
467 |
+
q_tensor,
|
468 |
+
return_prompt=False,
|
469 |
+
**generation_kwargs,
|
470 |
+
)
|
471 |
+
r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0]
|
472 |
+
responses.append(r)
|
473 |
+
response_tensors.append(response_tensor.squeeze(0))
|
474 |
+
batch["response"] = responses
|
475 |
+
|
476 |
+
# Compute reward score
|
477 |
+
score_outputs = [
|
478 |
+
get_reward_model_output(reward_model, reward_tokenizer, q, r, device) for q, r in
|
479 |
+
zip(batch["query"], batch["response"])
|
480 |
+
]
|
481 |
+
rewards = calculate_rewards(score_outputs, args.reward_baseline)
|
482 |
+
|
483 |
+
# Run PPO step
|
484 |
+
try:
|
485 |
+
stats = trainer.step(question_tensors, response_tensors, rewards)
|
486 |
+
trainer.log_stats(stats, batch, rewards)
|
487 |
+
logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}")
|
488 |
+
except ValueError as e:
|
489 |
+
logger.warning(f"Failed to log stats for step {step}, because of {e}")
|
490 |
+
|
491 |
+
if step and step % args.save_steps == 0:
|
492 |
+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
493 |
+
save_model(save_dir)
|
494 |
+
# Save final model
|
495 |
+
save_model(output_dir)
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == "__main__":
|
499 |
+
main()
|
run_dpo.sh
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 python dpo_training.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/reward \
|
5 |
+
--validation_file_dir ./data/reward \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 1 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--max_train_samples 1000 \
|
12 |
+
--max_eval_samples 10 \
|
13 |
+
--max_steps 100 \
|
14 |
+
--eval_steps 20 \
|
15 |
+
--save_steps 50 \
|
16 |
+
--max_source_length 128 \
|
17 |
+
--max_target_length 128 \
|
18 |
+
--output_dir outputs-dpo-bloom-v1 \
|
19 |
+
--target_modules all \
|
20 |
+
--lora_rank 8 \
|
21 |
+
--lora_alpha 16 \
|
22 |
+
--lora_dropout 0.05 \
|
23 |
+
--torch_dtype float16 \
|
24 |
+
--fp16 True \
|
25 |
+
--device_map auto \
|
26 |
+
--report_to tensorboard \
|
27 |
+
--remove_unused_columns False \
|
28 |
+
--gradient_checkpointing True \
|
29 |
+
--cache_dir ./cache
|
run_pt.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 pretraining.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/pretrain \
|
5 |
+
--validation_file_dir ./data/pretrain \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--seed 42 \
|
12 |
+
--fp16 \
|
13 |
+
--max_train_samples 10000 \
|
14 |
+
--max_eval_samples 10 \
|
15 |
+
--num_train_epochs 0.5 \
|
16 |
+
--learning_rate 2e-4 \
|
17 |
+
--warmup_ratio 0.05 \
|
18 |
+
--weight_decay 0.01 \
|
19 |
+
--logging_strategy steps \
|
20 |
+
--logging_steps 10 \
|
21 |
+
--eval_steps 50 \
|
22 |
+
--evaluation_strategy steps \
|
23 |
+
--save_steps 500 \
|
24 |
+
--save_strategy steps \
|
25 |
+
--save_total_limit 3 \
|
26 |
+
--gradient_accumulation_steps 1 \
|
27 |
+
--preprocessing_num_workers 1 \
|
28 |
+
--block_size 1024 \
|
29 |
+
--output_dir outputs-pt-bloom-v1 \
|
30 |
+
--overwrite_output_dir \
|
31 |
+
--ddp_timeout 30000 \
|
32 |
+
--logging_first_step True \
|
33 |
+
--target_modules all \
|
34 |
+
--lora_rank 8 \
|
35 |
+
--lora_alpha 16 \
|
36 |
+
--lora_dropout 0.05 \
|
37 |
+
--torch_dtype float16 \
|
38 |
+
--device_map auto \
|
39 |
+
--report_to tensorboard \
|
40 |
+
--ddp_find_unused_parameters False \
|
41 |
+
--gradient_checkpointing True \
|
42 |
+
--cache_dir ./cache
|
run_rl.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 rl_training.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--reward_model_name_or_path OpenAssistant/reward-model-deberta-v3-large-v2 \
|
5 |
+
--torch_dtype float16 \
|
6 |
+
--device_map auto \
|
7 |
+
--train_file_dir ./data/finetune \
|
8 |
+
--validation_file_dir ./data/finetune \
|
9 |
+
--batch_size 8 \
|
10 |
+
--max_source_length 256 \
|
11 |
+
--max_target_length 256 \
|
12 |
+
--max_train_samples 1000 \
|
13 |
+
--use_peft True \
|
14 |
+
--lora_rank 8 \
|
15 |
+
--lora_alpha 16 \
|
16 |
+
--lora_dropout 0.05 \
|
17 |
+
--do_train \
|
18 |
+
--max_steps 100 \
|
19 |
+
--learning_rate 1e-5 \
|
20 |
+
--save_steps 50 \
|
21 |
+
--output_dir outputs-rl-bloom-v1 \
|
22 |
+
--early_stopping True \
|
23 |
+
--target_kl 0.1 \
|
24 |
+
--reward_baseline 0.0
|
run_rm.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 python reward_modeling.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/reward \
|
5 |
+
--validation_file_dir ./data/reward \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--use_peft True \
|
10 |
+
--seed 42 \
|
11 |
+
--max_train_samples 1000 \
|
12 |
+
--max_eval_samples 10 \
|
13 |
+
--num_train_epochs 1 \
|
14 |
+
--learning_rate 2e-5 \
|
15 |
+
--warmup_ratio 0.05 \
|
16 |
+
--weight_decay 0.001 \
|
17 |
+
--logging_strategy steps \
|
18 |
+
--logging_steps 10 \
|
19 |
+
--eval_steps 50 \
|
20 |
+
--evaluation_strategy steps \
|
21 |
+
--save_steps 500 \
|
22 |
+
--save_strategy steps \
|
23 |
+
--save_total_limit 3 \
|
24 |
+
--max_source_length 256 \
|
25 |
+
--max_target_length 256 \
|
26 |
+
--output_dir outputs-rm-bloom-v1 \
|
27 |
+
--overwrite_output_dir \
|
28 |
+
--ddp_timeout 30000 \
|
29 |
+
--logging_first_step True \
|
30 |
+
--target_modules all \
|
31 |
+
--lora_rank 8 \
|
32 |
+
--lora_alpha 16 \
|
33 |
+
--lora_dropout 0.05 \
|
34 |
+
--torch_dtype float32 \
|
35 |
+
--device_map auto \
|
36 |
+
--report_to tensorboard \
|
37 |
+
--ddp_find_unused_parameters False \
|
38 |
+
--remove_unused_columns False \
|
39 |
+
--gradient_checkpointing True
|
run_sft.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 supervised_finetuning.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/finetune \
|
5 |
+
--validation_file_dir ./data/finetune \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--fp16 \
|
12 |
+
--max_train_samples 1000 \
|
13 |
+
--max_eval_samples 10 \
|
14 |
+
--num_train_epochs 1 \
|
15 |
+
--learning_rate 2e-5 \
|
16 |
+
--warmup_ratio 0.05 \
|
17 |
+
--weight_decay 0.05 \
|
18 |
+
--logging_strategy steps \
|
19 |
+
--logging_steps 10 \
|
20 |
+
--eval_steps 50 \
|
21 |
+
--evaluation_strategy steps \
|
22 |
+
--save_steps 500 \
|
23 |
+
--save_strategy steps \
|
24 |
+
--save_total_limit 3 \
|
25 |
+
--gradient_accumulation_steps 1 \
|
26 |
+
--preprocessing_num_workers 4 \
|
27 |
+
--output_dir outputs-sft-bloom-v1 \
|
28 |
+
--overwrite_output_dir \
|
29 |
+
--ddp_timeout 30000 \
|
30 |
+
--logging_first_step True \
|
31 |
+
--target_modules all \
|
32 |
+
--lora_rank 8 \
|
33 |
+
--lora_alpha 16 \
|
34 |
+
--lora_dropout 0.05 \
|
35 |
+
--torch_dtype float16 \
|
36 |
+
--device_map auto \
|
37 |
+
--report_to tensorboard \
|
38 |
+
--ddp_find_unused_parameters False \
|
39 |
+
--gradient_checkpointing True \
|
40 |
+
--cache_dir ./cache
|
run_training_dpo_pipeline.ipynb
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# Training Pipeline\n",
|
7 |
+
"[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb)"
|
8 |
+
],
|
9 |
+
"metadata": {
|
10 |
+
"collapsed": false
|
11 |
+
}
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"tags": []
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# Stage 1: Continue Pretraining\n",
|
20 |
+
"\n",
|
21 |
+
"第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
|
22 |
+
"\n",
|
23 |
+
"| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"#### 说明:\n",
|
31 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
32 |
+
"\n",
|
33 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
|
34 |
+
"2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"source": [],
|
40 |
+
"metadata": {
|
41 |
+
"collapsed": false
|
42 |
+
}
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "markdown",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## 配置运行环境\n",
|
49 |
+
"\n",
|
50 |
+
"本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
|
51 |
+
"\n",
|
52 |
+
"colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
|
53 |
+
"\n",
|
54 |
+
"步骤:\n",
|
55 |
+
"1. 下载最新代码到本地\n",
|
56 |
+
"2. 安装依赖包\n",
|
57 |
+
"\n",
|
58 |
+
"依赖包如下,保证最新版本:\n",
|
59 |
+
"\n",
|
60 |
+
"```\n",
|
61 |
+
"loguru\n",
|
62 |
+
"transformers\n",
|
63 |
+
"sentencepiece\n",
|
64 |
+
"datasets\n",
|
65 |
+
"tensorboard\n",
|
66 |
+
"tqdm\n",
|
67 |
+
"peft\n",
|
68 |
+
"trl\n",
|
69 |
+
"```"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
|
79 |
+
"%cd MedicalGPT\n",
|
80 |
+
"%ls\n",
|
81 |
+
"!pip install -r requirements.txt"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "markdown",
|
86 |
+
"metadata": {},
|
87 |
+
"source": [
|
88 |
+
"## Stage1 咱们开始吧\n",
|
89 |
+
"\n",
|
90 |
+
"训练步骤如下:\n",
|
91 |
+
"\n",
|
92 |
+
"1. 确认训练集\n",
|
93 |
+
"2. 执行训练脚本\n",
|
94 |
+
"\n",
|
95 |
+
"训练脚本的执行逻辑如下:\n",
|
96 |
+
"1. 导入依赖包\n",
|
97 |
+
"2. 设置参数\n",
|
98 |
+
"3. 定义各函数并加载训练集\n",
|
99 |
+
"4. 加载模型和tokenizer\n",
|
100 |
+
"5. 开始训练并评估\n",
|
101 |
+
"6. 查看训练结果\n",
|
102 |
+
"\n",
|
103 |
+
"**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"%ls ./data/pretrain/"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"outputs": [],
|
119 |
+
"source": [
|
120 |
+
"!python pretraining.py \\\n",
|
121 |
+
" --model_type bloom \\\n",
|
122 |
+
" --model_name_or_path bigscience/bloomz-560m \\\n",
|
123 |
+
" --train_file_dir ./data/pretrain \\\n",
|
124 |
+
" --validation_file_dir ./data/pretrain \\\n",
|
125 |
+
" --per_device_train_batch_size 3 \\\n",
|
126 |
+
" --per_device_eval_batch_size 3 \\\n",
|
127 |
+
" --do_train \\\n",
|
128 |
+
" --do_eval \\\n",
|
129 |
+
" --use_peft True \\\n",
|
130 |
+
" --seed 42 \\\n",
|
131 |
+
" --fp16 \\\n",
|
132 |
+
" --max_train_samples 10000 \\\n",
|
133 |
+
" --max_eval_samples 10 \\\n",
|
134 |
+
" --num_train_epochs 1 \\\n",
|
135 |
+
" --learning_rate 2e-4 \\\n",
|
136 |
+
" --warmup_ratio 0.05 \\\n",
|
137 |
+
" --weight_decay 0.01 \\\n",
|
138 |
+
" --logging_strategy steps \\\n",
|
139 |
+
" --logging_steps 10 \\\n",
|
140 |
+
" --eval_steps 50 \\\n",
|
141 |
+
" --evaluation_strategy steps \\\n",
|
142 |
+
" --save_steps 500 \\\n",
|
143 |
+
" --save_strategy steps \\\n",
|
144 |
+
" --save_total_limit 3 \\\n",
|
145 |
+
" --gradient_accumulation_steps 1 \\\n",
|
146 |
+
" --preprocessing_num_workers 1 \\\n",
|
147 |
+
" --block_size 1024 \\\n",
|
148 |
+
" --output_dir outputs-pt-v1 \\\n",
|
149 |
+
" --overwrite_output_dir \\\n",
|
150 |
+
" --ddp_timeout 30000 \\\n",
|
151 |
+
" --logging_first_step True \\\n",
|
152 |
+
" --target_modules all \\\n",
|
153 |
+
" --lora_rank 8 \\\n",
|
154 |
+
" --lora_alpha 16 \\\n",
|
155 |
+
" --lora_dropout 0.05 \\\n",
|
156 |
+
" --torch_dtype float16 \\\n",
|
157 |
+
" --device_map auto \\\n",
|
158 |
+
" --report_to tensorboard \\\n",
|
159 |
+
" --ddp_find_unused_parameters False \\\n",
|
160 |
+
" --gradient_checkpointing True"
|
161 |
+
],
|
162 |
+
"metadata": {
|
163 |
+
"collapsed": false
|
164 |
+
}
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"%ls -lh outputs-pt-v1"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "markdown",
|
177 |
+
"metadata": {},
|
178 |
+
"source": [
|
179 |
+
"模型训练结果:\n",
|
180 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
181 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "markdown",
|
186 |
+
"source": [
|
187 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
188 |
+
],
|
189 |
+
"metadata": {
|
190 |
+
"collapsed": false
|
191 |
+
}
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": null,
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
199 |
+
" --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
|
200 |
+
],
|
201 |
+
"metadata": {
|
202 |
+
"collapsed": false
|
203 |
+
}
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"outputs": [],
|
209 |
+
"source": [
|
210 |
+
"%ls -lh merged-pt/"
|
211 |
+
],
|
212 |
+
"metadata": {
|
213 |
+
"collapsed": false
|
214 |
+
}
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "code",
|
218 |
+
"execution_count": null,
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"%cat merged-pt/config.json"
|
222 |
+
],
|
223 |
+
"metadata": {
|
224 |
+
"collapsed": false
|
225 |
+
}
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "markdown",
|
229 |
+
"metadata": {},
|
230 |
+
"source": [
|
231 |
+
"Stage1 增量预训练完成。"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {
|
238 |
+
"ExecuteTime": {
|
239 |
+
"start_time": "2023-06-15T13:56:17.032821Z",
|
240 |
+
"end_time": "2023-06-15T13:56:17.081153Z"
|
241 |
+
}
|
242 |
+
},
|
243 |
+
"outputs": [],
|
244 |
+
"source": []
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"source": [
|
249 |
+
"# Stage 2: Supervised FineTuning\n",
|
250 |
+
"\n",
|
251 |
+
"第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
|
252 |
+
"\n",
|
253 |
+
"| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
|
254 |
+
],
|
255 |
+
"metadata": {
|
256 |
+
"collapsed": false
|
257 |
+
}
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "markdown",
|
261 |
+
"source": [
|
262 |
+
"#### 说明:\n",
|
263 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
264 |
+
"\n",
|
265 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
|
266 |
+
"2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
267 |
+
],
|
268 |
+
"metadata": {
|
269 |
+
"collapsed": false
|
270 |
+
}
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "markdown",
|
274 |
+
"source": [
|
275 |
+
"## Stage2 咱们开始吧\n",
|
276 |
+
"\n",
|
277 |
+
"训练步骤如下:\n",
|
278 |
+
"\n",
|
279 |
+
"1. 确认训练集\n",
|
280 |
+
"2. 执行训练脚本\n",
|
281 |
+
"\n",
|
282 |
+
"训练脚本的执行逻辑如下:\n",
|
283 |
+
"1. 导入依赖包\n",
|
284 |
+
"2. 设置参数\n",
|
285 |
+
"3. 定义各函数并加载训练集\n",
|
286 |
+
"4. 加载模型和tokenizer\n",
|
287 |
+
"5. 开始训练并评估\n",
|
288 |
+
"6. 查看训练结果"
|
289 |
+
],
|
290 |
+
"metadata": {
|
291 |
+
"collapsed": false
|
292 |
+
}
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"execution_count": null,
|
297 |
+
"outputs": [],
|
298 |
+
"source": [
|
299 |
+
"%ls ./data/finetune"
|
300 |
+
],
|
301 |
+
"metadata": {
|
302 |
+
"collapsed": false,
|
303 |
+
"ExecuteTime": {
|
304 |
+
"start_time": "2023-06-15T13:58:38.778132Z",
|
305 |
+
"end_time": "2023-06-15T13:58:38.966506Z"
|
306 |
+
}
|
307 |
+
}
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"outputs": [],
|
313 |
+
"source": [
|
314 |
+
"!python supervised_finetuning.py \\\n",
|
315 |
+
" --model_type bloom \\\n",
|
316 |
+
" --model_name_or_path merged-pt \\\n",
|
317 |
+
" --train_file_dir ./data/finetune \\\n",
|
318 |
+
" --validation_file_dir ./data/finetune \\\n",
|
319 |
+
" --per_device_train_batch_size 4 \\\n",
|
320 |
+
" --per_device_eval_batch_size 4 \\\n",
|
321 |
+
" --do_train \\\n",
|
322 |
+
" --do_eval \\\n",
|
323 |
+
" --use_peft True \\\n",
|
324 |
+
" --fp16 \\\n",
|
325 |
+
" --max_train_samples 1000 \\\n",
|
326 |
+
" --max_eval_samples 10 \\\n",
|
327 |
+
" --num_train_epochs 1 \\\n",
|
328 |
+
" --learning_rate 2e-5 \\\n",
|
329 |
+
" --warmup_ratio 0.05 \\\n",
|
330 |
+
" --weight_decay 0.05 \\\n",
|
331 |
+
" --logging_strategy steps \\\n",
|
332 |
+
" --logging_steps 10 \\\n",
|
333 |
+
" --eval_steps 50 \\\n",
|
334 |
+
" --evaluation_strategy steps \\\n",
|
335 |
+
" --save_steps 500 \\\n",
|
336 |
+
" --save_strategy steps \\\n",
|
337 |
+
" --save_total_limit 3 \\\n",
|
338 |
+
" --gradient_accumulation_steps 1 \\\n",
|
339 |
+
" --preprocessing_num_workers 1 \\\n",
|
340 |
+
" --output_dir outputs-sft-v1 \\\n",
|
341 |
+
" --overwrite_output_dir \\\n",
|
342 |
+
" --ddp_timeout 30000 \\\n",
|
343 |
+
" --logging_first_step True \\\n",
|
344 |
+
" --target_modules all \\\n",
|
345 |
+
" --lora_rank 8 \\\n",
|
346 |
+
" --lora_alpha 16 \\\n",
|
347 |
+
" --lora_dropout 0.05 \\\n",
|
348 |
+
" --torch_dtype float16 \\\n",
|
349 |
+
" --device_map auto \\\n",
|
350 |
+
" --report_to tensorboard \\\n",
|
351 |
+
" --ddp_find_unused_parameters False \\\n",
|
352 |
+
" --gradient_checkpointing True"
|
353 |
+
],
|
354 |
+
"metadata": {
|
355 |
+
"collapsed": false
|
356 |
+
}
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"outputs": [],
|
362 |
+
"source": [
|
363 |
+
"%ls -lh outputs-sft-v1"
|
364 |
+
],
|
365 |
+
"metadata": {
|
366 |
+
"collapsed": false
|
367 |
+
}
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "markdown",
|
371 |
+
"source": [
|
372 |
+
"模型训练结果:\n",
|
373 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
374 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
375 |
+
],
|
376 |
+
"metadata": {
|
377 |
+
"collapsed": false
|
378 |
+
}
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "markdown",
|
382 |
+
"source": [
|
383 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
384 |
+
],
|
385 |
+
"metadata": {
|
386 |
+
"collapsed": false
|
387 |
+
}
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": null,
|
392 |
+
"outputs": [],
|
393 |
+
"source": [
|
394 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
395 |
+
" --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
|
396 |
+
],
|
397 |
+
"metadata": {
|
398 |
+
"collapsed": false
|
399 |
+
}
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"execution_count": null,
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"%ls -lh merged-sft/"
|
407 |
+
],
|
408 |
+
"metadata": {
|
409 |
+
"collapsed": false
|
410 |
+
}
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"cell_type": "code",
|
414 |
+
"execution_count": null,
|
415 |
+
"outputs": [],
|
416 |
+
"source": [
|
417 |
+
"%cat merged-sft/config.json"
|
418 |
+
],
|
419 |
+
"metadata": {
|
420 |
+
"collapsed": false
|
421 |
+
}
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "markdown",
|
425 |
+
"source": [
|
426 |
+
"Stage2 SFT训练完成。"
|
427 |
+
],
|
428 |
+
"metadata": {
|
429 |
+
"collapsed": false
|
430 |
+
}
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"outputs": [],
|
436 |
+
"source": [],
|
437 |
+
"metadata": {
|
438 |
+
"collapsed": false,
|
439 |
+
"ExecuteTime": {
|
440 |
+
"start_time": "2023-06-15T14:07:40.731186Z",
|
441 |
+
"end_time": "2023-06-15T14:07:40.752635Z"
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "markdown",
|
447 |
+
"source": [
|
448 |
+
"# Stage 3: DPO(Direct Preference Optimization)\n",
|
449 |
+
"\n",
|
450 |
+
"第三阶段:DPO(Direct Preference Optimization)直接偏好优化,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好\n",
|
451 |
+
"\n",
|
452 |
+
"| Stage 3: Direct Preference Optimization | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |"
|
453 |
+
],
|
454 |
+
"metadata": {
|
455 |
+
"collapsed": false
|
456 |
+
}
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "markdown",
|
460 |
+
"source": [
|
461 |
+
"#### 说明:\n",
|
462 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
463 |
+
"\n",
|
464 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
465 |
+
"2. 数据集:DPO阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
|
466 |
+
],
|
467 |
+
"metadata": {
|
468 |
+
"collapsed": false
|
469 |
+
}
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "markdown",
|
473 |
+
"source": [
|
474 |
+
"## Stage3 咱们开始吧\n",
|
475 |
+
"\n",
|
476 |
+
"训练步骤如下:\n",
|
477 |
+
"\n",
|
478 |
+
"1. 确认训练集\n",
|
479 |
+
"2. 执行训练脚本\n",
|
480 |
+
"\n",
|
481 |
+
"训练脚本的执行逻辑如下:\n",
|
482 |
+
"1. 导入依赖包\n",
|
483 |
+
"2. 设置参数\n",
|
484 |
+
"3. 定义各函数并加载训练集\n",
|
485 |
+
"4. 加载模型和tokenizer\n",
|
486 |
+
"5. 开始训练并评估\n",
|
487 |
+
"6. 查看训练结果"
|
488 |
+
],
|
489 |
+
"metadata": {
|
490 |
+
"collapsed": false
|
491 |
+
}
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "code",
|
495 |
+
"execution_count": null,
|
496 |
+
"outputs": [],
|
497 |
+
"source": [
|
498 |
+
"%ls ./data/reward/"
|
499 |
+
],
|
500 |
+
"metadata": {
|
501 |
+
"collapsed": false
|
502 |
+
}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": null,
|
507 |
+
"outputs": [],
|
508 |
+
"source": [
|
509 |
+
"!python dpo_training.py \\\n",
|
510 |
+
" --model_type bloom \\\n",
|
511 |
+
" --model_name_or_path merged-sft \\\n",
|
512 |
+
" --train_file_dir ./data/reward \\\n",
|
513 |
+
" --validation_file_dir ./data/reward \\\n",
|
514 |
+
" --per_device_train_batch_size 3 \\\n",
|
515 |
+
" --per_device_eval_batch_size 1 \\\n",
|
516 |
+
" --do_train \\\n",
|
517 |
+
" --do_eval \\\n",
|
518 |
+
" --use_peft True \\\n",
|
519 |
+
" --max_train_samples 1000 \\\n",
|
520 |
+
" --max_eval_samples 10 \\\n",
|
521 |
+
" --max_steps 100 \\\n",
|
522 |
+
" --eval_steps 10 \\\n",
|
523 |
+
" --save_steps 50 \\\n",
|
524 |
+
" --max_source_length 128 \\\n",
|
525 |
+
" --max_target_length 128 \\\n",
|
526 |
+
" --output_dir outputs-dpo-v1 \\\n",
|
527 |
+
" --target_modules all \\\n",
|
528 |
+
" --lora_rank 8 \\\n",
|
529 |
+
" --lora_alpha 16 \\\n",
|
530 |
+
" --lora_dropout 0.05 \\\n",
|
531 |
+
" --torch_dtype float16 \\\n",
|
532 |
+
" --fp16 True \\\n",
|
533 |
+
" --device_map auto \\\n",
|
534 |
+
" --report_to tensorboard \\\n",
|
535 |
+
" --remove_unused_columns False \\\n",
|
536 |
+
" --gradient_checkpointing True \\\n",
|
537 |
+
" --cache_dir ./cache"
|
538 |
+
],
|
539 |
+
"metadata": {
|
540 |
+
"collapsed": false
|
541 |
+
}
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"cell_type": "code",
|
545 |
+
"execution_count": null,
|
546 |
+
"outputs": [],
|
547 |
+
"source": [
|
548 |
+
"%ls -lh outputs-dpo-v1"
|
549 |
+
],
|
550 |
+
"metadata": {
|
551 |
+
"collapsed": false
|
552 |
+
}
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"cell_type": "markdown",
|
556 |
+
"source": [
|
557 |
+
"模型训练结果:\n",
|
558 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
559 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
560 |
+
],
|
561 |
+
"metadata": {
|
562 |
+
"collapsed": false
|
563 |
+
}
|
564 |
+
},
|
565 |
+
{
|
566 |
+
"cell_type": "markdown",
|
567 |
+
"source": [
|
568 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
569 |
+
],
|
570 |
+
"metadata": {
|
571 |
+
"collapsed": false
|
572 |
+
}
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": null,
|
577 |
+
"outputs": [],
|
578 |
+
"source": [
|
579 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
580 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-dpo-v1 --output_dir merged-dpo/"
|
581 |
+
],
|
582 |
+
"metadata": {
|
583 |
+
"collapsed": false
|
584 |
+
}
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"cell_type": "code",
|
588 |
+
"execution_count": null,
|
589 |
+
"outputs": [],
|
590 |
+
"source": [
|
591 |
+
"%ls -lh merged-dpo/"
|
592 |
+
],
|
593 |
+
"metadata": {
|
594 |
+
"collapsed": false
|
595 |
+
}
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": null,
|
600 |
+
"outputs": [],
|
601 |
+
"source": [
|
602 |
+
"%cat merged-dpo/config.json"
|
603 |
+
],
|
604 |
+
"metadata": {
|
605 |
+
"collapsed": false
|
606 |
+
}
|
607 |
+
},
|
608 |
+
{
|
609 |
+
"cell_type": "markdown",
|
610 |
+
"source": [
|
611 |
+
"Stage3 偏好建模第一次训练完成。"
|
612 |
+
],
|
613 |
+
"metadata": {
|
614 |
+
"collapsed": false
|
615 |
+
}
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "markdown",
|
619 |
+
"source": [
|
620 |
+
"**至此一个完整的训练流程演示完成。**"
|
621 |
+
],
|
622 |
+
"metadata": {
|
623 |
+
"collapsed": false
|
624 |
+
}
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"execution_count": null,
|
629 |
+
"outputs": [],
|
630 |
+
"source": [],
|
631 |
+
"metadata": {
|
632 |
+
"collapsed": false,
|
633 |
+
"ExecuteTime": {
|
634 |
+
"start_time": "2023-06-26T12:34:29.620609Z",
|
635 |
+
"end_time": "2023-06-26T12:34:29.658428Z"
|
636 |
+
}
|
637 |
+
}
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "markdown",
|
641 |
+
"source": [
|
642 |
+
"# Test"
|
643 |
+
],
|
644 |
+
"metadata": {
|
645 |
+
"collapsed": false
|
646 |
+
}
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "code",
|
650 |
+
"execution_count": null,
|
651 |
+
"outputs": [],
|
652 |
+
"source": [
|
653 |
+
"!python inference.py --model_type bloom --base_model merged-dpo --interactive"
|
654 |
+
],
|
655 |
+
"metadata": {
|
656 |
+
"collapsed": false,
|
657 |
+
"ExecuteTime": {
|
658 |
+
"start_time": "2023-06-26T12:34:47.802087Z",
|
659 |
+
"end_time": "2023-06-26T12:35:00.864463Z"
|
660 |
+
}
|
661 |
+
}
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"cell_type": "markdown",
|
665 |
+
"source": [
|
666 |
+
"Input:介绍下南京\n",
|
667 |
+
"Response: 南京市位于江苏省西南部,是全国首批历史文化名城、国家中心城市和自由贸易试验区。\n",
|
668 |
+
"\n",
|
669 |
+
"完。\n"
|
670 |
+
],
|
671 |
+
"metadata": {
|
672 |
+
"collapsed": false
|
673 |
+
}
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"cell_type": "code",
|
677 |
+
"execution_count": null,
|
678 |
+
"outputs": [],
|
679 |
+
"source": [],
|
680 |
+
"metadata": {
|
681 |
+
"collapsed": false
|
682 |
+
}
|
683 |
+
}
|
684 |
+
],
|
685 |
+
"metadata": {
|
686 |
+
"kernelspec": {
|
687 |
+
"name": "python3",
|
688 |
+
"language": "python",
|
689 |
+
"display_name": "Python 3"
|
690 |
+
},
|
691 |
+
"language_info": {
|
692 |
+
"codemirror_mode": {
|
693 |
+
"name": "ipython",
|
694 |
+
"version": 3
|
695 |
+
},
|
696 |
+
"file_extension": ".py",
|
697 |
+
"mimetype": "text/x-python",
|
698 |
+
"name": "python",
|
699 |
+
"nbconvert_exporter": "python",
|
700 |
+
"pygments_lexer": "ipython3",
|
701 |
+
"version": "3.8.13"
|
702 |
+
},
|
703 |
+
"vscode": {
|
704 |
+
"interpreter": {
|
705 |
+
"hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
|
706 |
+
}
|
707 |
+
}
|
708 |
+
},
|
709 |
+
"nbformat": 4,
|
710 |
+
"nbformat_minor": 4
|
711 |
+
}
|
run_training_pipeline.ipynb
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# Training Pipeline\n",
|
7 |
+
"[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb)"
|
8 |
+
],
|
9 |
+
"metadata": {
|
10 |
+
"collapsed": false
|
11 |
+
}
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"tags": []
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# Stage 1: Continue Pretraining\n",
|
20 |
+
"\n",
|
21 |
+
"第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
|
22 |
+
"\n",
|
23 |
+
"| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"#### 说明:\n",
|
31 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
32 |
+
"\n",
|
33 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
|
34 |
+
"2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {},
|
40 |
+
"source": [
|
41 |
+
"## 配置运行环境\n",
|
42 |
+
"\n",
|
43 |
+
"本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
|
44 |
+
"\n",
|
45 |
+
"colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
|
46 |
+
"\n",
|
47 |
+
"步骤:\n",
|
48 |
+
"1. 下载最新代码到本地\n",
|
49 |
+
"2. 安装依赖包\n",
|
50 |
+
"\n",
|
51 |
+
"依赖包如下,保证最新版本:\n",
|
52 |
+
"\n",
|
53 |
+
"```\n",
|
54 |
+
"loguru\n",
|
55 |
+
"transformers\n",
|
56 |
+
"sentencepiece\n",
|
57 |
+
"datasets\n",
|
58 |
+
"tensorboard\n",
|
59 |
+
"tqdm\n",
|
60 |
+
"peft\n",
|
61 |
+
"trl\n",
|
62 |
+
"```"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
|
72 |
+
"%cd MedicalGPT\n",
|
73 |
+
"%ls\n",
|
74 |
+
"!pip install -r requirements.txt"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "markdown",
|
79 |
+
"metadata": {},
|
80 |
+
"source": [
|
81 |
+
"## Stage1 咱们开始吧\n",
|
82 |
+
"\n",
|
83 |
+
"训练步骤如下:\n",
|
84 |
+
"\n",
|
85 |
+
"1. 确认训练集\n",
|
86 |
+
"2. 执行训练脚本\n",
|
87 |
+
"\n",
|
88 |
+
"训练脚本的执行逻辑如下:\n",
|
89 |
+
"1. 导入依赖包\n",
|
90 |
+
"2. 设置参数\n",
|
91 |
+
"3. 定义各函数并加载训练集\n",
|
92 |
+
"4. 加载模型和tokenizer\n",
|
93 |
+
"5. 开始训练并评估\n",
|
94 |
+
"6. 查看训练结果\n",
|
95 |
+
"\n",
|
96 |
+
"**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"%ls ./data/pretrain/"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"!python pretraining.py \\\n",
|
114 |
+
" --model_type bloom \\\n",
|
115 |
+
" --model_name_or_path bigscience/bloomz-560m \\\n",
|
116 |
+
" --train_file_dir ./data/pretrain \\\n",
|
117 |
+
" --validation_file_dir ./data/pretrain \\\n",
|
118 |
+
" --per_device_train_batch_size 3 \\\n",
|
119 |
+
" --per_device_eval_batch_size 3 \\\n",
|
120 |
+
" --do_train \\\n",
|
121 |
+
" --do_eval \\\n",
|
122 |
+
" --use_peft True \\\n",
|
123 |
+
" --seed 42 \\\n",
|
124 |
+
" --fp16 \\\n",
|
125 |
+
" --max_train_samples 10000 \\\n",
|
126 |
+
" --max_eval_samples 10 \\\n",
|
127 |
+
" --num_train_epochs 1 \\\n",
|
128 |
+
" --learning_rate 2e-4 \\\n",
|
129 |
+
" --warmup_ratio 0.05 \\\n",
|
130 |
+
" --weight_decay 0.01 \\\n",
|
131 |
+
" --logging_strategy steps \\\n",
|
132 |
+
" --logging_steps 10 \\\n",
|
133 |
+
" --eval_steps 50 \\\n",
|
134 |
+
" --evaluation_strategy steps \\\n",
|
135 |
+
" --save_steps 500 \\\n",
|
136 |
+
" --save_strategy steps \\\n",
|
137 |
+
" --save_total_limit 3 \\\n",
|
138 |
+
" --gradient_accumulation_steps 1 \\\n",
|
139 |
+
" --preprocessing_num_workers 1 \\\n",
|
140 |
+
" --block_size 1024 \\\n",
|
141 |
+
" --output_dir outputs-pt-v1 \\\n",
|
142 |
+
" --overwrite_output_dir \\\n",
|
143 |
+
" --ddp_timeout 30000 \\\n",
|
144 |
+
" --logging_first_step True \\\n",
|
145 |
+
" --target_modules all \\\n",
|
146 |
+
" --lora_rank 8 \\\n",
|
147 |
+
" --lora_alpha 16 \\\n",
|
148 |
+
" --lora_dropout 0.05 \\\n",
|
149 |
+
" --torch_dtype float16 \\\n",
|
150 |
+
" --device_map auto \\\n",
|
151 |
+
" --report_to tensorboard \\\n",
|
152 |
+
" --ddp_find_unused_parameters False \\\n",
|
153 |
+
" --gradient_checkpointing True"
|
154 |
+
],
|
155 |
+
"metadata": {
|
156 |
+
"collapsed": false
|
157 |
+
}
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [],
|
164 |
+
"source": [
|
165 |
+
"%ls -lh outputs-pt-v1"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "markdown",
|
170 |
+
"metadata": {},
|
171 |
+
"source": [
|
172 |
+
"模型训练结果:\n",
|
173 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
174 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "markdown",
|
179 |
+
"source": [
|
180 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
181 |
+
],
|
182 |
+
"metadata": {
|
183 |
+
"collapsed": false
|
184 |
+
}
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
192 |
+
" --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
|
193 |
+
],
|
194 |
+
"metadata": {
|
195 |
+
"collapsed": false
|
196 |
+
}
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": null,
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"%ls -lh merged-pt/"
|
204 |
+
],
|
205 |
+
"metadata": {
|
206 |
+
"collapsed": false
|
207 |
+
}
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": null,
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"%cat merged-pt/config.json"
|
215 |
+
],
|
216 |
+
"metadata": {
|
217 |
+
"collapsed": false
|
218 |
+
}
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "markdown",
|
222 |
+
"metadata": {},
|
223 |
+
"source": [
|
224 |
+
"Stage1 增量预训练完成。"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"metadata": {
|
231 |
+
"ExecuteTime": {
|
232 |
+
"start_time": "2023-06-15T13:56:17.032821Z",
|
233 |
+
"end_time": "2023-06-15T13:56:17.081153Z"
|
234 |
+
}
|
235 |
+
},
|
236 |
+
"outputs": [],
|
237 |
+
"source": []
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"source": [
|
242 |
+
"# Stage 2: Supervised FineTuning\n",
|
243 |
+
"\n",
|
244 |
+
"第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
|
245 |
+
"\n",
|
246 |
+
"| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
|
247 |
+
],
|
248 |
+
"metadata": {
|
249 |
+
"collapsed": false
|
250 |
+
}
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "markdown",
|
254 |
+
"source": [
|
255 |
+
"#### 说明:\n",
|
256 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
257 |
+
"\n",
|
258 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
|
259 |
+
"2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
260 |
+
],
|
261 |
+
"metadata": {
|
262 |
+
"collapsed": false
|
263 |
+
}
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "markdown",
|
267 |
+
"source": [
|
268 |
+
"## Stage2 咱们开始吧\n",
|
269 |
+
"\n",
|
270 |
+
"训练步骤如下:\n",
|
271 |
+
"\n",
|
272 |
+
"1. 确认训练集\n",
|
273 |
+
"2. 执行训练脚本\n",
|
274 |
+
"\n",
|
275 |
+
"训练脚本的执行逻辑如下:\n",
|
276 |
+
"1. 导入依赖包\n",
|
277 |
+
"2. 设置参数\n",
|
278 |
+
"3. 定义各函数并加载训练集\n",
|
279 |
+
"4. 加载模型和tokenizer\n",
|
280 |
+
"5. 开始训练并评估\n",
|
281 |
+
"6. 查看训练结果"
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"collapsed": false
|
285 |
+
}
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"execution_count": null,
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"%ls ./data/finetune"
|
293 |
+
],
|
294 |
+
"metadata": {
|
295 |
+
"collapsed": false,
|
296 |
+
"ExecuteTime": {
|
297 |
+
"start_time": "2023-06-15T13:58:38.778132Z",
|
298 |
+
"end_time": "2023-06-15T13:58:38.966506Z"
|
299 |
+
}
|
300 |
+
}
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": null,
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"!python supervised_finetuning.py \\\n",
|
308 |
+
" --model_type bloom \\\n",
|
309 |
+
" --model_name_or_path merged-pt \\\n",
|
310 |
+
" --train_file_dir ./data/finetune \\\n",
|
311 |
+
" --validation_file_dir ./data/finetune \\\n",
|
312 |
+
" --per_device_train_batch_size 4 \\\n",
|
313 |
+
" --per_device_eval_batch_size 4 \\\n",
|
314 |
+
" --do_train \\\n",
|
315 |
+
" --do_eval \\\n",
|
316 |
+
" --use_peft True \\\n",
|
317 |
+
" --fp16 \\\n",
|
318 |
+
" --max_train_samples 1000 \\\n",
|
319 |
+
" --max_eval_samples 10 \\\n",
|
320 |
+
" --num_train_epochs 1 \\\n",
|
321 |
+
" --learning_rate 2e-5 \\\n",
|
322 |
+
" --warmup_ratio 0.05 \\\n",
|
323 |
+
" --weight_decay 0.05 \\\n",
|
324 |
+
" --logging_strategy steps \\\n",
|
325 |
+
" --logging_steps 10 \\\n",
|
326 |
+
" --eval_steps 50 \\\n",
|
327 |
+
" --evaluation_strategy steps \\\n",
|
328 |
+
" --save_steps 500 \\\n",
|
329 |
+
" --save_strategy steps \\\n",
|
330 |
+
" --save_total_limit 3 \\\n",
|
331 |
+
" --gradient_accumulation_steps 1 \\\n",
|
332 |
+
" --preprocessing_num_workers 1 \\\n",
|
333 |
+
" --output_dir outputs-sft-v1 \\\n",
|
334 |
+
" --overwrite_output_dir \\\n",
|
335 |
+
" --ddp_timeout 30000 \\\n",
|
336 |
+
" --logging_first_step True \\\n",
|
337 |
+
" --target_modules all \\\n",
|
338 |
+
" --lora_rank 8 \\\n",
|
339 |
+
" --lora_alpha 16 \\\n",
|
340 |
+
" --lora_dropout 0.05 \\\n",
|
341 |
+
" --torch_dtype float16 \\\n",
|
342 |
+
" --device_map auto \\\n",
|
343 |
+
" --report_to tensorboard \\\n",
|
344 |
+
" --ddp_find_unused_parameters False \\\n",
|
345 |
+
" --gradient_checkpointing True"
|
346 |
+
],
|
347 |
+
"metadata": {
|
348 |
+
"collapsed": false
|
349 |
+
}
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"cell_type": "code",
|
353 |
+
"execution_count": null,
|
354 |
+
"outputs": [],
|
355 |
+
"source": [
|
356 |
+
"%ls -lh outputs-sft-v1"
|
357 |
+
],
|
358 |
+
"metadata": {
|
359 |
+
"collapsed": false
|
360 |
+
}
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "markdown",
|
364 |
+
"source": [
|
365 |
+
"模型训练结果:\n",
|
366 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
367 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"collapsed": false
|
371 |
+
}
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "markdown",
|
375 |
+
"source": [
|
376 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
377 |
+
],
|
378 |
+
"metadata": {
|
379 |
+
"collapsed": false
|
380 |
+
}
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": null,
|
385 |
+
"outputs": [],
|
386 |
+
"source": [
|
387 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
388 |
+
" --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
|
389 |
+
],
|
390 |
+
"metadata": {
|
391 |
+
"collapsed": false
|
392 |
+
}
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": null,
|
397 |
+
"outputs": [],
|
398 |
+
"source": [
|
399 |
+
"%ls -lh merged-sft/"
|
400 |
+
],
|
401 |
+
"metadata": {
|
402 |
+
"collapsed": false
|
403 |
+
}
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "code",
|
407 |
+
"execution_count": null,
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"%cat merged-sft/config.json"
|
411 |
+
],
|
412 |
+
"metadata": {
|
413 |
+
"collapsed": false
|
414 |
+
}
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "markdown",
|
418 |
+
"source": [
|
419 |
+
"Stage2 SFT训练完成。"
|
420 |
+
],
|
421 |
+
"metadata": {
|
422 |
+
"collapsed": false
|
423 |
+
}
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"outputs": [],
|
429 |
+
"source": [],
|
430 |
+
"metadata": {
|
431 |
+
"collapsed": false,
|
432 |
+
"ExecuteTime": {
|
433 |
+
"start_time": "2023-06-15T14:07:40.731186Z",
|
434 |
+
"end_time": "2023-06-15T14:07:40.752635Z"
|
435 |
+
}
|
436 |
+
}
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"cell_type": "markdown",
|
440 |
+
"source": [
|
441 |
+
"# Stage 3: Reward Modeling\n",
|
442 |
+
"\n",
|
443 |
+
"第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是\"HHH\"原则,具体是\"helpful, honest, harmless\"\n",
|
444 |
+
"\n",
|
445 |
+
"| Stage 3: Reward Modeling | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |"
|
446 |
+
],
|
447 |
+
"metadata": {
|
448 |
+
"collapsed": false
|
449 |
+
}
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"cell_type": "markdown",
|
453 |
+
"source": [
|
454 |
+
"#### 说明:\n",
|
455 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
456 |
+
"\n",
|
457 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
458 |
+
"2. 数据集:RM阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
|
459 |
+
],
|
460 |
+
"metadata": {
|
461 |
+
"collapsed": false
|
462 |
+
}
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "markdown",
|
466 |
+
"source": [
|
467 |
+
"## Stage3 咱们开始吧\n",
|
468 |
+
"\n",
|
469 |
+
"训练步骤如下:\n",
|
470 |
+
"\n",
|
471 |
+
"1. 确认训练集\n",
|
472 |
+
"2. 执行训练脚本\n",
|
473 |
+
"\n",
|
474 |
+
"训练脚本的执行逻辑如下:\n",
|
475 |
+
"1. 导入依赖包\n",
|
476 |
+
"2. 设置参数\n",
|
477 |
+
"3. 定义各函数并加载训练集\n",
|
478 |
+
"4. 加载模型和tokenizer\n",
|
479 |
+
"5. 开始训练并评估\n",
|
480 |
+
"6. 查看训练结果"
|
481 |
+
],
|
482 |
+
"metadata": {
|
483 |
+
"collapsed": false
|
484 |
+
}
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"cell_type": "code",
|
488 |
+
"execution_count": null,
|
489 |
+
"outputs": [],
|
490 |
+
"source": [
|
491 |
+
"%ls ./data/reward/"
|
492 |
+
],
|
493 |
+
"metadata": {
|
494 |
+
"collapsed": false
|
495 |
+
}
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": null,
|
500 |
+
"outputs": [],
|
501 |
+
"source": [
|
502 |
+
"!python reward_modeling.py \\\n",
|
503 |
+
" --model_type bloom \\\n",
|
504 |
+
" --model_name_or_path merged-sft \\\n",
|
505 |
+
" --train_file_dir ./data/reward \\\n",
|
506 |
+
" --validation_file_dir ./data/reward \\\n",
|
507 |
+
" --per_device_train_batch_size 3 \\\n",
|
508 |
+
" --per_device_eval_batch_size 1 \\\n",
|
509 |
+
" --do_train \\\n",
|
510 |
+
" --use_peft True \\\n",
|
511 |
+
" --seed 42 \\\n",
|
512 |
+
" --max_train_samples 1000 \\\n",
|
513 |
+
" --max_eval_samples 10 \\\n",
|
514 |
+
" --num_train_epochs 1 \\\n",
|
515 |
+
" --learning_rate 2e-5 \\\n",
|
516 |
+
" --warmup_ratio 0.05 \\\n",
|
517 |
+
" --weight_decay 0.001 \\\n",
|
518 |
+
" --logging_strategy steps \\\n",
|
519 |
+
" --logging_steps 10 \\\n",
|
520 |
+
" --eval_steps 50 \\\n",
|
521 |
+
" --evaluation_strategy steps \\\n",
|
522 |
+
" --save_steps 500 \\\n",
|
523 |
+
" --save_strategy steps \\\n",
|
524 |
+
" --save_total_limit 3 \\\n",
|
525 |
+
" --max_source_length 256 \\\n",
|
526 |
+
" --max_target_length 256 \\\n",
|
527 |
+
" --output_dir outputs-rm-v1 \\\n",
|
528 |
+
" --overwrite_output_dir \\\n",
|
529 |
+
" --ddp_timeout 30000 \\\n",
|
530 |
+
" --logging_first_step True \\\n",
|
531 |
+
" --target_modules all \\\n",
|
532 |
+
" --lora_rank 8 \\\n",
|
533 |
+
" --lora_alpha 16 \\\n",
|
534 |
+
" --lora_dropout 0.05 \\\n",
|
535 |
+
" --torch_dtype float32 \\\n",
|
536 |
+
" --device_map auto \\\n",
|
537 |
+
" --report_to tensorboard \\\n",
|
538 |
+
" --ddp_find_unused_parameters False \\\n",
|
539 |
+
" --remove_unused_columns False \\\n",
|
540 |
+
" --gradient_checkpointing True"
|
541 |
+
],
|
542 |
+
"metadata": {
|
543 |
+
"collapsed": false
|
544 |
+
}
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "code",
|
548 |
+
"execution_count": null,
|
549 |
+
"outputs": [],
|
550 |
+
"source": [
|
551 |
+
"%ls -lh outputs-rm-v1"
|
552 |
+
],
|
553 |
+
"metadata": {
|
554 |
+
"collapsed": false
|
555 |
+
}
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "markdown",
|
559 |
+
"source": [
|
560 |
+
"模型训练结果:\n",
|
561 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
562 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
563 |
+
],
|
564 |
+
"metadata": {
|
565 |
+
"collapsed": false
|
566 |
+
}
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "markdown",
|
570 |
+
"source": [
|
571 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
572 |
+
],
|
573 |
+
"metadata": {
|
574 |
+
"collapsed": false
|
575 |
+
}
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"cell_type": "code",
|
579 |
+
"execution_count": null,
|
580 |
+
"outputs": [],
|
581 |
+
"source": [
|
582 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
583 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-rm-v1 --output_dir merged-rm/"
|
584 |
+
],
|
585 |
+
"metadata": {
|
586 |
+
"collapsed": false
|
587 |
+
}
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"cell_type": "code",
|
591 |
+
"execution_count": null,
|
592 |
+
"outputs": [],
|
593 |
+
"source": [
|
594 |
+
"%ls -lh merged-rm/"
|
595 |
+
],
|
596 |
+
"metadata": {
|
597 |
+
"collapsed": false
|
598 |
+
}
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"cell_type": "code",
|
602 |
+
"execution_count": null,
|
603 |
+
"outputs": [],
|
604 |
+
"source": [
|
605 |
+
"%cat merged-rm/config.json"
|
606 |
+
],
|
607 |
+
"metadata": {
|
608 |
+
"collapsed": false
|
609 |
+
}
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "markdown",
|
613 |
+
"source": [
|
614 |
+
"Stage3 奖励建模第一次训练完成。"
|
615 |
+
],
|
616 |
+
"metadata": {
|
617 |
+
"collapsed": false
|
618 |
+
}
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"cell_type": "code",
|
622 |
+
"execution_count": null,
|
623 |
+
"outputs": [],
|
624 |
+
"source": [],
|
625 |
+
"metadata": {
|
626 |
+
"collapsed": false,
|
627 |
+
"ExecuteTime": {
|
628 |
+
"start_time": "2023-06-15T14:12:09.464881Z",
|
629 |
+
"end_time": "2023-06-15T14:12:09.472414Z"
|
630 |
+
}
|
631 |
+
}
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "markdown",
|
635 |
+
"source": [
|
636 |
+
"# Stage 4: Reinforcement Learning Training\n",
|
637 |
+
"\n",
|
638 |
+
"第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本\n",
|
639 |
+
"\n",
|
640 |
+
"| Stage 4: Reinforcement Learning | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |\n"
|
641 |
+
],
|
642 |
+
"metadata": {
|
643 |
+
"collapsed": false
|
644 |
+
}
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"cell_type": "markdown",
|
648 |
+
"source": [
|
649 |
+
"#### 说明:\n",
|
650 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型、奖励模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
651 |
+
"\n",
|
652 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
653 |
+
"2. 奖励模型:使用的是`OpenAssistant/reward-model-deberta-v3-large-v2` 或者 Stage3得到的BERT类或者GPT类奖励模型\n",
|
654 |
+
"3. 数据集:RL阶段的数据可以复用SFT的数据集,使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
655 |
+
],
|
656 |
+
"metadata": {
|
657 |
+
"collapsed": false
|
658 |
+
}
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"cell_type": "markdown",
|
662 |
+
"source": [
|
663 |
+
"## Stage4 咱们开始吧\n",
|
664 |
+
"\n",
|
665 |
+
"训练步骤如下:\n",
|
666 |
+
"\n",
|
667 |
+
"1. 确认训练集\n",
|
668 |
+
"2. 执行训练脚本\n",
|
669 |
+
"\n",
|
670 |
+
"训练脚本的执行逻辑如下:\n",
|
671 |
+
"1. 导入依赖包\n",
|
672 |
+
"2. 设置参数\n",
|
673 |
+
"3. 定义各函数并加载训练集\n",
|
674 |
+
"4. 加载生成模型和tokenizer,加载奖励模型和其tokenizer\n",
|
675 |
+
"5. 开始训练并评估\n",
|
676 |
+
"6. 查看训练结果\n",
|
677 |
+
"\n",
|
678 |
+
"以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的。"
|
679 |
+
],
|
680 |
+
"metadata": {
|
681 |
+
"collapsed": false
|
682 |
+
}
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"cell_type": "code",
|
686 |
+
"execution_count": null,
|
687 |
+
"outputs": [],
|
688 |
+
"source": [
|
689 |
+
"%ls ./data/finetune/"
|
690 |
+
],
|
691 |
+
"metadata": {
|
692 |
+
"collapsed": false
|
693 |
+
}
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"cell_type": "code",
|
697 |
+
"execution_count": null,
|
698 |
+
"outputs": [],
|
699 |
+
"source": [
|
700 |
+
"!python rl_training.py \\\n",
|
701 |
+
" --model_type bloom \\\n",
|
702 |
+
" --model_name_or_path merged-sft \\\n",
|
703 |
+
" --reward_model_name_or_path merged-rm \\\n",
|
704 |
+
" --torch_dtype float16 \\\n",
|
705 |
+
" --device_map auto \\\n",
|
706 |
+
" --train_file_dir ./data/finetune \\\n",
|
707 |
+
" --validation_file_dir ./data/finetune \\\n",
|
708 |
+
" --batch_size 4 \\\n",
|
709 |
+
" --max_source_length 256 \\\n",
|
710 |
+
" --max_target_length 256 \\\n",
|
711 |
+
" --max_train_samples 1000 \\\n",
|
712 |
+
" --use_peft True \\\n",
|
713 |
+
" --lora_rank 8 \\\n",
|
714 |
+
" --lora_alpha 16 \\\n",
|
715 |
+
" --lora_dropout 0.05 \\\n",
|
716 |
+
" --do_train \\\n",
|
717 |
+
" --max_steps 64 \\\n",
|
718 |
+
" --learning_rate 1e-5 \\\n",
|
719 |
+
" --save_steps 50 \\\n",
|
720 |
+
" --output_dir outputs-rl-v1 \\\n",
|
721 |
+
" --early_stopping True \\\n",
|
722 |
+
" --target_kl 0.1 \\\n",
|
723 |
+
" --reward_baseline 0.0"
|
724 |
+
],
|
725 |
+
"metadata": {
|
726 |
+
"collapsed": false
|
727 |
+
}
|
728 |
+
},
|
729 |
+
{
|
730 |
+
"cell_type": "code",
|
731 |
+
"execution_count": null,
|
732 |
+
"outputs": [],
|
733 |
+
"source": [
|
734 |
+
"%ls -lh outputs-rl-v1"
|
735 |
+
],
|
736 |
+
"metadata": {
|
737 |
+
"collapsed": false
|
738 |
+
}
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"cell_type": "markdown",
|
742 |
+
"source": [
|
743 |
+
"模型训练结果:\n",
|
744 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
745 |
+
"- 日志保存在`output_dir/trl`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/trl --host 0.0.0.0 --port 8009`"
|
746 |
+
],
|
747 |
+
"metadata": {
|
748 |
+
"collapsed": false
|
749 |
+
}
|
750 |
+
},
|
751 |
+
{
|
752 |
+
"cell_type": "markdown",
|
753 |
+
"source": [
|
754 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
755 |
+
],
|
756 |
+
"metadata": {
|
757 |
+
"collapsed": false
|
758 |
+
}
|
759 |
+
},
|
760 |
+
{
|
761 |
+
"cell_type": "code",
|
762 |
+
"execution_count": null,
|
763 |
+
"outputs": [],
|
764 |
+
"source": [
|
765 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
766 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-rl-v1 --output_dir merged-rl/"
|
767 |
+
],
|
768 |
+
"metadata": {
|
769 |
+
"collapsed": false
|
770 |
+
}
|
771 |
+
},
|
772 |
+
{
|
773 |
+
"cell_type": "code",
|
774 |
+
"execution_count": null,
|
775 |
+
"outputs": [],
|
776 |
+
"source": [
|
777 |
+
"%ls -lh merged-rl/"
|
778 |
+
],
|
779 |
+
"metadata": {
|
780 |
+
"collapsed": false
|
781 |
+
}
|
782 |
+
},
|
783 |
+
{
|
784 |
+
"cell_type": "code",
|
785 |
+
"execution_count": null,
|
786 |
+
"outputs": [],
|
787 |
+
"source": [
|
788 |
+
"%cat merged-rl/config.json"
|
789 |
+
],
|
790 |
+
"metadata": {
|
791 |
+
"collapsed": false
|
792 |
+
}
|
793 |
+
},
|
794 |
+
{
|
795 |
+
"cell_type": "markdown",
|
796 |
+
"source": [
|
797 |
+
"Stage4 RL第一次训练完成。\n",
|
798 |
+
"\n",
|
799 |
+
"**至此一个完整的4阶段训练流程演示完成。**"
|
800 |
+
],
|
801 |
+
"metadata": {
|
802 |
+
"collapsed": false
|
803 |
+
}
|
804 |
+
},
|
805 |
+
{
|
806 |
+
"cell_type": "markdown",
|
807 |
+
"source": [
|
808 |
+
"实际操作中Stage3和Stage4可以反复多次,直到RL得到的最后模型满足评估要求。\n",
|
809 |
+
"\n",
|
810 |
+
"RLHF过程可以把SFT模型当成一个初始化模型,RM模型当做指导老师,使用RL(PPO)调教SFT模型生成指导老师最满意的结果,如果小学老师满意了,我们就再训练一个中学老师,继续指导,中学老师满意了,就训练一个大学老师,这样不断迭代,使得生成模型的质量达到甚至超过人工撰写的天花板。\n",
|
811 |
+
"\n",
|
812 |
+
"RLHF训练不易,此项目提供给大家一种实现的方法和参考,希望抛砖引玉,共同促进中文开源LLM发展。"
|
813 |
+
],
|
814 |
+
"metadata": {
|
815 |
+
"collapsed": false
|
816 |
+
}
|
817 |
+
},
|
818 |
+
{
|
819 |
+
"cell_type": "markdown",
|
820 |
+
"source": [],
|
821 |
+
"metadata": {
|
822 |
+
"collapsed": false
|
823 |
+
}
|
824 |
+
},
|
825 |
+
{
|
826 |
+
"cell_type": "code",
|
827 |
+
"execution_count": null,
|
828 |
+
"outputs": [],
|
829 |
+
"source": [],
|
830 |
+
"metadata": {
|
831 |
+
"collapsed": false,
|
832 |
+
"ExecuteTime": {
|
833 |
+
"start_time": "2023-06-26T12:34:29.620609Z",
|
834 |
+
"end_time": "2023-06-26T12:34:29.658428Z"
|
835 |
+
}
|
836 |
+
}
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"cell_type": "markdown",
|
840 |
+
"source": [
|
841 |
+
"# Test"
|
842 |
+
],
|
843 |
+
"metadata": {
|
844 |
+
"collapsed": false
|
845 |
+
}
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"cell_type": "markdown",
|
849 |
+
"source": [],
|
850 |
+
"metadata": {
|
851 |
+
"collapsed": false
|
852 |
+
}
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "code",
|
856 |
+
"execution_count": null,
|
857 |
+
"outputs": [],
|
858 |
+
"source": [
|
859 |
+
"!python inference.py --model_type bloom --base_model merged-rl --interactive"
|
860 |
+
],
|
861 |
+
"metadata": {
|
862 |
+
"collapsed": false,
|
863 |
+
"ExecuteTime": {
|
864 |
+
"start_time": "2023-06-26T12:34:47.802087Z",
|
865 |
+
"end_time": "2023-06-26T12:35:00.864463Z"
|
866 |
+
}
|
867 |
+
}
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"cell_type": "markdown",
|
871 |
+
"source": [
|
872 |
+
"Input:介绍下南京\n",
|
873 |
+
"Response: 南京市位于江苏省西南部,是全国��批历史文化名城、国家中心城市和自由贸易试验区。\n",
|
874 |
+
"\n",
|
875 |
+
"完。\n"
|
876 |
+
],
|
877 |
+
"metadata": {
|
878 |
+
"collapsed": false
|
879 |
+
}
|
880 |
+
},
|
881 |
+
{
|
882 |
+
"cell_type": "code",
|
883 |
+
"execution_count": null,
|
884 |
+
"outputs": [],
|
885 |
+
"source": [],
|
886 |
+
"metadata": {
|
887 |
+
"collapsed": false
|
888 |
+
}
|
889 |
+
}
|
890 |
+
],
|
891 |
+
"metadata": {
|
892 |
+
"kernelspec": {
|
893 |
+
"name": "python3",
|
894 |
+
"language": "python",
|
895 |
+
"display_name": "Python 3"
|
896 |
+
},
|
897 |
+
"language_info": {
|
898 |
+
"codemirror_mode": {
|
899 |
+
"name": "ipython",
|
900 |
+
"version": 3
|
901 |
+
},
|
902 |
+
"file_extension": ".py",
|
903 |
+
"mimetype": "text/x-python",
|
904 |
+
"name": "python",
|
905 |
+
"nbconvert_exporter": "python",
|
906 |
+
"pygments_lexer": "ipython3",
|
907 |
+
"version": "3.8.13"
|
908 |
+
},
|
909 |
+
"vscode": {
|
910 |
+
"interpreter": {
|
911 |
+
"hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
|
912 |
+
}
|
913 |
+
}
|
914 |
+
},
|
915 |
+
"nbformat": 4,
|
916 |
+
"nbformat_minor": 4
|
917 |
+
}
|
supervised_finetuning.py
ADDED
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright 2023 XuMing([email protected]) and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
17 |
+
|
18 |
+
part of this code is adapted from https://github.com/shibing624/textgen
|
19 |
+
"""
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from glob import glob
|
24 |
+
from typing import List, Optional, Dict, Sequence
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from datasets import load_dataset
|
28 |
+
from loguru import logger
|
29 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
30 |
+
from transformers import (
|
31 |
+
AutoConfig,
|
32 |
+
BloomForCausalLM,
|
33 |
+
AutoModel,
|
34 |
+
AutoModelForCausalLM,
|
35 |
+
LlamaTokenizer,
|
36 |
+
LlamaForCausalLM,
|
37 |
+
BloomTokenizerFast,
|
38 |
+
AutoTokenizer,
|
39 |
+
HfArgumentParser,
|
40 |
+
Trainer,
|
41 |
+
TrainingArguments,
|
42 |
+
set_seed,
|
43 |
+
BitsAndBytesConfig,
|
44 |
+
DataCollatorForSeq2Seq,
|
45 |
+
)
|
46 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
47 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
48 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
49 |
+
|
50 |
+
MODEL_CLASSES = {
|
51 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
52 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
53 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
54 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
55 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class ModelArguments:
|
61 |
+
"""
|
62 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
63 |
+
"""
|
64 |
+
|
65 |
+
model_type: str = field(
|
66 |
+
default=None,
|
67 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
68 |
+
)
|
69 |
+
model_name_or_path: Optional[str] = field(
|
70 |
+
default=None,
|
71 |
+
metadata={
|
72 |
+
"help": (
|
73 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
74 |
+
)
|
75 |
+
},
|
76 |
+
)
|
77 |
+
tokenizer_name_or_path: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={
|
80 |
+
"help": (
|
81 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
82 |
+
)
|
83 |
+
},
|
84 |
+
)
|
85 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
86 |
+
cache_dir: Optional[str] = field(
|
87 |
+
default=None,
|
88 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
89 |
+
)
|
90 |
+
use_fast_tokenizer: bool = field(
|
91 |
+
default=False,
|
92 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
93 |
+
)
|
94 |
+
torch_dtype: Optional[str] = field(
|
95 |
+
default="float16",
|
96 |
+
metadata={
|
97 |
+
"help": (
|
98 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
99 |
+
"dtype will be automatically derived from the model's weights."
|
100 |
+
),
|
101 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
102 |
+
},
|
103 |
+
)
|
104 |
+
device_map: Optional[str] = field(
|
105 |
+
default="auto",
|
106 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
107 |
+
)
|
108 |
+
trust_remote_code: bool = field(
|
109 |
+
default=True,
|
110 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
111 |
+
)
|
112 |
+
|
113 |
+
def __post_init__(self):
|
114 |
+
if self.model_type is None:
|
115 |
+
raise ValueError(
|
116 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
117 |
+
MODEL_CLASSES.keys()))
|
118 |
+
if self.model_name_or_path is None:
|
119 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
120 |
+
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class DataTrainingArguments:
|
124 |
+
"""
|
125 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
126 |
+
"""
|
127 |
+
|
128 |
+
dataset_name: Optional[str] = field(
|
129 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
130 |
+
)
|
131 |
+
dataset_config_name: Optional[str] = field(
|
132 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
133 |
+
)
|
134 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."})
|
135 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."})
|
136 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
|
137 |
+
max_train_samples: Optional[int] = field(
|
138 |
+
default=None,
|
139 |
+
metadata={
|
140 |
+
"help": (
|
141 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
142 |
+
"value if set."
|
143 |
+
)
|
144 |
+
},
|
145 |
+
)
|
146 |
+
max_eval_samples: Optional[int] = field(
|
147 |
+
default=None,
|
148 |
+
metadata={
|
149 |
+
"help": (
|
150 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
151 |
+
"value if set."
|
152 |
+
)
|
153 |
+
},
|
154 |
+
)
|
155 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
156 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
157 |
+
ignore_pad_token_for_loss: bool = field(
|
158 |
+
default=True,
|
159 |
+
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
|
160 |
+
)
|
161 |
+
overwrite_cache: bool = field(
|
162 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
163 |
+
)
|
164 |
+
validation_split_percentage: Optional[int] = field(
|
165 |
+
default=1,
|
166 |
+
metadata={
|
167 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
168 |
+
},
|
169 |
+
)
|
170 |
+
preprocessing_num_workers: Optional[int] = field(
|
171 |
+
default=None,
|
172 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
173 |
+
)
|
174 |
+
|
175 |
+
def __post_init__(self):
|
176 |
+
if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000:
|
177 |
+
logger.warning("You may set max_train_samples = -1 to run all samples in production.")
|
178 |
+
if self.max_source_length < 30:
|
179 |
+
raise ValueError("You must specify a valid max_source_length >= 30 to run training.")
|
180 |
+
|
181 |
+
|
182 |
+
@dataclass
|
183 |
+
class PeftArguments(TrainingArguments):
|
184 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
185 |
+
target_modules: Optional[str] = field(default="all")
|
186 |
+
lora_rank: Optional[int] = field(default=8)
|
187 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
188 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
189 |
+
modules_to_save: Optional[str] = field(default=None)
|
190 |
+
peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"})
|
191 |
+
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
|
192 |
+
|
193 |
+
|
194 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
195 |
+
"""Cast the output of the model to float"""
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
return super().forward(x).to(torch.float32)
|
199 |
+
|
200 |
+
|
201 |
+
@dataclass
|
202 |
+
class Conversation:
|
203 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
204 |
+
|
205 |
+
# The name of this template
|
206 |
+
name: str
|
207 |
+
# The system prompt
|
208 |
+
system_prompt: str
|
209 |
+
# All messages. format: list of [question, answer]
|
210 |
+
messages: Optional[List[Sequence[str]]]
|
211 |
+
# The roles of the speakers
|
212 |
+
roles: Optional[Sequence[str]]
|
213 |
+
# Conversation prompt
|
214 |
+
prompt: str
|
215 |
+
# Separator
|
216 |
+
sep: str
|
217 |
+
# Stop token, default is tokenizer.eos_token
|
218 |
+
stop_str: Optional[str] = "</s>"
|
219 |
+
|
220 |
+
def get_prompt(
|
221 |
+
self,
|
222 |
+
messages: Optional[List[Sequence[str]]] = None,
|
223 |
+
system_prompt: Optional[str] = ""
|
224 |
+
) -> str:
|
225 |
+
"""
|
226 |
+
Returns a string containing prompt without response.
|
227 |
+
"""
|
228 |
+
return "".join(self._format_example(messages, system_prompt))
|
229 |
+
|
230 |
+
def get_dialog(
|
231 |
+
self,
|
232 |
+
messages: Optional[List[Sequence[str]]] = None,
|
233 |
+
system_prompt: Optional[str] = ""
|
234 |
+
) -> List[str]:
|
235 |
+
"""
|
236 |
+
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
237 |
+
"""
|
238 |
+
return self._format_example(messages, system_prompt)
|
239 |
+
|
240 |
+
def _format_example(
|
241 |
+
self,
|
242 |
+
messages: Optional[List[Sequence[str]]] = None,
|
243 |
+
system_prompt: Optional[str] = ""
|
244 |
+
) -> List[str]:
|
245 |
+
system_prompt = system_prompt or self.system_prompt
|
246 |
+
system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt
|
247 |
+
messages = messages or self.messages
|
248 |
+
convs = []
|
249 |
+
for turn_idx, [user_query, bot_resp] in enumerate(messages):
|
250 |
+
if turn_idx == 0:
|
251 |
+
convs.append(system_prompt + self.prompt.format(query=user_query))
|
252 |
+
convs.append(bot_resp)
|
253 |
+
else:
|
254 |
+
convs.append(self.sep + self.prompt.format(query=user_query))
|
255 |
+
convs.append(bot_resp)
|
256 |
+
return convs
|
257 |
+
|
258 |
+
def append_message(self, query: str, answer: str):
|
259 |
+
"""Append a new message."""
|
260 |
+
self.messages.append([query, answer])
|
261 |
+
|
262 |
+
|
263 |
+
# A global registry for all conversation templates
|
264 |
+
conv_templates: Dict[str, Conversation] = {}
|
265 |
+
|
266 |
+
|
267 |
+
def register_conv_template(template: Conversation):
|
268 |
+
"""Register a new conversation template."""
|
269 |
+
conv_templates[template.name] = template
|
270 |
+
|
271 |
+
|
272 |
+
"""Vicuna v1.1 template
|
273 |
+
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
274 |
+
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
275 |
+
"""
|
276 |
+
register_conv_template(
|
277 |
+
Conversation(
|
278 |
+
name="vicuna",
|
279 |
+
system_prompt="A chat between a curious user and an artificial intelligence assistant. "
|
280 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
281 |
+
messages=[],
|
282 |
+
roles=("USER", "ASSISTANT"),
|
283 |
+
prompt="USER: {query} ASSISTANT: ",
|
284 |
+
sep="</s>",
|
285 |
+
)
|
286 |
+
)
|
287 |
+
|
288 |
+
"""Alpaca template"""
|
289 |
+
register_conv_template(
|
290 |
+
Conversation(
|
291 |
+
name="alpaca",
|
292 |
+
system_prompt="Below is an instruction that describes a task. "
|
293 |
+
"Write a response that appropriately completes the request.",
|
294 |
+
messages=[],
|
295 |
+
roles=("### Instruction", "### Response"),
|
296 |
+
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
297 |
+
sep="\n\n",
|
298 |
+
)
|
299 |
+
)
|
300 |
+
|
301 |
+
"""Baichuan-13B-Chat template
|
302 |
+
source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
|
303 |
+
Support: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
304 |
+
"""
|
305 |
+
register_conv_template(
|
306 |
+
Conversation(
|
307 |
+
name="baichuan-chat",
|
308 |
+
system_prompt="",
|
309 |
+
messages=[],
|
310 |
+
roles=("<reserved_102>", "<reserved_103>"),
|
311 |
+
prompt=" <reserved_102> {query} <reserved_103> ",
|
312 |
+
sep="</s>",
|
313 |
+
)
|
314 |
+
)
|
315 |
+
|
316 |
+
"""ziya template"""
|
317 |
+
register_conv_template(
|
318 |
+
Conversation(
|
319 |
+
name="ziya",
|
320 |
+
system_prompt="",
|
321 |
+
messages=[],
|
322 |
+
roles=("<human>", "<bot>"),
|
323 |
+
prompt="<human>:{query}\n<bot>:",
|
324 |
+
sep="\n",
|
325 |
+
)
|
326 |
+
)
|
327 |
+
|
328 |
+
"""Linly template"""
|
329 |
+
register_conv_template(
|
330 |
+
Conversation(
|
331 |
+
name="linly",
|
332 |
+
system_prompt="",
|
333 |
+
messages=[],
|
334 |
+
roles=("User", "Bot"),
|
335 |
+
prompt="User: {query}\nBot: ",
|
336 |
+
sep="\n",
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
"""ChatGLM1 template
|
341 |
+
source: https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1307
|
342 |
+
"""
|
343 |
+
register_conv_template(
|
344 |
+
Conversation(
|
345 |
+
name="chatglm",
|
346 |
+
system_prompt="",
|
347 |
+
messages=[],
|
348 |
+
roles=("问", "答"),
|
349 |
+
prompt="问:{query}\n答:",
|
350 |
+
sep="\n",
|
351 |
+
)
|
352 |
+
)
|
353 |
+
|
354 |
+
"""ChatGLM2 template
|
355 |
+
source: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1007
|
356 |
+
"""
|
357 |
+
register_conv_template(
|
358 |
+
# source:
|
359 |
+
Conversation(
|
360 |
+
name="chatglm2",
|
361 |
+
system_prompt="",
|
362 |
+
messages=[],
|
363 |
+
roles=("问", "答"),
|
364 |
+
prompt="问:{query}\n\n答:",
|
365 |
+
sep="\n\n",
|
366 |
+
)
|
367 |
+
)
|
368 |
+
|
369 |
+
"""Phoenix template"""
|
370 |
+
register_conv_template(
|
371 |
+
Conversation(
|
372 |
+
name="phoenix",
|
373 |
+
system_prompt="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
374 |
+
messages=[],
|
375 |
+
roles=("Human", "Assistant"),
|
376 |
+
prompt="Human: <s>{query}</s>Assistant: ",
|
377 |
+
sep="</s>",
|
378 |
+
)
|
379 |
+
)
|
380 |
+
|
381 |
+
"""belle template
|
382 |
+
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
383 |
+
"""
|
384 |
+
register_conv_template(
|
385 |
+
Conversation(
|
386 |
+
name="belle",
|
387 |
+
system_prompt="",
|
388 |
+
messages=[],
|
389 |
+
roles=("Human", "Belle"),
|
390 |
+
prompt="Human: {query}\n\nBelle: ",
|
391 |
+
sep="\n\n",
|
392 |
+
)
|
393 |
+
)
|
394 |
+
|
395 |
+
"""aquila template
|
396 |
+
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
397 |
+
"""
|
398 |
+
register_conv_template(
|
399 |
+
Conversation(
|
400 |
+
name="aquila",
|
401 |
+
system_prompt="A chat between a curious human and an artificial intelligence assistant. "
|
402 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
403 |
+
messages=[],
|
404 |
+
roles=("Human", "Assistant"),
|
405 |
+
prompt="Human: {query}###Assistant: ",
|
406 |
+
sep="###",
|
407 |
+
)
|
408 |
+
)
|
409 |
+
|
410 |
+
"""intern template
|
411 |
+
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
412 |
+
"""
|
413 |
+
register_conv_template(
|
414 |
+
Conversation(
|
415 |
+
name="intern",
|
416 |
+
system_prompt="",
|
417 |
+
messages=[],
|
418 |
+
roles=("<|User|>", "<|Bot|>"),
|
419 |
+
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
420 |
+
sep="<eoa>\n",
|
421 |
+
stop_str="<eoa>",
|
422 |
+
)
|
423 |
+
)
|
424 |
+
|
425 |
+
"""StarChat template"""
|
426 |
+
register_conv_template(
|
427 |
+
Conversation(
|
428 |
+
name="starchat",
|
429 |
+
system_prompt="<system>\n",
|
430 |
+
messages=[],
|
431 |
+
roles=("<|user|>", "<|assistant|>"),
|
432 |
+
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
433 |
+
sep="<|end|>\n",
|
434 |
+
stop_str="<|end|>",
|
435 |
+
)
|
436 |
+
)
|
437 |
+
|
438 |
+
"""llama2 template
|
439 |
+
reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
440 |
+
"""
|
441 |
+
register_conv_template(
|
442 |
+
Conversation(
|
443 |
+
name="llama2",
|
444 |
+
system_prompt="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
445 |
+
"Always answer as helpfully as possible, while being safe. "
|
446 |
+
"Your answers should not include any harmful, unethical, racist, sexist, "
|
447 |
+
"toxic, dangerous, or illegal content. "
|
448 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
449 |
+
"If a question does not make any sense, or is not factually coherent, "
|
450 |
+
"explain why instead of answering something not correct. "
|
451 |
+
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
452 |
+
messages=[],
|
453 |
+
roles=("[INST]", "[/INST]"),
|
454 |
+
prompt=" [INST] {query} [/INST] ",
|
455 |
+
sep="</s>",
|
456 |
+
)
|
457 |
+
)
|
458 |
+
|
459 |
+
"""llama2-zh template
|
460 |
+
Sources: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
461 |
+
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
462 |
+
"""
|
463 |
+
register_conv_template(
|
464 |
+
Conversation(
|
465 |
+
name="llama2-zh",
|
466 |
+
system_prompt="<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n",
|
467 |
+
messages=[],
|
468 |
+
roles=("[INST]", "[/INST]"),
|
469 |
+
prompt=" [INST] {query} [/INST] ",
|
470 |
+
sep="</s>",
|
471 |
+
)
|
472 |
+
)
|
473 |
+
"""XVERSE template
|
474 |
+
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
|
475 |
+
"""
|
476 |
+
register_conv_template(
|
477 |
+
Conversation(
|
478 |
+
name="xverse",
|
479 |
+
system_prompt="",
|
480 |
+
messages=[],
|
481 |
+
roles=("Human", "Assistant"),
|
482 |
+
prompt="Human: {query}\n\nAssistant: ",
|
483 |
+
sep="</s>",
|
484 |
+
)
|
485 |
+
)
|
486 |
+
|
487 |
+
"""Qwen template
|
488 |
+
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
489 |
+
chatml: https://xbot123.com/645a461b922f176d7cfdbc2d/
|
490 |
+
"""
|
491 |
+
register_conv_template(
|
492 |
+
Conversation(
|
493 |
+
name="chatml",
|
494 |
+
system_prompt="You are a helpful assistant.",
|
495 |
+
messages=[],
|
496 |
+
roles=("user", "assistant"),
|
497 |
+
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
|
498 |
+
sep="<|im_end|>\n",
|
499 |
+
stop_str="<|im_end|>",
|
500 |
+
)
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
+
def get_conv_template(name: str) -> Conversation:
|
505 |
+
"""Get a conversation template."""
|
506 |
+
return conv_templates[name]
|
507 |
+
|
508 |
+
|
509 |
+
class SavePeftModelTrainer(Trainer):
|
510 |
+
"""
|
511 |
+
Trainer for lora models
|
512 |
+
"""
|
513 |
+
|
514 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
515 |
+
"""Save the LoRA model."""
|
516 |
+
os.makedirs(output_dir, exist_ok=True)
|
517 |
+
if self.args.local_rank in [-1, 0]:
|
518 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
519 |
+
self.model.save_pretrained(output_dir)
|
520 |
+
|
521 |
+
|
522 |
+
def save_model(output_dir, model, tokenizer, args):
|
523 |
+
"""Save the model and the tokenizer."""
|
524 |
+
os.makedirs(output_dir, exist_ok=True)
|
525 |
+
|
526 |
+
# Take care of distributed/parallel training
|
527 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
528 |
+
if args.local_rank in [-1, 0]:
|
529 |
+
model_to_save.save_pretrained(output_dir)
|
530 |
+
tokenizer.save_pretrained(output_dir)
|
531 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
532 |
+
|
533 |
+
|
534 |
+
def print_trainable_parameters(model):
|
535 |
+
"""
|
536 |
+
Prints the number of trainable parameters in the model.
|
537 |
+
"""
|
538 |
+
trainable_params = 0
|
539 |
+
all_param = 0
|
540 |
+
for _, param in model.named_parameters():
|
541 |
+
all_param += param.numel()
|
542 |
+
if param.requires_grad:
|
543 |
+
trainable_params += param.numel()
|
544 |
+
print(
|
545 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
550 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
551 |
+
cls = torch.nn.Linear
|
552 |
+
if int4 or int8:
|
553 |
+
import bitsandbytes as bnb
|
554 |
+
if int4:
|
555 |
+
cls = bnb.nn.Linear4bit
|
556 |
+
elif int8:
|
557 |
+
cls = bnb.nn.Linear8bitLt
|
558 |
+
lora_module_names = set()
|
559 |
+
for name, module in peft_model.named_modules():
|
560 |
+
if isinstance(module, cls):
|
561 |
+
# last layer is not add to lora_module_names
|
562 |
+
if 'lm_head' in name:
|
563 |
+
continue
|
564 |
+
if 'output_layer' in name:
|
565 |
+
continue
|
566 |
+
names = name.split('.')
|
567 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
568 |
+
return sorted(lora_module_names)
|
569 |
+
|
570 |
+
|
571 |
+
def main():
|
572 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
573 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
574 |
+
|
575 |
+
logger.info(f"Model args: {model_args}")
|
576 |
+
logger.info(f"Data args: {data_args}")
|
577 |
+
logger.info(f"Training args: {training_args}")
|
578 |
+
logger.info(
|
579 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
580 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
581 |
+
)
|
582 |
+
|
583 |
+
# Set seed before initializing model.
|
584 |
+
set_seed(training_args.seed)
|
585 |
+
|
586 |
+
if not model_args.model_type:
|
587 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
588 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
589 |
+
|
590 |
+
# Load tokenizer
|
591 |
+
tokenizer_kwargs = {
|
592 |
+
"cache_dir": model_args.cache_dir,
|
593 |
+
"use_fast": model_args.use_fast_tokenizer,
|
594 |
+
"trust_remote_code": model_args.trust_remote_code,
|
595 |
+
}
|
596 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
597 |
+
if not tokenizer_name_or_path:
|
598 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
599 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
600 |
+
prompt_template = get_conv_template(data_args.template_name)
|
601 |
+
if tokenizer.eos_token_id is None:
|
602 |
+
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
|
603 |
+
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
604 |
+
if tokenizer.pad_token_id is None:
|
605 |
+
if tokenizer.unk_token_id is not None:
|
606 |
+
tokenizer.pad_token = tokenizer.unk_token
|
607 |
+
else:
|
608 |
+
tokenizer.pad_token = tokenizer.eos_token
|
609 |
+
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
610 |
+
|
611 |
+
logger.debug(f"Tokenizer: {tokenizer}")
|
612 |
+
IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
613 |
+
|
614 |
+
# Get datasets
|
615 |
+
if data_args.dataset_name is not None:
|
616 |
+
# Downloading and loading a dataset from the hub.
|
617 |
+
raw_datasets = load_dataset(
|
618 |
+
data_args.dataset_name,
|
619 |
+
data_args.dataset_config_name,
|
620 |
+
cache_dir=model_args.cache_dir,
|
621 |
+
)
|
622 |
+
if "validation" not in raw_datasets.keys():
|
623 |
+
raw_datasets["validation"] = load_dataset(
|
624 |
+
data_args.dataset_name,
|
625 |
+
data_args.dataset_config_name,
|
626 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
627 |
+
cache_dir=model_args.cache_dir,
|
628 |
+
)
|
629 |
+
raw_datasets["train"] = load_dataset(
|
630 |
+
data_args.dataset_name,
|
631 |
+
data_args.dataset_config_name,
|
632 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
633 |
+
cache_dir=model_args.cache_dir,
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
# Loading a dataset from local files.
|
637 |
+
data_files = {}
|
638 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
639 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
640 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
641 |
+
logger.info(f"train files: {train_data_files}")
|
642 |
+
data_files["train"] = train_data_files
|
643 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
644 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
645 |
+
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
|
646 |
+
logger.info(f"eval files: {eval_data_files}")
|
647 |
+
data_files["validation"] = eval_data_files
|
648 |
+
raw_datasets = load_dataset(
|
649 |
+
'json',
|
650 |
+
data_files=data_files,
|
651 |
+
cache_dir=model_args.cache_dir,
|
652 |
+
)
|
653 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
654 |
+
if "validation" not in raw_datasets.keys():
|
655 |
+
raw_datasets["validation"] = load_dataset(
|
656 |
+
'json',
|
657 |
+
data_files=data_files,
|
658 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
659 |
+
cache_dir=model_args.cache_dir,
|
660 |
+
)
|
661 |
+
raw_datasets["train"] = load_dataset(
|
662 |
+
'json',
|
663 |
+
data_files=data_files,
|
664 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
665 |
+
cache_dir=model_args.cache_dir,
|
666 |
+
)
|
667 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
668 |
+
|
669 |
+
# Preprocessing the datasets
|
670 |
+
max_source_length = data_args.max_source_length
|
671 |
+
max_target_length = data_args.max_target_length
|
672 |
+
max_length = max_source_length + max_target_length
|
673 |
+
|
674 |
+
def preprocess_function(examples):
|
675 |
+
"""
|
676 |
+
Preprocessing the datasets.
|
677 |
+
part of code modified from https://github.com/lm-sys/FastChat
|
678 |
+
"""
|
679 |
+
input_ids_list = []
|
680 |
+
targets_list = []
|
681 |
+
roles = ["human", "gpt"]
|
682 |
+
|
683 |
+
def get_dialog(examples):
|
684 |
+
for i, source in enumerate(examples['conversations']):
|
685 |
+
if len(source) < 2:
|
686 |
+
continue
|
687 |
+
data_role = source[0].get("from", "")
|
688 |
+
if data_role not in roles or data_role != roles[0]:
|
689 |
+
# Skip the first one if it is not from human
|
690 |
+
source = source[1:]
|
691 |
+
if len(source) < 2:
|
692 |
+
continue
|
693 |
+
messages = []
|
694 |
+
for j, sentence in enumerate(source):
|
695 |
+
data_role = sentence.get("from", "")
|
696 |
+
if data_role not in roles:
|
697 |
+
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
|
698 |
+
break
|
699 |
+
if data_role == roles[j % 2]:
|
700 |
+
messages.append(sentence["value"])
|
701 |
+
if len(messages) < 2 or len(messages) % 2 != 0:
|
702 |
+
continue
|
703 |
+
# Convert the list to pairs of elements
|
704 |
+
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
|
705 |
+
yield prompt_template.get_dialog(history_messages)
|
706 |
+
|
707 |
+
for dialog in get_dialog(examples):
|
708 |
+
input_ids, labels = [], []
|
709 |
+
|
710 |
+
for i in range(len(dialog) // 2):
|
711 |
+
source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0))
|
712 |
+
target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False)
|
713 |
+
|
714 |
+
if len(source_ids) > max_source_length:
|
715 |
+
source_ids = source_ids[:max_source_length]
|
716 |
+
if len(target_ids) > max_target_length - 1: # eos token
|
717 |
+
target_ids = target_ids[:max_target_length - 1]
|
718 |
+
if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id:
|
719 |
+
source_ids = source_ids[1:]
|
720 |
+
if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id:
|
721 |
+
target_ids = target_ids[:-1]
|
722 |
+
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
723 |
+
break
|
724 |
+
|
725 |
+
input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn
|
726 |
+
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
727 |
+
|
728 |
+
input_ids_list.append(input_ids)
|
729 |
+
targets_list.append(labels)
|
730 |
+
|
731 |
+
return dict(
|
732 |
+
input_ids=input_ids_list,
|
733 |
+
labels=targets_list,
|
734 |
+
)
|
735 |
+
|
736 |
+
def filter_empty_labels(example):
|
737 |
+
"""Remove empty labels dataset."""
|
738 |
+
return not all(label == IGNORE_INDEX for label in example["labels"])
|
739 |
+
|
740 |
+
train_dataset = None
|
741 |
+
max_train_samples = 0
|
742 |
+
if training_args.do_train:
|
743 |
+
if "train" not in raw_datasets:
|
744 |
+
raise ValueError("--do_train requires a train dataset")
|
745 |
+
train_dataset = raw_datasets['train']
|
746 |
+
max_train_samples = len(train_dataset)
|
747 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
748 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
749 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
750 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
751 |
+
with training_args.main_process_first(desc="Train dataset tokenization"):
|
752 |
+
train_dataset = train_dataset.shuffle().map(
|
753 |
+
preprocess_function,
|
754 |
+
batched=True,
|
755 |
+
num_proc=data_args.preprocessing_num_workers,
|
756 |
+
remove_columns=train_dataset.column_names,
|
757 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
758 |
+
desc="Running tokenizer on dataset",
|
759 |
+
)
|
760 |
+
train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
|
761 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
762 |
+
logger.debug("Tokenized training example:")
|
763 |
+
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(train_dataset[0]['input_ids'])}")
|
764 |
+
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id
|
765 |
+
for label in list(train_dataset[0]['labels'])]
|
766 |
+
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
|
767 |
+
|
768 |
+
eval_dataset = None
|
769 |
+
max_eval_samples = 0
|
770 |
+
if training_args.do_eval:
|
771 |
+
with training_args.main_process_first(desc="Eval dataset tokenization"):
|
772 |
+
if "validation" not in raw_datasets:
|
773 |
+
raise ValueError("--do_eval requires a validation dataset")
|
774 |
+
eval_dataset = raw_datasets["validation"]
|
775 |
+
max_eval_samples = len(eval_dataset)
|
776 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
777 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
778 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
779 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
780 |
+
eval_dataset = eval_dataset.map(
|
781 |
+
preprocess_function,
|
782 |
+
batched=True,
|
783 |
+
num_proc=data_args.preprocessing_num_workers,
|
784 |
+
remove_columns=eval_dataset.column_names,
|
785 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
786 |
+
desc="Running tokenizer on dataset",
|
787 |
+
)
|
788 |
+
eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
|
789 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
790 |
+
logger.debug("Tokenized eval example:")
|
791 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
|
792 |
+
|
793 |
+
# Load model
|
794 |
+
if model_args.model_name_or_path:
|
795 |
+
torch_dtype = (
|
796 |
+
model_args.torch_dtype
|
797 |
+
if model_args.torch_dtype in ["auto", None]
|
798 |
+
else getattr(torch, model_args.torch_dtype)
|
799 |
+
)
|
800 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
801 |
+
ddp = world_size != 1
|
802 |
+
if ddp:
|
803 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
804 |
+
if training_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
|
805 |
+
logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
|
806 |
+
config = config_class.from_pretrained(
|
807 |
+
model_args.model_name_or_path,
|
808 |
+
trust_remote_code=model_args.trust_remote_code,
|
809 |
+
torch_dtype=torch_dtype,
|
810 |
+
cache_dir=model_args.cache_dir
|
811 |
+
)
|
812 |
+
model = model_class.from_pretrained(
|
813 |
+
model_args.model_name_or_path,
|
814 |
+
config=config,
|
815 |
+
load_in_8bit=model_args.load_in_8bit,
|
816 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
817 |
+
device_map=model_args.device_map,
|
818 |
+
trust_remote_code=model_args.trust_remote_code,
|
819 |
+
quantization_config=BitsAndBytesConfig(
|
820 |
+
load_in_4bit=True,
|
821 |
+
bnb_4bit_use_double_quant=True,
|
822 |
+
bnb_4bit_quant_type="nf4",
|
823 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
824 |
+
) if training_args.qlora else None,
|
825 |
+
)
|
826 |
+
if hasattr(model, 'lm_head'):
|
827 |
+
model.lm_head = CastOutputToFloat(model.lm_head)
|
828 |
+
else:
|
829 |
+
raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model")
|
830 |
+
|
831 |
+
if training_args.use_peft:
|
832 |
+
logger.info("Fine-tuning method: LoRA(PEFT)")
|
833 |
+
if training_args.peft_path is not None:
|
834 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
835 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
836 |
+
else:
|
837 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
838 |
+
if target_modules and 'all' in target_modules:
|
839 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
840 |
+
modules_to_save = training_args.modules_to_save
|
841 |
+
if modules_to_save is not None:
|
842 |
+
modules_to_save = modules_to_save.split(',')
|
843 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
844 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
845 |
+
peft_config = LoraConfig(
|
846 |
+
task_type=TaskType.CAUSAL_LM,
|
847 |
+
target_modules=target_modules,
|
848 |
+
inference_mode=False,
|
849 |
+
r=training_args.lora_rank,
|
850 |
+
lora_alpha=training_args.lora_alpha,
|
851 |
+
lora_dropout=training_args.lora_dropout,
|
852 |
+
modules_to_save=modules_to_save)
|
853 |
+
model = get_peft_model(model, peft_config)
|
854 |
+
if model_args.load_in_8bit:
|
855 |
+
model = prepare_model_for_int8_training(model)
|
856 |
+
model.print_trainable_parameters()
|
857 |
+
else:
|
858 |
+
logger.info("Fine-tuning method: Full parameters training")
|
859 |
+
model = model.float()
|
860 |
+
print_trainable_parameters(model)
|
861 |
+
logger.debug(f"Model: {model}")
|
862 |
+
|
863 |
+
# Initialize our Trainer
|
864 |
+
if training_args.gradient_checkpointing:
|
865 |
+
model.gradient_checkpointing_enable()
|
866 |
+
model.config.use_cache = False
|
867 |
+
else:
|
868 |
+
model.config.use_cache = True
|
869 |
+
model.enable_input_require_grads()
|
870 |
+
if not ddp and torch.cuda.device_count() > 1:
|
871 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
872 |
+
model.is_parallelizable = True
|
873 |
+
model.model_parallel = True
|
874 |
+
|
875 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
876 |
+
# Initialize our Trainer
|
877 |
+
trainer = SavePeftModelTrainer(
|
878 |
+
model=model,
|
879 |
+
args=training_args,
|
880 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
881 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
882 |
+
tokenizer=tokenizer,
|
883 |
+
data_collator=data_collator,
|
884 |
+
)
|
885 |
+
|
886 |
+
# Training
|
887 |
+
if training_args.do_train:
|
888 |
+
logger.info("*** Train ***")
|
889 |
+
sample = next(iter(trainer.get_train_dataloader()))
|
890 |
+
logger.debug(f"Train dataloader example: {sample}")
|
891 |
+
logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
|
892 |
+
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
|
893 |
+
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
|
894 |
+
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
|
895 |
+
checkpoint = None
|
896 |
+
if training_args.resume_from_checkpoint is not None:
|
897 |
+
checkpoint = training_args.resume_from_checkpoint
|
898 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
899 |
+
|
900 |
+
metrics = train_result.metrics
|
901 |
+
metrics["train_samples"] = max_train_samples
|
902 |
+
logger.debug(f"Training metrics: {metrics}")
|
903 |
+
trainer.log_metrics("train", metrics)
|
904 |
+
trainer.save_metrics("train", metrics)
|
905 |
+
model.config.use_cache = True # enable cache after training
|
906 |
+
trainer.save_state()
|
907 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
908 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
909 |
+
|
910 |
+
# Evaluation
|
911 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
912 |
+
logger.info("*** Evaluate ***")
|
913 |
+
metrics = trainer.evaluate()
|
914 |
+
|
915 |
+
metrics["eval_samples"] = max_eval_samples
|
916 |
+
try:
|
917 |
+
perplexity = math.exp(metrics["eval_loss"])
|
918 |
+
except OverflowError:
|
919 |
+
perplexity = float("inf")
|
920 |
+
metrics["perplexity"] = perplexity
|
921 |
+
logger.debug(f"Eval metrics: {metrics}")
|
922 |
+
trainer.log_metrics("eval", metrics)
|
923 |
+
trainer.save_metrics("eval", metrics)
|
924 |
+
|
925 |
+
|
926 |
+
if __name__ == "__main__":
|
927 |
+
main()
|