kisejin
commited on
Commit
·
795c49e
1
Parent(s):
02e64cb
initial: create version skipbert for mates
Browse files- template_FL/.gitignore +175 -0
- template_FL/LICENSE +21 -0
- template_FL/README.md +1 -0
- template_FL/requirements.txt +36 -0
- template_FL/src/environment.yml +565 -0
- template_FL/src/ex.env.example +3 -0
- template_FL/src/fedllm/Untitled.ipynb +861 -0
- template_FL/src/fedllm/__init__.py +1 -0
- template_FL/src/fedllm/client_app.py +335 -0
- template_FL/src/fedllm/data_domains.py +281 -0
- template_FL/src/fedllm/dataset.py +122 -0
- template_FL/src/fedllm/flwr_mods.py +49 -0
- template_FL/src/fedllm/make_data.py +101 -0
- template_FL/src/fedllm/metrics.py +73 -0
- template_FL/src/fedllm/models.py +200 -0
- template_FL/src/fedllm/myaggregation.py +416 -0
- template_FL/src/fedllm/myfedavg.py +295 -0
- template_FL/src/fedllm/server_app.py +309 -0
- template_FL/src/fedllm/skipbert/__init__.py +0 -0
- template_FL/src/fedllm/skipbert/modeling.py +922 -0
- template_FL/src/fedllm/skipbert/plot.py +173 -0
- template_FL/src/fedllm/templates/alpaca.json +6 -0
- template_FL/src/fedllm/trainer.py +476 -0
- template_FL/src/fedllm/utils.py +54 -0
- template_FL/src/pyproject.toml +180 -0
- template_FL/src/tesst.ipynb +40 -0
template_FL/.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# PyPI configuration file
|
171 |
+
.pypirc
|
172 |
+
|
173 |
+
# Remove uncenssary output files
|
174 |
+
results/
|
175 |
+
wandb/
|
template_FL/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 KiseJin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
template_FL/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# template_FL
|
template_FL/requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bitsandbytes==0.45.0
|
2 |
+
cryptography==42.0.8
|
3 |
+
cupy-cuda12x==13.3.0
|
4 |
+
docker-pycreds==0.4.0
|
5 |
+
faiss-cpu==1.9.0.post1
|
6 |
+
fastrlock==0.8.3
|
7 |
+
# flash-attn==2.6.3
|
8 |
+
flwr==1.14.0
|
9 |
+
flwr-datasets==0.5.0
|
10 |
+
fsspec<=2024.10.0,>=2023.1.0
|
11 |
+
gitdb==4.0.12
|
12 |
+
gitpython==3.1.44
|
13 |
+
grpcio==1.64.3
|
14 |
+
iterators==0.0.2
|
15 |
+
multiprocess==0.70.16
|
16 |
+
pathspec==0.12.1
|
17 |
+
protobuf==4.25.5
|
18 |
+
prv-accountant==0.2.0
|
19 |
+
pycryptodome==3.21.0
|
20 |
+
ray==2.10.0
|
21 |
+
rouge-score==0.1.2
|
22 |
+
sentry-sdk==2.19.2
|
23 |
+
setproctitle==1.3.4
|
24 |
+
shellingham==1.5.4
|
25 |
+
smmap==5.0.2
|
26 |
+
sympy==1.13.1
|
27 |
+
thop==0.1.1-2209072238
|
28 |
+
tomli-w==1.1.0
|
29 |
+
typer==0.12.5
|
30 |
+
wandb==0.19.3
|
31 |
+
python-dotenv
|
32 |
+
omegaconf
|
33 |
+
trl
|
34 |
+
evaluate
|
35 |
+
google
|
36 |
+
deepspeed
|
template_FL/src/environment.yml
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: fedllm
|
2 |
+
channels:
|
3 |
+
- xformers
|
4 |
+
- pytorch
|
5 |
+
- nvidia
|
6 |
+
- defaults
|
7 |
+
- conda-forge
|
8 |
+
- https://repo.anaconda.com/pkgs/main
|
9 |
+
- https://repo.anaconda.com/pkgs/r
|
10 |
+
dependencies:
|
11 |
+
- _libgcc_mutex=0.1=conda_forge
|
12 |
+
- _openmp_mutex=4.5=2_gnu
|
13 |
+
- about-time=4.2.1=pyhd8ed1ab_1
|
14 |
+
- absl-py=2.1.0=pyhd8ed1ab_1
|
15 |
+
- accelerate=1.2.1=pyhd8ed1ab_1
|
16 |
+
- aiohappyeyeballs=2.4.4=pyhd8ed1ab_1
|
17 |
+
- aiohttp=3.11.11=py311h2dc5d0c_0
|
18 |
+
- aiosignal=1.3.2=pyhd8ed1ab_0
|
19 |
+
- alive-progress=3.2.0=pyhd8ed1ab_0
|
20 |
+
- alsa-lib=1.2.8=h166bdaf_0
|
21 |
+
- annotated-types=0.7.0=pyhd8ed1ab_1
|
22 |
+
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
|
23 |
+
- anyio=4.8.0=pyhd8ed1ab_0
|
24 |
+
- aom=3.5.0=h27087fc_0
|
25 |
+
- argon2-cffi=23.1.0=pyhd8ed1ab_1
|
26 |
+
- argon2-cffi-bindings=21.2.0=py311h9ecbd09_5
|
27 |
+
- arrow=1.3.0=pyhd8ed1ab_1
|
28 |
+
- asttokens=3.0.0=pyhd8ed1ab_1
|
29 |
+
- async-lru=2.0.4=pyhd8ed1ab_1
|
30 |
+
- async-timeout=4.0.3=pyhd8ed1ab_0
|
31 |
+
- attr=2.5.1=h166bdaf_1
|
32 |
+
- attrs=24.3.0=pyh71513ae_0
|
33 |
+
- autograd=1.7.0=pyhd8ed1ab_1
|
34 |
+
- aws-c-auth=0.7.3=he2921ad_3
|
35 |
+
- aws-c-cal=0.6.2=hc309b26_1
|
36 |
+
- aws-c-common=0.9.0=hd590300_0
|
37 |
+
- aws-c-compression=0.2.17=h4d4d85c_2
|
38 |
+
- aws-c-event-stream=0.3.2=h2e3709c_0
|
39 |
+
- aws-c-http=0.7.12=hc865f51_1
|
40 |
+
- aws-c-io=0.13.32=h1a03231_3
|
41 |
+
- aws-c-mqtt=0.9.5=h3a0376c_1
|
42 |
+
- aws-c-s3=0.3.14=h1678ad6_3
|
43 |
+
- aws-c-sdkutils=0.1.12=h4d4d85c_1
|
44 |
+
- aws-checksums=0.1.17=h4d4d85c_1
|
45 |
+
- aws-crt-cpp=0.23.0=h40cdbb9_5
|
46 |
+
- aws-sdk-cpp=1.10.57=h6f6b8fa_21
|
47 |
+
- babel=2.16.0=pyhd8ed1ab_1
|
48 |
+
- beautifulsoup4=4.12.3=pyha770c72_1
|
49 |
+
- binaryornot=0.4.4=pyhd8ed1ab_2
|
50 |
+
- blas=1.0=mkl
|
51 |
+
- bleach=6.2.0=pyhd8ed1ab_3
|
52 |
+
- bleach-with-css=6.2.0=hd8ed1ab_3
|
53 |
+
- blosc=1.21.5=h0f2a231_0
|
54 |
+
- boltons=24.0.0=pyhd8ed1ab_1
|
55 |
+
- brotli=1.0.9=h166bdaf_9
|
56 |
+
- brotli-bin=1.0.9=h166bdaf_9
|
57 |
+
- brotli-python=1.0.9=py311ha362b79_9
|
58 |
+
- brunsli=0.1=h9c3ff4c_0
|
59 |
+
- bzip2=1.0.8=h5eee18b_6
|
60 |
+
- c-ares=1.34.4=hb9d3cd8_0
|
61 |
+
- c-blosc2=2.12.0=hb4ffafa_0
|
62 |
+
- ca-certificates=2024.12.31=h06a4308_0
|
63 |
+
- cached-property=1.5.2=hd8ed1ab_1
|
64 |
+
- cached_property=1.5.2=pyha770c72_1
|
65 |
+
- cachetools=5.5.0=pyhd8ed1ab_1
|
66 |
+
- cairo=1.16.0=ha61ee94_1014
|
67 |
+
- certifi=2024.12.14=pyhd8ed1ab_0
|
68 |
+
- cffi=1.17.1=py311hf29c0ef_0
|
69 |
+
- cfitsio=4.2.0=hd9d235c_0
|
70 |
+
- chardet=5.2.0=py311h38be061_2
|
71 |
+
- charls=2.4.2=h59595ed_0
|
72 |
+
- charset-normalizer=3.4.1=pyhd8ed1ab_0
|
73 |
+
- click=8.1.8=pyh707e725_0
|
74 |
+
- cma=3.2.2=pyh050c7b8_0
|
75 |
+
- colorama=0.4.6=pyhd8ed1ab_1
|
76 |
+
- comm=0.2.2=pyhd8ed1ab_1
|
77 |
+
- conda=23.7.4=py311h38be061_0
|
78 |
+
- conda-package-handling=2.4.0=pyh7900ff3_2
|
79 |
+
- conda-package-streaming=0.11.0=pyhd8ed1ab_0
|
80 |
+
- contourpy=1.3.1=py311hd18a35c_0
|
81 |
+
- cookiecutter=2.6.0=pyhd8ed1ab_1
|
82 |
+
- cpp-expected=1.1.0=hf52228f_0
|
83 |
+
- cuda-cudart=12.4.127=0
|
84 |
+
- cuda-cupti=12.4.127=0
|
85 |
+
- cuda-libraries=12.4.1=0
|
86 |
+
- cuda-nvrtc=12.4.127=0
|
87 |
+
- cuda-nvtx=12.4.127=0
|
88 |
+
- cuda-opencl=12.6.77=0
|
89 |
+
- cuda-runtime=12.4.1=0
|
90 |
+
- cuda-version=12.6=3
|
91 |
+
- cycler=0.12.1=pyhd8ed1ab_1
|
92 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
93 |
+
- dataclasses-json=0.6.7=pyhd8ed1ab_1
|
94 |
+
- datasets=2.19.2=pyhd8ed1ab_0
|
95 |
+
- dav1d=1.2.1=hd590300_0
|
96 |
+
- dbus=1.13.6=h5008d03_3
|
97 |
+
- debugpy=1.8.11=py311hfdbb021_0
|
98 |
+
- decorator=5.1.1=pyhd8ed1ab_1
|
99 |
+
- deepspeed=0.16.2=cpu_py311hd0a74d0_0
|
100 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
101 |
+
- deprecated=1.2.15=pyhd8ed1ab_1
|
102 |
+
- dill=0.3.8=pyhd8ed1ab_0
|
103 |
+
- distro=1.9.0=pyhd8ed1ab_1
|
104 |
+
- docstring_parser=0.16=pyhd8ed1ab_0
|
105 |
+
- einops=0.8.0=pyhd8ed1ab_1
|
106 |
+
- entrypoints=0.4=pyhd8ed1ab_1
|
107 |
+
- eval_type_backport=0.2.2=pyha770c72_0
|
108 |
+
- evaluate=0.4.1=pyhd8ed1ab_0
|
109 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
110 |
+
- executing=2.1.0=pyhd8ed1ab_1
|
111 |
+
- expat=2.6.4=h5888daf_0
|
112 |
+
- ffmpeg=4.3=hf484d3e_0
|
113 |
+
- fftw=3.3.10=nompi_hf1063bd_110
|
114 |
+
- filelock=3.16.1=pyhd8ed1ab_1
|
115 |
+
- fmt=10.2.1=h00ab1b0_0
|
116 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
117 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
118 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
119 |
+
- font-ttf-ubuntu=0.83=h77eed37_3
|
120 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
121 |
+
- fonts-conda-ecosystem=1=0
|
122 |
+
- fonts-conda-forge=1=0
|
123 |
+
- fonttools=4.55.3=py311h2dc5d0c_1
|
124 |
+
- fqdn=1.5.1=pyhd8ed1ab_1
|
125 |
+
- freetype=2.12.1=h267a509_2
|
126 |
+
- frozendict=2.4.6=py311h9ecbd09_0
|
127 |
+
- frozenlist=1.5.0=py311h9ecbd09_0
|
128 |
+
- functorch=2.0.0=pyhd8ed1ab_0
|
129 |
+
- fvcore=0.1.5.post20221221=pyhd8ed1ab_0
|
130 |
+
- gdown=5.2.0=pyhd8ed1ab_1
|
131 |
+
- gettext=0.22.5=he02047a_3
|
132 |
+
- gettext-tools=0.22.5=he02047a_3
|
133 |
+
- gflags=2.2.2=h5888daf_1005
|
134 |
+
- giflib=5.2.2=hd590300_0
|
135 |
+
- glib=2.78.4=hfc55251_0
|
136 |
+
- glib-tools=2.78.4=hfc55251_0
|
137 |
+
- glog=0.6.0=h6f12383_0
|
138 |
+
- gmp=6.3.0=hac33072_2
|
139 |
+
- gmpy2=2.1.5=py311h0f6cedb_3
|
140 |
+
- gnutls=3.6.13=h85f3911_1
|
141 |
+
- grapheme=0.6.0=pyhd8ed1ab_1
|
142 |
+
- graphite2=1.3.13=h59595ed_1003
|
143 |
+
- greenlet=3.1.1=py311hfdbb021_1
|
144 |
+
- gst-plugins-base=1.22.0=h4243ec0_2
|
145 |
+
- gstreamer=1.22.0=h25f0c4b_2
|
146 |
+
- gstreamer-orc=0.4.40=hb9d3cd8_0
|
147 |
+
- h11=0.14.0=pyhd8ed1ab_1
|
148 |
+
- h2=4.1.0=pyhd8ed1ab_1
|
149 |
+
- harfbuzz=6.0.0=h8e241bc_0
|
150 |
+
- hjson-py=3.1.0=pyhd8ed1ab_1
|
151 |
+
- hpack=4.0.0=pyhd8ed1ab_1
|
152 |
+
- httpcore=1.0.7=pyh29332c3_1
|
153 |
+
- httpx=0.28.1=pyhd8ed1ab_0
|
154 |
+
- huggingface_hub=0.27.1=pyhd8ed1ab_0
|
155 |
+
- hyperframe=6.0.1=pyhd8ed1ab_1
|
156 |
+
- icu=70.1=h27087fc_0
|
157 |
+
- idna=3.10=pyhd8ed1ab_1
|
158 |
+
- imagecodecs=2023.1.23=py311ha5a3c35_0
|
159 |
+
- imageio=2.36.1=pyh12aca89_1
|
160 |
+
- importlib-metadata=8.5.0=pyha770c72_1
|
161 |
+
- importlib_metadata=8.5.0=hd8ed1ab_1
|
162 |
+
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
163 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
164 |
+
- ipykernel=6.29.5=pyh3099207_0
|
165 |
+
- ipython=8.31.0=pyh707e725_0
|
166 |
+
- ipywidgets=8.1.5=pyhd8ed1ab_1
|
167 |
+
- isoduration=20.11.0=pyhd8ed1ab_1
|
168 |
+
- jack=1.9.22=h11f4161_0
|
169 |
+
- jedi=0.19.2=pyhd8ed1ab_1
|
170 |
+
- jinja2=3.1.5=pyhd8ed1ab_0
|
171 |
+
- jiter=0.8.2=py311h9e33e62_0
|
172 |
+
- joblib=1.4.2=pyhd8ed1ab_1
|
173 |
+
- jpeg=9e=h166bdaf_2
|
174 |
+
- json5=0.10.0=pyhd8ed1ab_1
|
175 |
+
- jsonpatch=1.33=pyhd8ed1ab_1
|
176 |
+
- jsonpointer=3.0.0=py311h38be061_1
|
177 |
+
- jsonschema=4.23.0=pyhd8ed1ab_1
|
178 |
+
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
|
179 |
+
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
|
180 |
+
- jupyter=1.1.1=pyhd8ed1ab_1
|
181 |
+
- jupyter-lsp=2.2.5=pyhd8ed1ab_1
|
182 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_1
|
183 |
+
- jupyter_console=6.6.3=pyhd8ed1ab_1
|
184 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
185 |
+
- jupyter_events=0.11.0=pyhd8ed1ab_0
|
186 |
+
- jupyter_server=2.15.0=pyhd8ed1ab_0
|
187 |
+
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
|
188 |
+
- jupyterlab=4.3.4=pyhd8ed1ab_0
|
189 |
+
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
|
190 |
+
- jupyterlab_server=2.27.3=pyhd8ed1ab_1
|
191 |
+
- jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
|
192 |
+
- jxrlib=1.1=hd590300_3
|
193 |
+
- keyutils=1.6.1=h166bdaf_0
|
194 |
+
- kiwisolver=1.4.7=py311hd18a35c_0
|
195 |
+
- krb5=1.20.1=h81ceb04_0
|
196 |
+
- lame=3.100=h166bdaf_1003
|
197 |
+
- langchain=0.2.5=pyhd8ed1ab_0
|
198 |
+
- langchain-core=0.2.43=pyhd8ed1ab_0
|
199 |
+
- langchain-text-splitters=0.2.4=pyhd8ed1ab_0
|
200 |
+
- langsmith=0.1.147=pyhd8ed1ab_0
|
201 |
+
- lazy-loader=0.4=pyhd8ed1ab_2
|
202 |
+
- lazy_loader=0.4=pyhd8ed1ab_2
|
203 |
+
- lcms2=2.15=hfd0df8a_0
|
204 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
205 |
+
- lerc=4.0.0=h27087fc_0
|
206 |
+
- libabseil=20230125.3=cxx17_h59595ed_0
|
207 |
+
- libaec=1.1.3=h59595ed_0
|
208 |
+
- libaio=0.3.113=h166bdaf_0
|
209 |
+
- libarchive=3.6.2=h3d51595_0
|
210 |
+
- libarrow=13.0.0=hb9dc469_0_cpu
|
211 |
+
- libasprintf=0.22.5=he8f35ee_3
|
212 |
+
- libasprintf-devel=0.22.5=he8f35ee_3
|
213 |
+
- libavif=0.11.1=h8182462_2
|
214 |
+
- libblas=3.9.0=1_h86c2bf4_netlib
|
215 |
+
- libbrotlicommon=1.0.9=h166bdaf_9
|
216 |
+
- libbrotlidec=1.0.9=h166bdaf_9
|
217 |
+
- libbrotlienc=1.0.9=h166bdaf_9
|
218 |
+
- libcap=2.67=he9d0100_0
|
219 |
+
- libcblas=3.9.0=8_h3b12eaf_netlib
|
220 |
+
- libclang=15.0.7=default_h127d8a8_5
|
221 |
+
- libclang13=15.0.7=default_h5d6823c_5
|
222 |
+
- libcrc32c=1.1.2=h9c3ff4c_0
|
223 |
+
- libcublas=12.4.5.8=0
|
224 |
+
- libcufft=11.2.1.3=0
|
225 |
+
- libcufile=1.11.1.6=0
|
226 |
+
- libcups=2.3.3=h36d4200_3
|
227 |
+
- libcurand=10.3.7.77=0
|
228 |
+
- libcurl=8.11.1=hc9e6f67_0
|
229 |
+
- libcusolver=11.6.1.9=0
|
230 |
+
- libcusparse=12.3.1.170=0
|
231 |
+
- libdb=6.2.32=h9c3ff4c_0
|
232 |
+
- libdeflate=1.17=h0b41bf4_0
|
233 |
+
- libedit=3.1.20191231=he28a2e2_2
|
234 |
+
- libev=4.33=hd590300_2
|
235 |
+
- libevent=2.1.10=h28343ad_4
|
236 |
+
- libexpat=2.6.4=h5888daf_0
|
237 |
+
- libffi=3.4.4=h6a678d5_1
|
238 |
+
- libflac=1.4.3=h59595ed_0
|
239 |
+
- libgcc=14.2.0=h77fa898_1
|
240 |
+
- libgcc-ng=14.2.0=h69a702a_1
|
241 |
+
- libgcrypt=1.11.0=ha770c72_2
|
242 |
+
- libgcrypt-devel=1.11.0=hb9d3cd8_2
|
243 |
+
- libgcrypt-lib=1.11.0=hb9d3cd8_2
|
244 |
+
- libgcrypt-tools=1.11.0=hb9d3cd8_2
|
245 |
+
- libgettextpo=0.22.5=he02047a_3
|
246 |
+
- libgettextpo-devel=0.22.5=he02047a_3
|
247 |
+
- libgfortran=14.2.0=h69a702a_1
|
248 |
+
- libgfortran-ng=14.2.0=h69a702a_1
|
249 |
+
- libgfortran5=14.2.0=hd5240d6_1
|
250 |
+
- libglib=2.78.4=h783c2da_0
|
251 |
+
- libgomp=14.2.0=h77fa898_1
|
252 |
+
- libgoogle-cloud=2.12.0=h840a212_1
|
253 |
+
- libgpg-error=1.51=hbd13f7d_1
|
254 |
+
- libgrpc=1.56.2=h3905398_1
|
255 |
+
- libhwloc=2.9.1=hd6dc26d_0
|
256 |
+
- libiconv=1.17=hd590300_2
|
257 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
258 |
+
- liblapack=3.9.0=8_h3b12eaf_netlib
|
259 |
+
- libllvm15=15.0.7=hadd5161_1
|
260 |
+
- libmamba=1.5.1=h744094f_0
|
261 |
+
- libmambapy=1.5.1=py311hf2555c7_0
|
262 |
+
- libnghttp2=1.57.0=h2d74bed_0
|
263 |
+
- libnpp=12.2.5.30=0
|
264 |
+
- libnsl=2.0.1=hd590300_0
|
265 |
+
- libnuma=2.0.18=h4ab18f5_2
|
266 |
+
- libnvfatbin=12.6.77=0
|
267 |
+
- libnvjitlink=12.4.127=0
|
268 |
+
- libnvjpeg=12.3.1.117=0
|
269 |
+
- libogg=1.3.5=h4ab18f5_0
|
270 |
+
- libopus=1.3.1=h7f98852_1
|
271 |
+
- libpng=1.6.39=h5eee18b_0
|
272 |
+
- libpq=15.3=hbcd7760_1
|
273 |
+
- libprotobuf=4.23.3=hd1fb520_1
|
274 |
+
- libsentencepiece=0.1.99=h28b9611_1
|
275 |
+
- libsndfile=1.2.2=hc60ed4a_1
|
276 |
+
- libsodium=1.0.18=h36c2ea0_1
|
277 |
+
- libsolv=0.7.30=he621ea3_1
|
278 |
+
- libsqlite=3.46.0=hde9e2c9_0
|
279 |
+
- libssh2=1.11.1=h251f7ec_0
|
280 |
+
- libstdcxx=14.2.0=hc0a3c3a_1
|
281 |
+
- libstdcxx-ng=14.2.0=h4852527_1
|
282 |
+
- libsystemd0=253=h8c4010b_1
|
283 |
+
- libthrift=0.18.1=h5e4af38_0
|
284 |
+
- libtiff=4.5.0=h6adf6a1_2
|
285 |
+
- libtool=2.4.7=he02047a_1
|
286 |
+
- libudev1=253=h0b41bf4_1
|
287 |
+
- libutf8proc=2.8.0=hf23e847_1
|
288 |
+
- libuuid=2.38.1=h0b41bf4_0
|
289 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
290 |
+
- libwebp=1.2.4=h1daa5a0_1
|
291 |
+
- libwebp-base=1.2.4=h166bdaf_0
|
292 |
+
- libxcb=1.13=h7f98852_1004
|
293 |
+
- libxkbcommon=1.5.0=h79f4944_1
|
294 |
+
- libxml2=2.10.3=hca2bb57_4
|
295 |
+
- libzlib=1.2.13=h4ab18f5_6
|
296 |
+
- libzopfli=1.0.3=h9c3ff4c_0
|
297 |
+
- lightning=2.5.0.post0=pyhd8ed1ab_0
|
298 |
+
- lightning-utilities=0.11.9=pyhd8ed1ab_1
|
299 |
+
- llvm-openmp=12.0.1=h4bd325d_1
|
300 |
+
- lz4-c=1.9.4=hcb278e6_0
|
301 |
+
- lzo=2.10=hd590300_1001
|
302 |
+
- mamba=1.5.1=py311h3072747_0
|
303 |
+
- markdown=3.6=pyhd8ed1ab_0
|
304 |
+
- markdown-it-py=3.0.0=pyhd8ed1ab_1
|
305 |
+
- markupsafe=3.0.2=py311h2dc5d0c_1
|
306 |
+
- marshmallow=3.25.1=pyhd8ed1ab_0
|
307 |
+
- matplotlib=3.9.1=py311h38be061_1
|
308 |
+
- matplotlib-base=3.9.1=py311h74b4f7c_2
|
309 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
310 |
+
- mdurl=0.1.2=pyhd8ed1ab_1
|
311 |
+
- mistune=3.1.0=pyhd8ed1ab_0
|
312 |
+
- mkl=2023.1.0=h213fc3f_46344
|
313 |
+
- mpc=1.3.1=h24ddda3_1
|
314 |
+
- mpfr=4.2.1=h90cbb55_3
|
315 |
+
- mpg123=1.32.9=hc50e24c_0
|
316 |
+
- mpmath=1.3.0=pyhd8ed1ab_1
|
317 |
+
- msgpack-python=1.1.0=py311hd18a35c_0
|
318 |
+
- multidict=6.1.0=py311h2dc5d0c_2
|
319 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
320 |
+
- mypy_extensions=1.0.0=pyha770c72_1
|
321 |
+
- mysql-common=8.0.33=hf1915f5_6
|
322 |
+
- mysql-libs=8.0.33=hca2cd23_6
|
323 |
+
- nbclient=0.10.2=pyhd8ed1ab_0
|
324 |
+
- nbconvert-core=7.16.5=pyhd8ed1ab_1
|
325 |
+
- nbformat=5.10.4=pyhd8ed1ab_1
|
326 |
+
- ncurses=6.4=h6a678d5_0
|
327 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
328 |
+
- nettle=3.6=he412f7d_0
|
329 |
+
- networkx=3.4.2=pyh267e887_2
|
330 |
+
- ninja=1.12.1=h297d8ca_0
|
331 |
+
- nlohmann_json=3.11.3=he02047a_1
|
332 |
+
- nltk=3.9.1=pyhd8ed1ab_1
|
333 |
+
- notebook=7.3.2=pyhd8ed1ab_0
|
334 |
+
- notebook-shim=0.2.4=pyhd8ed1ab_1
|
335 |
+
- nspr=4.36=h5888daf_0
|
336 |
+
- nss=3.100=hca3bf56_0
|
337 |
+
- numpy=1.26.4=py311h64a7726_0
|
338 |
+
- nvidia-ml-py=12.560.30=pyhd8ed1ab_1
|
339 |
+
- nvitop=1.4.1=pyh707e725_1
|
340 |
+
- omegaconf=2.3.0=pyhd8ed1ab_0
|
341 |
+
- opacus=1.5.2=pyhd8ed1ab_1
|
342 |
+
- openai=1.59.7=pyhd8ed1ab_0
|
343 |
+
- openh264=2.1.1=h780b84a_0
|
344 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
345 |
+
- openssl=3.1.7=hb9d3cd8_0
|
346 |
+
- opt-einsum=3.4.0=hd8ed1ab_1
|
347 |
+
- opt_einsum=3.4.0=pyhd8ed1ab_1
|
348 |
+
- orc=1.9.0=h385abfd_1
|
349 |
+
- orjson=3.10.14=py311h9e33e62_0
|
350 |
+
- overrides=7.7.0=pyhd8ed1ab_1
|
351 |
+
- packaging=24.2=pyhd8ed1ab_2
|
352 |
+
- pandas=2.2.3=py311h7db5c69_1
|
353 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
354 |
+
- parso=0.8.4=pyhd8ed1ab_1
|
355 |
+
- patsy=1.0.1=pyhd8ed1ab_1
|
356 |
+
- pcre2=10.42=hebb0a14_1
|
357 |
+
- peft=0.14.0=pyhd8ed1ab_0
|
358 |
+
- pexpect=4.9.0=pyhd8ed1ab_1
|
359 |
+
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
360 |
+
- pillow=9.4.0=py311h6a678d5_0
|
361 |
+
- pip=24.2=py311h06a4308_0
|
362 |
+
- pixman=0.44.2=h29eaf8c_0
|
363 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
|
364 |
+
- platformdirs=4.3.6=pyhd8ed1ab_1
|
365 |
+
- pluggy=1.5.0=pyhd8ed1ab_1
|
366 |
+
- ply=3.11=pyhd8ed1ab_3
|
367 |
+
- portalocker=3.0.0=py311h38be061_0
|
368 |
+
- prometheus_client=0.21.1=pyhd8ed1ab_0
|
369 |
+
- prompt-toolkit=3.0.48=pyha770c72_1
|
370 |
+
- prompt_toolkit=3.0.48=hd8ed1ab_1
|
371 |
+
- propcache=0.2.1=py311h9ecbd09_0
|
372 |
+
- psutil=6.1.1=py311h9ecbd09_0
|
373 |
+
- pthread-stubs=0.4=hb9d3cd8_1002
|
374 |
+
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
375 |
+
- pulseaudio=16.1=hcb278e6_3
|
376 |
+
- pulseaudio-client=16.1=h5195f5e_3
|
377 |
+
- pulseaudio-daemon=16.1=ha8d29e2_3
|
378 |
+
- pure_eval=0.2.3=pyhd8ed1ab_1
|
379 |
+
- py-cpuinfo=9.0.0=pyhd8ed1ab_1
|
380 |
+
- pyarrow=13.0.0=py311h39c9aba_0_cpu
|
381 |
+
- pyarrow-hotfix=0.6=pyhd8ed1ab_1
|
382 |
+
- pybind11-abi=4=hd8ed1ab_3
|
383 |
+
- pycosat=0.6.6=py311h9ecbd09_2
|
384 |
+
- pycparser=2.22=pyh29332c3_1
|
385 |
+
- pydantic=2.10.5=pyh3cfb1c2_0
|
386 |
+
- pydantic-core=2.27.2=py311h9e33e62_0
|
387 |
+
- pygments=2.19.1=pyhd8ed1ab_0
|
388 |
+
- pymoo=0.6.1.3=py311h7db5c69_0
|
389 |
+
- pynvml=12.0.0=pyhd8ed1ab_0
|
390 |
+
- pyopenssl=24.3.0=pyhd8ed1ab_0
|
391 |
+
- pyparsing=3.2.1=pyhd8ed1ab_0
|
392 |
+
- pypdf=3.17.4=pyhd8ed1ab_0
|
393 |
+
- pyqt=5.15.9=py311hf0fb5b6_5
|
394 |
+
- pyqt5-sip=12.12.2=py311hb755f60_5
|
395 |
+
- pysocks=1.7.1=pyha55dd90_7
|
396 |
+
- python=3.11.6=hab00c5b_0_cpython
|
397 |
+
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
398 |
+
- python-dotenv=1.0.1=pyhd8ed1ab_1
|
399 |
+
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
|
400 |
+
- python-json-logger=2.0.7=pyhd8ed1ab_0
|
401 |
+
- python-slugify=8.0.4=pyhd8ed1ab_1
|
402 |
+
- python-tzdata=2024.2=pyhd8ed1ab_1
|
403 |
+
- python-xxhash=3.5.0=py311h9ecbd09_1
|
404 |
+
- python_abi=3.11=2_cp311
|
405 |
+
- pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0
|
406 |
+
- pytorch-cuda=12.4=hc786d27_7
|
407 |
+
- pytorch-lightning=2.5.0.post0=pyh101cb37_0
|
408 |
+
- pytorch-mutex=1.0=cuda
|
409 |
+
- pytz=2024.1=pyhd8ed1ab_0
|
410 |
+
- pywavelets=1.8.0=py311h9f3472d_0
|
411 |
+
- pyyaml=6.0.2=py311h9ecbd09_1
|
412 |
+
- pyzmq=26.2.0=py311h7deb3e3_0
|
413 |
+
- qhull=2020.2=h434a139_5
|
414 |
+
- qt-main=5.15.8=h5d23da1_6
|
415 |
+
- rdma-core=28.9=h59595ed_1
|
416 |
+
- re2=2023.03.02=h8c504da_0
|
417 |
+
- readline=8.2=h5eee18b_0
|
418 |
+
- referencing=0.35.1=pyhd8ed1ab_1
|
419 |
+
- regex=2024.11.6=py311h9ecbd09_0
|
420 |
+
- reproc=14.2.5.post0=hb9d3cd8_0
|
421 |
+
- reproc-cpp=14.2.5.post0=h5888daf_0
|
422 |
+
- requests=2.32.3=pyhd8ed1ab_1
|
423 |
+
- requests-toolbelt=1.0.0=pyhd8ed1ab_1
|
424 |
+
- responses=0.18.0=pyhd8ed1ab_0
|
425 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
|
426 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
427 |
+
- rich=13.9.4=pyhd8ed1ab_1
|
428 |
+
- rpds-py=0.22.3=py311h9e33e62_0
|
429 |
+
- ruamel.yaml=0.17.40=py311h459d7ec_0
|
430 |
+
- ruamel.yaml.clib=0.2.8=py311h9ecbd09_1
|
431 |
+
- s2n=1.3.51=h06160fa_0
|
432 |
+
- sacrebleu=2.1.0=pyhd8ed1ab_0
|
433 |
+
- sacremoses=0.0.53=pyhd8ed1ab_0
|
434 |
+
- safetensors=0.5.2=py311h9e33e62_0
|
435 |
+
- scikit-image=0.25.0=py311h7db5c69_0
|
436 |
+
- scikit-learn=1.6.1=py311h57cc02b_0
|
437 |
+
- scipy=1.15.1=py311hc1ac118_0
|
438 |
+
- seaborn=0.13.2=hd8ed1ab_3
|
439 |
+
- seaborn-base=0.13.2=pyhd8ed1ab_3
|
440 |
+
- send2trash=1.8.3=pyh0d859eb_1
|
441 |
+
- sentence-transformers=2.7.0=pyhd8ed1ab_0
|
442 |
+
- sentencepiece=0.1.99=h38be061_1
|
443 |
+
- sentencepiece-python=0.1.99=py311hf03188e_1
|
444 |
+
- sentencepiece-spm=0.1.99=h28b9611_1
|
445 |
+
- setuptools=75.1.0=py311h06a4308_0
|
446 |
+
- shtab=1.7.1=pyhd8ed1ab_1
|
447 |
+
- simdjson=3.11.5=h84d6215_0
|
448 |
+
- sip=6.7.12=py311hb755f60_0
|
449 |
+
- six=1.17.0=pyhd8ed1ab_0
|
450 |
+
- snappy=1.1.10=hdb0a2a9_1
|
451 |
+
- sniffio=1.3.1=pyhd8ed1ab_1
|
452 |
+
- soupsieve=2.5=pyhd8ed1ab_1
|
453 |
+
- spdlog=1.14.1=h597fd29_0
|
454 |
+
- sqlalchemy=2.0.37=py311h9ecbd09_0
|
455 |
+
- sqlite=3.45.3=h5eee18b_0
|
456 |
+
- stack_data=0.6.3=pyhd8ed1ab_1
|
457 |
+
- statsmodels=0.14.4=py311h9f3472d_0
|
458 |
+
- tabulate=0.9.0=pyhd8ed1ab_2
|
459 |
+
- tbb=2021.9.0=hf52228f_0
|
460 |
+
- tenacity=8.5.0=pyhd8ed1ab_0
|
461 |
+
- tensorboard=2.18.0=pyhd8ed1ab_1
|
462 |
+
- tensorboard-data-server=0.7.0=py311h63ff55d_1
|
463 |
+
- termcolor=2.5.0=pyhd8ed1ab_1
|
464 |
+
- terminado=0.18.1=pyh0d859eb_0
|
465 |
+
- text-unidecode=1.3=pyhd8ed1ab_2
|
466 |
+
- threadpoolctl=3.5.0=pyhc1e730c_0
|
467 |
+
- tifffile=2023.8.12=pyhd8ed1ab_0
|
468 |
+
- tinycss2=1.4.0=pyhd8ed1ab_0
|
469 |
+
- tk=8.6.14=h39e8969_0
|
470 |
+
- tokenizers=0.13.3=py311h1b04a43_0
|
471 |
+
- toml=0.10.2=pyhd8ed1ab_1
|
472 |
+
- tomli=2.2.1=pyhd8ed1ab_1
|
473 |
+
- toolz=1.0.0=pyhd8ed1ab_1
|
474 |
+
- torchaudio=2.5.1=py311_cu124
|
475 |
+
- torchmetrics=1.6.1=pyhd8ed1ab_0
|
476 |
+
- torchtriton=3.1.0=py311
|
477 |
+
- torchvision=0.20.1=py311_cu124
|
478 |
+
- tornado=6.4.2=py311h9ecbd09_0
|
479 |
+
- tqdm=4.67.1=pyhd8ed1ab_1
|
480 |
+
- traitlets=5.14.3=pyhd8ed1ab_1
|
481 |
+
- transformers=4.33.3=pyhd8ed1ab_0
|
482 |
+
- transforms3d=0.4.2=pyhd8ed1ab_1
|
483 |
+
- trl=0.10.1=pyhd8ed1ab_0
|
484 |
+
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
|
485 |
+
- typing=3.10.0.0=pyhd8ed1ab_2
|
486 |
+
- typing-extensions=4.12.2=hd8ed1ab_1
|
487 |
+
- typing_extensions=4.12.2=pyha770c72_1
|
488 |
+
- typing_inspect=0.9.0=pyhd8ed1ab_1
|
489 |
+
- typing_utils=0.1.0=pyhd8ed1ab_1
|
490 |
+
- tyro=0.9.1=pyhff2d567_0
|
491 |
+
- tzdata=2024b=h04d1e81_0
|
492 |
+
- ucx=1.14.1=h195a15c_5
|
493 |
+
- unicodedata2=16.0.0=py311h9ecbd09_0
|
494 |
+
- uri-template=1.3.0=pyhd8ed1ab_1
|
495 |
+
- urllib3=2.3.0=pyhd8ed1ab_0
|
496 |
+
- wcwidth=0.2.13=pyhd8ed1ab_1
|
497 |
+
- webcolors=24.11.1=pyhd8ed1ab_0
|
498 |
+
- webencodings=0.5.1=pyhd8ed1ab_3
|
499 |
+
- websocket-client=1.8.0=pyhd8ed1ab_1
|
500 |
+
- werkzeug=3.1.3=pyhd8ed1ab_1
|
501 |
+
- wheel=0.44.0=py311h06a4308_0
|
502 |
+
- widgetsnbextension=4.0.13=pyhd8ed1ab_1
|
503 |
+
- wrapt=1.17.1=py311h9ecbd09_0
|
504 |
+
- xcb-util=0.4.0=h516909a_0
|
505 |
+
- xcb-util-image=0.4.0=h166bdaf_0
|
506 |
+
- xcb-util-keysyms=0.4.0=h516909a_0
|
507 |
+
- xcb-util-renderutil=0.3.9=h166bdaf_0
|
508 |
+
- xcb-util-wm=0.4.1=h516909a_0
|
509 |
+
- xformers=0.0.28.post3=py311_cu12.1.0_pyt2.5.1
|
510 |
+
- xkeyboard-config=2.38=h0b41bf4_0
|
511 |
+
- xorg-kbproto=1.0.7=hb9d3cd8_1003
|
512 |
+
- xorg-libice=1.1.2=hb9d3cd8_0
|
513 |
+
- xorg-libsm=1.2.5=he73a12e_0
|
514 |
+
- xorg-libx11=1.8.4=h0b41bf4_0
|
515 |
+
- xorg-libxau=1.0.12=hb9d3cd8_0
|
516 |
+
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
|
517 |
+
- xorg-libxext=1.3.4=h0b41bf4_2
|
518 |
+
- xorg-libxrender=0.9.10=h7f98852_1003
|
519 |
+
- xorg-renderproto=0.11.1=hb9d3cd8_1003
|
520 |
+
- xorg-xextproto=7.3.0=hb9d3cd8_1004
|
521 |
+
- xorg-xproto=7.0.31=hb9d3cd8_1008
|
522 |
+
- xxhash=0.8.2=hd590300_0
|
523 |
+
- xz=5.4.6=h5eee18b_1
|
524 |
+
- yacs=0.1.8=pyhd8ed1ab_1
|
525 |
+
- yaml=0.2.5=h7f98852_2
|
526 |
+
- yaml-cpp=0.7.0=h59595ed_3
|
527 |
+
- yarl=1.18.3=py311h9ecbd09_0
|
528 |
+
- zeromq=4.3.5=h59595ed_1
|
529 |
+
- zfp=1.0.1=h5888daf_2
|
530 |
+
- zipp=3.21.0=pyhd8ed1ab_1
|
531 |
+
- zlib=1.2.13=h4ab18f5_6
|
532 |
+
- zlib-ng=2.0.7=h0b41bf4_0
|
533 |
+
- zstandard=0.23.0=py311hbc35293_1
|
534 |
+
- zstd=1.5.6=hc292b87_0
|
535 |
+
- pip:
|
536 |
+
- bitsandbytes==0.45.0
|
537 |
+
- cryptography==42.0.8
|
538 |
+
- cupy-cuda12x==13.3.0
|
539 |
+
- docker-pycreds==0.4.0
|
540 |
+
- faiss-cpu==1.9.0.post1
|
541 |
+
- fastrlock==0.8.3
|
542 |
+
- flash-attn==2.6.3
|
543 |
+
- flwr==1.14.0
|
544 |
+
- flwr-datasets==0.5.0
|
545 |
+
- fsspec==2024.12.0
|
546 |
+
- gitdb==4.0.12
|
547 |
+
- gitpython==3.1.44
|
548 |
+
- grpcio==1.64.3
|
549 |
+
- iterators==0.0.2
|
550 |
+
- multiprocess==0.70.16
|
551 |
+
- pathspec==0.12.1
|
552 |
+
- protobuf==4.25.5
|
553 |
+
- prv-accountant==0.2.0
|
554 |
+
- pycryptodome==3.21.0
|
555 |
+
- ray==2.10.0
|
556 |
+
- rouge-score==0.1.2
|
557 |
+
- sentry-sdk==2.19.2
|
558 |
+
- setproctitle==1.3.4
|
559 |
+
- shellingham==1.5.4
|
560 |
+
- smmap==5.0.2
|
561 |
+
- sympy==1.13.1
|
562 |
+
- thop==0.1.1-2209072238
|
563 |
+
- tomli-w==1.1.0
|
564 |
+
- typer==0.12.5
|
565 |
+
- wandb==0.19.3
|
template_FL/src/ex.env.example
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
WANDB_API_KEY = ""
|
2 |
+
WANDB_NAME = "FL@CSS25"
|
3 |
+
HF_TOKEN = ""
|
template_FL/src/fedllm/Untitled.ipynb
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "5f1cdd6d-0f6c-447c-8b11-665d0201215e",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [
|
11 |
+
{
|
12 |
+
"ename": "ModuleNotFoundError",
|
13 |
+
"evalue": "No module named 'flwr_datasets'",
|
14 |
+
"output_type": "error",
|
15 |
+
"traceback": [
|
16 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
17 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
18 |
+
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mflwr_datasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpartitioner\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m IidPartitioner\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mflwr_datasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m FederatedDataset\n\u001b[1;32m 5\u001b[0m partitioner \u001b[38;5;241m=\u001b[39m IidPartitioner(num_partitions\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n",
|
19 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'flwr_datasets'"
|
20 |
+
]
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"source": [
|
24 |
+
"from flwr_datasets.partitioner import IidPartitioner\n",
|
25 |
+
"from flwr_datasets import FederatedDataset\n",
|
26 |
+
"\n",
|
27 |
+
"\n",
|
28 |
+
"partitioner = IidPartitioner(num_partitions=10)\n",
|
29 |
+
"FDS = FederatedDataset(\n",
|
30 |
+
" dataset=\"vicgalle/alpaca-gpt4\",\n",
|
31 |
+
" partitioners={\"train\": partitioner},\n",
|
32 |
+
")\n",
|
33 |
+
"client_trainset = FDS.load_partition(1, \"train\")\n",
|
34 |
+
"client_trainset = client_trainset.rename_column(\"output\", \"response\")\n",
|
35 |
+
"client_trainset"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 12,
|
41 |
+
"id": "fecf675d-1279-4fbd-a4e3-15d245582df4",
|
42 |
+
"metadata": {
|
43 |
+
"tags": []
|
44 |
+
},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import datasets\n",
|
48 |
+
"from datasets import load_dataset, DatasetDict\n",
|
49 |
+
"import pandas as pd\n",
|
50 |
+
"from functools import partial\n",
|
51 |
+
"from sklearn.model_selection import train_test_split\n",
|
52 |
+
"\n",
|
53 |
+
"\n",
|
54 |
+
"def get_dataset(dataset_name, local_data_dir=None):\n",
|
55 |
+
"\n",
|
56 |
+
" if dataset_name in [\"gsm8k\"]:\n",
|
57 |
+
" dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
|
58 |
+
" dataset = load_dataset(dataset_name, name=\"main\")\n",
|
59 |
+
" elif dataset_name in [\"lighteval/MATH\"]:\n",
|
60 |
+
" dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
|
61 |
+
" dataset = load_dataset(dataset_name, name=\"all\")\n",
|
62 |
+
" else:\n",
|
63 |
+
" dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
|
64 |
+
" dataset = load_dataset(dataset_name)\n",
|
65 |
+
"\n",
|
66 |
+
" return dataset\n",
|
67 |
+
"\n",
|
68 |
+
"# Function to split a dataset dictionary into two 50/50 parts\n",
|
69 |
+
"def split_dataset_50_50(dataset_dict):\n",
|
70 |
+
" split_datasets = {\n",
|
71 |
+
" 'ds_1': None,\n",
|
72 |
+
" 'ds_2': None\n",
|
73 |
+
" }\n",
|
74 |
+
" for split in ['train', 'valid', 'test']:\n",
|
75 |
+
" if split in dataset_dict:\n",
|
76 |
+
" dataset_split_1, dataset_split_2 = train_test_split(\n",
|
77 |
+
" dataset_dict[split], test_size=0.5, shuffle=True, seed=42\n",
|
78 |
+
" )\n",
|
79 |
+
" print(f\">> ===== After split, Dataset1 {split} has {len(dataset_split_1)} examples. =====\")\n",
|
80 |
+
" print(f\">> ===== After split, Dataset2 {split} has {len(dataset_split_2)} examples. =====\")\n",
|
81 |
+
" split_datasets['ds_1'][split] = dataset_split_1\n",
|
82 |
+
" split_datasets['ds_2'][split] = dataset_split_2\n",
|
83 |
+
" return DatasetDict(split_datasets['ds_1']), DatasetDict(split_datasets['ds_2'])\n",
|
84 |
+
"\n",
|
85 |
+
"\n",
|
86 |
+
"def process_sft_dataset(dataset_name, dataset, dataset_sample):\n",
|
87 |
+
" if dataset_name in [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"yahma/alpaca-cleaned\", \"FinGPT/fingpt-sentiment-train\"]:\n",
|
88 |
+
" dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
|
89 |
+
" elif dataset_name in [\"WizardLM/WizardLM_evol_instruct_70k\"]:\n",
|
90 |
+
" dataset = dataset.rename_column(\"output\", \"response\")\n",
|
91 |
+
" elif dataset_name in [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\", \"gbharti/finance-alpaca\"]:\n",
|
92 |
+
" dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
|
93 |
+
" elif dataset_name in [\"TIGER-Lab/MathInstruct\"]:\n",
|
94 |
+
" df = pd.DataFrame(dataset)\n",
|
95 |
+
" df = df.drop_duplicates(subset=['instruction'])\n",
|
96 |
+
" dataset = datasets.Dataset.from_pandas(df)\n",
|
97 |
+
" dataset = dataset.rename_column(\"output\", \"response\")\n",
|
98 |
+
" dataset = dataset.remove_columns(['source'])\n",
|
99 |
+
" elif dataset_name in [\"lighteval/MATH\"]:\n",
|
100 |
+
" dataset = dataset.rename_column(\"solution\", \"response\")\n",
|
101 |
+
" dataset = dataset.rename_column(\"problem\", \"instruction\")\n",
|
102 |
+
" dataset = dataset.remove_columns(['level', 'type'])\n",
|
103 |
+
" elif dataset_name in ['gsm8k']:\n",
|
104 |
+
" dataset = dataset.rename_column(\"question\", \"instruction\")\n",
|
105 |
+
" dataset = dataset.rename_column(\"answer\", \"response\")\n",
|
106 |
+
" elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.\n",
|
107 |
+
" dataset = dataset.remove_columns(['instruction'])\n",
|
108 |
+
" dataset = dataset.rename_column(\"input\", \"instruction\")\n",
|
109 |
+
" dataset = dataset.rename_column(\"output\", \"response\")\n",
|
110 |
+
" else:\n",
|
111 |
+
" raise NotImplementedError(f\"Dataset {dataset_name} is not supported.\")\n",
|
112 |
+
" dataset = dataset.shuffle(seed=2023)\n",
|
113 |
+
" if dataset_sample:\n",
|
114 |
+
" num_sample = min(len(dataset), dataset_sample)\n",
|
115 |
+
" dataset = dataset.select(range(num_sample))\n",
|
116 |
+
" print(f\">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====\")\n",
|
117 |
+
" print(f\">> ===== Spliting two parts datasets =====\")\n",
|
118 |
+
" \n",
|
119 |
+
" if len(dataset['train']) > 10000 and len(dataset['test']) >= 2000:\n",
|
120 |
+
" dataset = split_dataset_50_50(dataset)\n",
|
121 |
+
" return dataset\n",
|
122 |
+
"\n",
|
123 |
+
"def alpaca_format(example):\n",
|
124 |
+
" if example['input'] == \"\":\n",
|
125 |
+
" example[\"instruction\"] = example[\"instruction\"]\n",
|
126 |
+
" else:\n",
|
127 |
+
" example[\"instruction\"] = example[\"instruction\"] + \" \" + example['input']\n",
|
128 |
+
" example[\"response\"] = example['output']\n",
|
129 |
+
" return example\n",
|
130 |
+
"\n",
|
131 |
+
"\n"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 25,
|
137 |
+
"id": "73d16824-ec97-4e5f-87bc-c432253b93c5",
|
138 |
+
"metadata": {
|
139 |
+
"tags": []
|
140 |
+
},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"import datasets\n",
|
144 |
+
"import pandas as pd\n",
|
145 |
+
"from datasets import Dataset, DatasetDict, load_dataset\n",
|
146 |
+
"from sklearn.model_selection import train_test_split\n",
|
147 |
+
"from functools import partial\n",
|
148 |
+
"\n",
|
149 |
+
"\n",
|
150 |
+
"class DatasetAbstract:\n",
|
151 |
+
" def __init__(self, dataset_name: list[str], category: str):\n",
|
152 |
+
" self.dataset_name = dataset_name\n",
|
153 |
+
" self.metadata = {\n",
|
154 |
+
" 'domain': category\n",
|
155 |
+
" }\n",
|
156 |
+
" \n",
|
157 |
+
" def _processing_data(self):\n",
|
158 |
+
" pass\n",
|
159 |
+
" \n",
|
160 |
+
" @classmethod\n",
|
161 |
+
" def get_dataset(cls, dataset_name, local_data_dir=None):\n",
|
162 |
+
" if dataset_name in [\"gsm8k\"]:\n",
|
163 |
+
" dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
|
164 |
+
" dataset = load_dataset(dataset_name, name=\"main\")\n",
|
165 |
+
" else:\n",
|
166 |
+
" dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
|
167 |
+
" dataset = load_dataset(dataset_name)\n",
|
168 |
+
" \n",
|
169 |
+
" return dataset\n",
|
170 |
+
" \n",
|
171 |
+
" def get_split_dataset(self, dataset):\n",
|
172 |
+
" print(f\">> ===== After processing, Dataset has {len(dataset)} examples. =====\")\n",
|
173 |
+
" if len(dataset) > 10000:\n",
|
174 |
+
" ds_part1, ds_part2 = train_test_split(\n",
|
175 |
+
" dataset, test_size=0.5, shuffle=True, random_state=42\n",
|
176 |
+
" )\n",
|
177 |
+
" print(f\">> ===== After split, Dataset1 has {len(ds_part1)} examples and Dataset2 has {len(ds_part2)} examples. =====\")\n",
|
178 |
+
" list_dataset = []\n",
|
179 |
+
" for subset in [ds_part1, ds_part2]:\n",
|
180 |
+
" train, test = train_test_split(\n",
|
181 |
+
" subset, test_size=0.2, shuffle=True, random_state=42\n",
|
182 |
+
" )\n",
|
183 |
+
" ds = DatasetDict({\n",
|
184 |
+
" \"train\": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),\n",
|
185 |
+
" \"test\": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])\n",
|
186 |
+
" })\n",
|
187 |
+
" list_dataset.append(ds)\n",
|
188 |
+
" return list_dataset\n",
|
189 |
+
" \n",
|
190 |
+
" else:\n",
|
191 |
+
" train, test = train_test_split(\n",
|
192 |
+
" dataset , test_size=0.2, shuffle=True, random_state=42\n",
|
193 |
+
" )\n",
|
194 |
+
" ds = DatasetDict(\n",
|
195 |
+
" {\n",
|
196 |
+
" \"train\": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),\n",
|
197 |
+
" \"test\": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])\n",
|
198 |
+
" }\n",
|
199 |
+
" )\n",
|
200 |
+
" return [ds]\n",
|
201 |
+
"\n",
|
202 |
+
" \n",
|
203 |
+
"class GeneralDataset(DatasetAbstract):\n",
|
204 |
+
" \n",
|
205 |
+
" def __init__(self):\n",
|
206 |
+
" list_dataset = [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\"]\n",
|
207 |
+
" super().__init__(list_dataset, 'general')\n",
|
208 |
+
" self._processing_data()\n",
|
209 |
+
" \n",
|
210 |
+
" def _processing_data(self):\n",
|
211 |
+
" datasets = []\n",
|
212 |
+
" for dataset_name in self.dataset_name:\n",
|
213 |
+
" datasets.append(\n",
|
214 |
+
" pd.DataFrame(super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train'])\n",
|
215 |
+
" )\n",
|
216 |
+
" dataset = pd.concat(datasets, ignore_index=True)\n",
|
217 |
+
" self.list_dataset = self.get_split_dataset(dataset)\n",
|
218 |
+
" \n",
|
219 |
+
" \n",
|
220 |
+
"\n",
|
221 |
+
"class FinanceDataset(DatasetAbstract):\n",
|
222 |
+
" \n",
|
223 |
+
" def __init__(self):\n",
|
224 |
+
" list_dataset = [\"gbharti/finance-alpaca\", \"FinGPT/fingpt-sentiment-train\"]\n",
|
225 |
+
" super().__init__(list_dataset, 'finance')\n",
|
226 |
+
" \n",
|
227 |
+
" self._processing_data()\n",
|
228 |
+
" \n",
|
229 |
+
" def _processing_data(self):\n",
|
230 |
+
" datasets = []\n",
|
231 |
+
" for dataset_name in self.dataset_name:\n",
|
232 |
+
" ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
|
233 |
+
" if dataset_name == 'gbharti/finance-alpaca':\n",
|
234 |
+
" ds = ds.remove_columns(['text'])\n",
|
235 |
+
" df = pd.DataFrame(ds)\n",
|
236 |
+
" datasets.append(df)\n",
|
237 |
+
" dataset = pd.concat(datasets, ignore_index=True)\n",
|
238 |
+
" self.list_dataset = self.get_split_dataset(dataset)\n",
|
239 |
+
" \n",
|
240 |
+
"\n",
|
241 |
+
"class MathDataset(DatasetAbstract):\n",
|
242 |
+
" \n",
|
243 |
+
" def __init__(self):\n",
|
244 |
+
" list_dataset = [\"TIGER-Lab/MathInstruct\", \"xDAN2099/lighteval-MATH\", \"gsm8k\"]\n",
|
245 |
+
" super().__init__(list_dataset, 'math')\n",
|
246 |
+
" self._processing_data()\n",
|
247 |
+
" \n",
|
248 |
+
" \n",
|
249 |
+
" def get_split_dataset(self, dataset):\n",
|
250 |
+
" dataset_train, dataset_test = dataset[0], dataset[1]\n",
|
251 |
+
" print(f\">> ===== After processing, Dataset has {len(dataset_train)} examples. =====\")\n",
|
252 |
+
" if len(dataset_train) > 10000:\n",
|
253 |
+
" ds_train_part1, ds_train_part2 = train_test_split(\n",
|
254 |
+
" dataset_train, test_size=0.5, shuffle=True, random_state=42\n",
|
255 |
+
" )\n",
|
256 |
+
" ds_test_part1, ds_test_part2 = train_test_split(\n",
|
257 |
+
" dataset_test, test_size=0.5, shuffle=True, random_state=42\n",
|
258 |
+
" )\n",
|
259 |
+
" print(f\">> ===== After split, Dataset1 has {len(ds_train_part1)} examples and Dataset2 has {len(ds_train_part2)} examples. =====\")\n",
|
260 |
+
" list_dataset = []\n",
|
261 |
+
" for i in range(2):\n",
|
262 |
+
" ds = DatasetDict({\n",
|
263 |
+
" \"train\": Dataset.from_pandas(eval(f'ds_train_part{i+1}')).remove_columns(['__index_level_0__']), \n",
|
264 |
+
" \"test\": Dataset.from_pandas(eval(f'ds_test_part{i+1}')).remove_columns(['__index_level_0__'])\n",
|
265 |
+
" })\n",
|
266 |
+
" list_dataset.append(ds)\n",
|
267 |
+
" return list_dataset\n",
|
268 |
+
" \n",
|
269 |
+
" else:\n",
|
270 |
+
" ds = DatasetDict(\n",
|
271 |
+
" {\n",
|
272 |
+
" \"train\": Dataset.from_pandas(dataset_train).remove_columns(['__index_level_0__']),\n",
|
273 |
+
" \"test\": Dataset.from_pandas(dataset_test).remove_columns(['__index_level_0__'])\n",
|
274 |
+
" }\n",
|
275 |
+
" )\n",
|
276 |
+
" \n",
|
277 |
+
" def _processing_data(self):\n",
|
278 |
+
" datasets_train, datasets_test = [], []\n",
|
279 |
+
" for dataset_name in self.dataset_name:\n",
|
280 |
+
" ds_tmp = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)\n",
|
281 |
+
" if dataset_name == 'TIGER-Lab/MathInstruct':\n",
|
282 |
+
" df = pd.DataFrame(ds_tmp['train'])\n",
|
283 |
+
" df = df.drop_duplicates(subset=['instruction'])\n",
|
284 |
+
" df = df.drop(['source'], axis=1)\n",
|
285 |
+
" df_train, df_test = train_test_split(df, test_size=0.3, shuffle=True, random_state=42)\n",
|
286 |
+
" \n",
|
287 |
+
" elif dataset_name == \"xDAN2099/lighteval-MATH\":\n",
|
288 |
+
" ds_tmp = ds_tmp.remove_columns(['level', 'type'])\n",
|
289 |
+
" ds_tmp = ds_tmp.rename_column(\"solution\", \"output\")\n",
|
290 |
+
" ds_tmp = ds_tmp.rename_column(\"problem\", \"instruction\")\n",
|
291 |
+
" df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])\n",
|
292 |
+
" \n",
|
293 |
+
" elif dataset_name == 'gsm8k':\n",
|
294 |
+
" ds_tmp = ds_tmp.rename_column(\"answer\", \"output\")\n",
|
295 |
+
" ds_tmp = ds_tmp.rename_column(\"question\", \"instruction\")\n",
|
296 |
+
" df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])\n",
|
297 |
+
" \n",
|
298 |
+
" df_train['input'] = [''] * len(df_train)\n",
|
299 |
+
" df_test['input'] = [''] * len(df_test)\n",
|
300 |
+
" datasets_train.append(df_train)\n",
|
301 |
+
" datasets_test.append(df_test)\n",
|
302 |
+
" \n",
|
303 |
+
" dataset_train = pd.concat(datasets_train, ignore_index=True)\n",
|
304 |
+
" dataset_test = pd.concat(datasets_test, ignore_index=True)\n",
|
305 |
+
" dataset = [dataset_train, dataset_test]\n",
|
306 |
+
" self.list_dataset = self.get_split_dataset(dataset)\n",
|
307 |
+
" \n",
|
308 |
+
"\n",
|
309 |
+
"class MedicalDataset(DatasetAbstract):\n",
|
310 |
+
" \n",
|
311 |
+
" def __init__(self):\n",
|
312 |
+
" list_dataset = [\"medalpaca/medical_meadow_medical_flashcards\"]\n",
|
313 |
+
" super().__init__(list_dataset, 'medical')\n",
|
314 |
+
" self._processing_data()\n",
|
315 |
+
" \n",
|
316 |
+
" def _processing_data(self):\n",
|
317 |
+
" datasets = []\n",
|
318 |
+
" for dataset_name in self.dataset_name:\n",
|
319 |
+
" ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
|
320 |
+
" if dataset_name == 'medalpaca/medical_meadow_medical_flashcards':\n",
|
321 |
+
" ds = ds.remove_columns(['instruction'])\n",
|
322 |
+
" ds = ds.rename_column(\"input\", \"instruction\")\n",
|
323 |
+
" \n",
|
324 |
+
" df = pd.DataFrame(ds)\n",
|
325 |
+
" df['input'] = [''] * len(df)\n",
|
326 |
+
" datasets.append(df)\n",
|
327 |
+
" dataset = pd.concat(datasets, ignore_index=True)\n",
|
328 |
+
" self.list_dataset = self.get_split_dataset(dataset)\n",
|
329 |
+
" \n",
|
330 |
+
"class CodeDataset(DatasetAbstract):\n",
|
331 |
+
" \n",
|
332 |
+
" def __init__(self):\n",
|
333 |
+
" list_dataset = [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"WizardLMTeam/WizardLM_evol_instruct_70k\"]\n",
|
334 |
+
" super().__init__(list_dataset, 'code')\n",
|
335 |
+
" self._processing_data()\n",
|
336 |
+
" \n",
|
337 |
+
" def _processing_data(self):\n",
|
338 |
+
" datasets = []\n",
|
339 |
+
" for dataset_name in self.dataset_name:\n",
|
340 |
+
" ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
|
341 |
+
" df = pd.DataFrame(ds)\n",
|
342 |
+
" if dataset_name == 'WizardLMTeam/WizardLM_evol_instruct_70k':\n",
|
343 |
+
" df['input'] = [''] * len(df)\n",
|
344 |
+
" datasets.append(df)\n",
|
345 |
+
" dataset = pd.concat(datasets, ignore_index=True)\n",
|
346 |
+
" self.list_dataset = self.get_split_dataset(dataset)\n",
|
347 |
+
" \n",
|
348 |
+
"client_id_dataset = {\n",
|
349 |
+
" '1': GeneralDataset().list_dataset[0],\n",
|
350 |
+
" '2': GeneralDataset().list_dataset[1],\n",
|
351 |
+
" '3': FinanceDataset().list_dataset[0],\n",
|
352 |
+
" '4': FinanceDataset().list_dataset[1],\n",
|
353 |
+
" '5': MathDataset().list_dataset[0],\n",
|
354 |
+
" '6': MathDataset().list_dataset[1],\n",
|
355 |
+
" '7': MedicalDataset().list_dataset[0],\n",
|
356 |
+
" '8': MedicalDataset().list_dataset[1],\n",
|
357 |
+
" '9': CodeDataset().list_dataset[0],\n",
|
358 |
+
" '10': CodeDataset().list_dataset[1],\n",
|
359 |
+
"}"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 27,
|
365 |
+
"id": "85a86706-fd9b-4618-b6e0-b6b2743d1246",
|
366 |
+
"metadata": {
|
367 |
+
"tags": []
|
368 |
+
},
|
369 |
+
"outputs": [
|
370 |
+
{
|
371 |
+
"data": {
|
372 |
+
"application/vnd.jupyter.widget-view+json": {
|
373 |
+
"model_id": "a5ca5da771024ac5acffe25a7e625d6e",
|
374 |
+
"version_major": 2,
|
375 |
+
"version_minor": 0
|
376 |
+
},
|
377 |
+
"text/plain": [
|
378 |
+
"README.md: 0%| | 0.00/709 [00:00<?, ?B/s]"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
"metadata": {},
|
382 |
+
"output_type": "display_data"
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"data": {
|
386 |
+
"application/vnd.jupyter.widget-view+json": {
|
387 |
+
"model_id": "33deb82850224a00912678f33da2f5c7",
|
388 |
+
"version_major": 2,
|
389 |
+
"version_minor": 0
|
390 |
+
},
|
391 |
+
"text/plain": [
|
392 |
+
"Cleaned_date.json: 0%| | 0.00/42.9M [00:00<?, ?B/s]"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
"metadata": {},
|
396 |
+
"output_type": "display_data"
|
397 |
+
},
|
398 |
+
{
|
399 |
+
"data": {
|
400 |
+
"application/vnd.jupyter.widget-view+json": {
|
401 |
+
"model_id": "2161bc54748c4eaf990d85c8dfb1d616",
|
402 |
+
"version_major": 2,
|
403 |
+
"version_minor": 0
|
404 |
+
},
|
405 |
+
"text/plain": [
|
406 |
+
"Generating train split: 0%| | 0/68912 [00:00<?, ? examples/s]"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
"metadata": {},
|
410 |
+
"output_type": "display_data"
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"data": {
|
414 |
+
"application/vnd.jupyter.widget-view+json": {
|
415 |
+
"model_id": "f0abe4e86e9b40e7a64dbd4ba80a1b3d",
|
416 |
+
"version_major": 2,
|
417 |
+
"version_minor": 0
|
418 |
+
},
|
419 |
+
"text/plain": [
|
420 |
+
"README.md: 0%| | 0.00/529 [00:00<?, ?B/s]"
|
421 |
+
]
|
422 |
+
},
|
423 |
+
"metadata": {},
|
424 |
+
"output_type": "display_data"
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"data": {
|
428 |
+
"application/vnd.jupyter.widget-view+json": {
|
429 |
+
"model_id": "4b662fd05288467b90064ce383f5ecea",
|
430 |
+
"version_major": 2,
|
431 |
+
"version_minor": 0
|
432 |
+
},
|
433 |
+
"text/plain": [
|
434 |
+
"(…)-00000-of-00001-dabab110260ac909.parquet: 0%| | 0.00/6.42M [00:00<?, ?B/s]"
|
435 |
+
]
|
436 |
+
},
|
437 |
+
"metadata": {},
|
438 |
+
"output_type": "display_data"
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"data": {
|
442 |
+
"application/vnd.jupyter.widget-view+json": {
|
443 |
+
"model_id": "5ee6a13f60124db4af266207872fc6fb",
|
444 |
+
"version_major": 2,
|
445 |
+
"version_minor": 0
|
446 |
+
},
|
447 |
+
"text/plain": [
|
448 |
+
"Generating train split: 0%| | 0/76772 [00:00<?, ? examples/s]"
|
449 |
+
]
|
450 |
+
},
|
451 |
+
"metadata": {},
|
452 |
+
"output_type": "display_data"
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"name": "stdout",
|
456 |
+
"output_type": "stream",
|
457 |
+
"text": [
|
458 |
+
">> ===== After processing, Dataset has 145684 examples. =====\n",
|
459 |
+
">> ===== After split, Dataset1 has 72842 examples and Dataset2 has 72842 examples. =====\n"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"data": {
|
464 |
+
"text/plain": [
|
465 |
+
"[DatasetDict({\n",
|
466 |
+
" train: Dataset({\n",
|
467 |
+
" features: ['instruction', 'input', 'output'],\n",
|
468 |
+
" num_rows: 58273\n",
|
469 |
+
" })\n",
|
470 |
+
" test: Dataset({\n",
|
471 |
+
" features: ['instruction', 'input', 'output'],\n",
|
472 |
+
" num_rows: 14569\n",
|
473 |
+
" })\n",
|
474 |
+
" }),\n",
|
475 |
+
" DatasetDict({\n",
|
476 |
+
" train: Dataset({\n",
|
477 |
+
" features: ['instruction', 'input', 'output'],\n",
|
478 |
+
" num_rows: 58273\n",
|
479 |
+
" })\n",
|
480 |
+
" test: Dataset({\n",
|
481 |
+
" features: ['instruction', 'input', 'output'],\n",
|
482 |
+
" num_rows: 14569\n",
|
483 |
+
" })\n",
|
484 |
+
" })]"
|
485 |
+
]
|
486 |
+
},
|
487 |
+
"execution_count": 27,
|
488 |
+
"metadata": {},
|
489 |
+
"output_type": "execute_result"
|
490 |
+
}
|
491 |
+
],
|
492 |
+
"source": [
|
493 |
+
"FinanceDataset().list_dataset"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"cell_type": "code",
|
498 |
+
"execution_count": 24,
|
499 |
+
"id": "78c926d4-9462-4a8a-b50a-ef5ed1e5c25e",
|
500 |
+
"metadata": {
|
501 |
+
"tags": []
|
502 |
+
},
|
503 |
+
"outputs": [
|
504 |
+
{
|
505 |
+
"data": {
|
506 |
+
"text/html": [
|
507 |
+
"<div>\n",
|
508 |
+
"<style scoped>\n",
|
509 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
510 |
+
" vertical-align: middle;\n",
|
511 |
+
" }\n",
|
512 |
+
"\n",
|
513 |
+
" .dataframe tbody tr th {\n",
|
514 |
+
" vertical-align: top;\n",
|
515 |
+
" }\n",
|
516 |
+
"\n",
|
517 |
+
" .dataframe thead th {\n",
|
518 |
+
" text-align: right;\n",
|
519 |
+
" }\n",
|
520 |
+
"</style>\n",
|
521 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
522 |
+
" <thead>\n",
|
523 |
+
" <tr style=\"text-align: right;\">\n",
|
524 |
+
" <th></th>\n",
|
525 |
+
" <th>letter</th>\n",
|
526 |
+
" <th>number</th>\n",
|
527 |
+
" </tr>\n",
|
528 |
+
" </thead>\n",
|
529 |
+
" <tbody>\n",
|
530 |
+
" <tr>\n",
|
531 |
+
" <th>0</th>\n",
|
532 |
+
" <td>a</td>\n",
|
533 |
+
" <td>1</td>\n",
|
534 |
+
" </tr>\n",
|
535 |
+
" <tr>\n",
|
536 |
+
" <th>1</th>\n",
|
537 |
+
" <td>b</td>\n",
|
538 |
+
" <td>2</td>\n",
|
539 |
+
" </tr>\n",
|
540 |
+
" </tbody>\n",
|
541 |
+
"</table>\n",
|
542 |
+
"</div>"
|
543 |
+
],
|
544 |
+
"text/plain": [
|
545 |
+
" letter number\n",
|
546 |
+
"0 a 1\n",
|
547 |
+
"1 b 2"
|
548 |
+
]
|
549 |
+
},
|
550 |
+
"execution_count": 24,
|
551 |
+
"metadata": {},
|
552 |
+
"output_type": "execute_result"
|
553 |
+
}
|
554 |
+
],
|
555 |
+
"source": [
|
556 |
+
"client_id_dataset = {\n",
|
557 |
+
" '1':\n",
|
558 |
+
"}"
|
559 |
+
]
|
560 |
+
},
|
561 |
+
{
|
562 |
+
"cell_type": "code",
|
563 |
+
"execution_count": 25,
|
564 |
+
"id": "9b9babb9-01fe-4b01-be23-d1f334dd450c",
|
565 |
+
"metadata": {
|
566 |
+
"tags": []
|
567 |
+
},
|
568 |
+
"outputs": [
|
569 |
+
{
|
570 |
+
"data": {
|
571 |
+
"text/html": [
|
572 |
+
"<div>\n",
|
573 |
+
"<style scoped>\n",
|
574 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
575 |
+
" vertical-align: middle;\n",
|
576 |
+
" }\n",
|
577 |
+
"\n",
|
578 |
+
" .dataframe tbody tr th {\n",
|
579 |
+
" vertical-align: top;\n",
|
580 |
+
" }\n",
|
581 |
+
"\n",
|
582 |
+
" .dataframe thead th {\n",
|
583 |
+
" text-align: right;\n",
|
584 |
+
" }\n",
|
585 |
+
"</style>\n",
|
586 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
587 |
+
" <thead>\n",
|
588 |
+
" <tr style=\"text-align: right;\">\n",
|
589 |
+
" <th></th>\n",
|
590 |
+
" <th>letter</th>\n",
|
591 |
+
" <th>number</th>\n",
|
592 |
+
" </tr>\n",
|
593 |
+
" </thead>\n",
|
594 |
+
" <tbody>\n",
|
595 |
+
" <tr>\n",
|
596 |
+
" <th>0</th>\n",
|
597 |
+
" <td>a</td>\n",
|
598 |
+
" <td>1</td>\n",
|
599 |
+
" </tr>\n",
|
600 |
+
" <tr>\n",
|
601 |
+
" <th>1</th>\n",
|
602 |
+
" <td>b</td>\n",
|
603 |
+
" <td>2</td>\n",
|
604 |
+
" </tr>\n",
|
605 |
+
" </tbody>\n",
|
606 |
+
"</table>\n",
|
607 |
+
"</div>"
|
608 |
+
],
|
609 |
+
"text/plain": [
|
610 |
+
" letter number\n",
|
611 |
+
"0 a 1\n",
|
612 |
+
"1 b 2"
|
613 |
+
]
|
614 |
+
},
|
615 |
+
"execution_count": 25,
|
616 |
+
"metadata": {},
|
617 |
+
"output_type": "execute_result"
|
618 |
+
}
|
619 |
+
],
|
620 |
+
"source": [
|
621 |
+
"import pandas as pd\n",
|
622 |
+
"df1 = pd.DataFrame([['a', 1], ['b', 2]],\n",
|
623 |
+
" columns=['letter', 'number'])\n",
|
624 |
+
"\n",
|
625 |
+
"df2 = pd.DataFrame([['c', 3], ['d', 4]],\n",
|
626 |
+
" columns=['letter', 'number'])\n",
|
627 |
+
"\n",
|
628 |
+
"df3 = pd.DataFrame([['e', 5], ['f', 6]],\n",
|
629 |
+
" columns=['letter', 'number'])\n",
|
630 |
+
"\n",
|
631 |
+
"pd.concat([df1], ignore_index=True)"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": 19,
|
637 |
+
"id": "fa021e4a-5e8b-470f-bf6f-db7b9e1f9eb6",
|
638 |
+
"metadata": {
|
639 |
+
"tags": []
|
640 |
+
},
|
641 |
+
"outputs": [
|
642 |
+
{
|
643 |
+
"name": "stdout",
|
644 |
+
"output_type": "stream",
|
645 |
+
"text": [
|
646 |
+
">> ===== After processing, Dataset gsm8k has 2 examples. =====\n",
|
647 |
+
">> ===== Spliting two parts datasets =====\n"
|
648 |
+
]
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"data": {
|
652 |
+
"text/plain": [
|
653 |
+
"DatasetDict({\n",
|
654 |
+
" train: Dataset({\n",
|
655 |
+
" features: ['instruction', 'output'],\n",
|
656 |
+
" num_rows: 7473\n",
|
657 |
+
" })\n",
|
658 |
+
" test: Dataset({\n",
|
659 |
+
" features: ['instruction', 'output'],\n",
|
660 |
+
" num_rows: 1319\n",
|
661 |
+
" })\n",
|
662 |
+
"})"
|
663 |
+
]
|
664 |
+
},
|
665 |
+
"execution_count": 19,
|
666 |
+
"metadata": {},
|
667 |
+
"output_type": "execute_result"
|
668 |
+
}
|
669 |
+
],
|
670 |
+
"source": [
|
671 |
+
"dataset = get_dataset(dataset_name=\"gsm8k\", local_data_dir=None)\n",
|
672 |
+
"datasets = process_sft_dataset(dataset_name=\"gsm8k\", dataset=dataset, dataset_sample=None)\n",
|
673 |
+
"datasets = datasets.rename_column(\"response\", \"output\")\n",
|
674 |
+
"datasets"
|
675 |
+
]
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"cell_type": "code",
|
679 |
+
"execution_count": 18,
|
680 |
+
"id": "c8e02304-5b90-4651-ad36-7feb440d7ea4",
|
681 |
+
"metadata": {
|
682 |
+
"tags": []
|
683 |
+
},
|
684 |
+
"outputs": [],
|
685 |
+
"source": [
|
686 |
+
"def clean_llm_text(text):\n",
|
687 |
+
" \"\"\"\n",
|
688 |
+
" Clean and normalize text from LLM outputs by removing noise and repetitions.\n",
|
689 |
+
" \n",
|
690 |
+
" Args:\n",
|
691 |
+
" text (str): Raw text from LLM prediction\n",
|
692 |
+
" \n",
|
693 |
+
" Returns:\n",
|
694 |
+
" str: Cleaned and normalized text\n",
|
695 |
+
" \"\"\"\n",
|
696 |
+
" import re\n",
|
697 |
+
" \n",
|
698 |
+
" # Remove repetitive patterns (like 'cor cor cor' or 'asesases')\n",
|
699 |
+
" def remove_repetitions(text):\n",
|
700 |
+
" # Split into words\n",
|
701 |
+
" words = text.split()\n",
|
702 |
+
" cleaned_words = []\n",
|
703 |
+
" prev_word = None\n",
|
704 |
+
" repetition_count = 0\n",
|
705 |
+
" \n",
|
706 |
+
" for word in words:\n",
|
707 |
+
" if word == prev_word:\n",
|
708 |
+
" repetition_count += 1\n",
|
709 |
+
" if repetition_count < 2: # Allow up to 2 repetitions for legitimate cases\n",
|
710 |
+
" cleaned_words.append(word)\n",
|
711 |
+
" else:\n",
|
712 |
+
" repetition_count = 0\n",
|
713 |
+
" cleaned_words.append(word)\n",
|
714 |
+
" prev_word = word\n",
|
715 |
+
" \n",
|
716 |
+
" return ' '.join(cleaned_words)\n",
|
717 |
+
" \n",
|
718 |
+
" def remove_repeats(text):\n",
|
719 |
+
" # Remove repeated words\n",
|
720 |
+
" pattern_words = r'\\b(\\w+)(?:\\s+\\1\\b)+'\n",
|
721 |
+
" text = re.sub(pattern_words, r'\\1', text)\n",
|
722 |
+
"\n",
|
723 |
+
" # Remove repeated character patterns (like 'asasas')\n",
|
724 |
+
" pattern_chars = r'(\\w+?)\\1+'\n",
|
725 |
+
" text = re.sub(pattern_chars, r'\\1', text)\n",
|
726 |
+
"\n",
|
727 |
+
" return text\n",
|
728 |
+
" \n",
|
729 |
+
" # Remove excessive punctuation\n",
|
730 |
+
" def normalize_punctuation(text):\n",
|
731 |
+
" # Replace multiple exclamation/question marks with single ones\n",
|
732 |
+
" text = re.sub(r'!+', '!', text)\n",
|
733 |
+
" text = re.sub(r'\\?+', '?', text)\n",
|
734 |
+
" # Remove multiple periods (except for ellipsis)\n",
|
735 |
+
" text = re.sub(r'\\.{4,}', '...', text)\n",
|
736 |
+
" text = text.replace('cor', '').replace('asesa', '')\n",
|
737 |
+
" return text\n",
|
738 |
+
" \n",
|
739 |
+
" # Main cleaning pipeline\n",
|
740 |
+
" cleaned_text = text.strip()\n",
|
741 |
+
" \n",
|
742 |
+
" # Remove common noise patterns\n",
|
743 |
+
" noise_patterns = [\n",
|
744 |
+
" r'\\n+', # Multiple newlines\n",
|
745 |
+
" r'\\s+', # Multiple spaces\n",
|
746 |
+
" r'\\\\n', # Literal \\n\n",
|
747 |
+
" r'\\\\t', # Literal \\t\n",
|
748 |
+
" ]\n",
|
749 |
+
" \n",
|
750 |
+
" for pattern in noise_patterns:\n",
|
751 |
+
" cleaned_text = re.sub(pattern, ' ', cleaned_text)\n",
|
752 |
+
" \n",
|
753 |
+
" # Apply cleaning functions\n",
|
754 |
+
" # cleaned_text = remove_repetitions(cleaned_text)\n",
|
755 |
+
" cleaned_text = remove_repeats(cleaned_text)\n",
|
756 |
+
" cleaned_text = normalize_punctuation(cleaned_text)\n",
|
757 |
+
" cleaned_text = ' '.join(cleaned_text.split()) # Normalize spacing\n",
|
758 |
+
" \n",
|
759 |
+
" return cleaned_text"
|
760 |
+
]
|
761 |
+
},
|
762 |
+
{
|
763 |
+
"cell_type": "code",
|
764 |
+
"execution_count": 20,
|
765 |
+
"id": "9d95fbfe-d824-4247-8628-3bb1cffe9065",
|
766 |
+
"metadata": {
|
767 |
+
"tags": []
|
768 |
+
},
|
769 |
+
"outputs": [
|
770 |
+
{
|
771 |
+
"data": {
|
772 |
+
"text/plain": [
|
773 |
+
"'folowing into their based mamals animals, water animals. 1 Response: animals: , Elephant,Sea Animals: Dolphin, Dolphin'"
|
774 |
+
]
|
775 |
+
},
|
776 |
+
"execution_count": 20,
|
777 |
+
"metadata": {},
|
778 |
+
"output_type": "execute_result"
|
779 |
+
}
|
780 |
+
],
|
781 |
+
"source": [
|
782 |
+
"text_str = f\"\"\"at\\n\\nSea Animals: Whale, Fish cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor', \" cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor One example of a technology that uses artificial intelligence is a virtual personal assistant such as Amazon's Alexa, Apple's Siri, or Google Assistant. These devices use natural language processing and machine learning to understand and respond to user's voice commands, providing assistance in tasks such as setting reminders, playing music, or answering questions. cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor\"\"\"\n",
|
783 |
+
"text_str1 = f\"\"\"following into their based mammals animals, water animals.\\n\\n1 Response:\\n \\n animals: \\n, Elephant,Sea Animals: Dolphin, Dolphin\\n\\nasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesa\"\"\"\n",
|
784 |
+
"clean_llm_text(text_str1)"
|
785 |
+
]
|
786 |
+
},
|
787 |
+
{
|
788 |
+
"cell_type": "code",
|
789 |
+
"execution_count": 10,
|
790 |
+
"id": "c2bd139b-55f9-4848-a3d3-0db55d601e67",
|
791 |
+
"metadata": {
|
792 |
+
"tags": []
|
793 |
+
},
|
794 |
+
"outputs": [
|
795 |
+
{
|
796 |
+
"data": {
|
797 |
+
"text/plain": [
|
798 |
+
"tensor([[1, 2, 3]])"
|
799 |
+
]
|
800 |
+
},
|
801 |
+
"execution_count": 10,
|
802 |
+
"metadata": {},
|
803 |
+
"output_type": "execute_result"
|
804 |
+
}
|
805 |
+
],
|
806 |
+
"source": [
|
807 |
+
"import torch\n",
|
808 |
+
"\n",
|
809 |
+
"x = torch.tensor([1, 2, 3])\n",
|
810 |
+
"torch.unsqueeze(x,dim=0)"
|
811 |
+
]
|
812 |
+
},
|
813 |
+
{
|
814 |
+
"cell_type": "code",
|
815 |
+
"execution_count": 16,
|
816 |
+
"id": "2eb4071a-3d82-4d86-bb0e-a1403a2b1c1e",
|
817 |
+
"metadata": {
|
818 |
+
"tags": []
|
819 |
+
},
|
820 |
+
"outputs": [
|
821 |
+
{
|
822 |
+
"name": "stdout",
|
823 |
+
"output_type": "stream",
|
824 |
+
"text": [
|
825 |
+
"TF-IDF Matrix Shape: (136, 25)\n"
|
826 |
+
]
|
827 |
+
}
|
828 |
+
],
|
829 |
+
"source": []
|
830 |
+
},
|
831 |
+
{
|
832 |
+
"cell_type": "code",
|
833 |
+
"execution_count": null,
|
834 |
+
"id": "4f75512f-13e8-4373-8630-8b5269729c32",
|
835 |
+
"metadata": {},
|
836 |
+
"outputs": [],
|
837 |
+
"source": []
|
838 |
+
}
|
839 |
+
],
|
840 |
+
"metadata": {
|
841 |
+
"kernelspec": {
|
842 |
+
"display_name": "py11torch",
|
843 |
+
"language": "python",
|
844 |
+
"name": "py11torch"
|
845 |
+
},
|
846 |
+
"language_info": {
|
847 |
+
"codemirror_mode": {
|
848 |
+
"name": "ipython",
|
849 |
+
"version": 3
|
850 |
+
},
|
851 |
+
"file_extension": ".py",
|
852 |
+
"mimetype": "text/x-python",
|
853 |
+
"name": "python",
|
854 |
+
"nbconvert_exporter": "python",
|
855 |
+
"pygments_lexer": "ipython3",
|
856 |
+
"version": "3.11.8"
|
857 |
+
}
|
858 |
+
},
|
859 |
+
"nbformat": 4,
|
860 |
+
"nbformat_minor": 5
|
861 |
+
}
|
template_FL/src/fedllm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""flowertune_llm."""
|
template_FL/src/fedllm/client_app.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""flowertune-llm: A Flower / FlowerTune app."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from typing import Dict, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
import numpy as np
|
10 |
+
from flwr.client import ClientApp, NumPyClient
|
11 |
+
from flwr.common import Context
|
12 |
+
from flwr.common.config import unflatten_dict
|
13 |
+
from flwr.common.typing import NDArrays, Scalar
|
14 |
+
from omegaconf import DictConfig
|
15 |
+
|
16 |
+
from transformers import TrainingArguments, DataCollatorForSeq2Seq, Trainer, EarlyStoppingCallback, BertForSequenceClassification, GenerationConfig
|
17 |
+
|
18 |
+
from trl import SFTTrainer, SFTConfig
|
19 |
+
from deepspeed.profiling.flops_profiler import get_model_profile
|
20 |
+
from deepspeed.accelerator import get_accelerator
|
21 |
+
|
22 |
+
from .trainer import ManualTrainer
|
23 |
+
|
24 |
+
from .dataset import (
|
25 |
+
get_data_collator_and_propt_formatting,
|
26 |
+
load_data,
|
27 |
+
load_data_homo,
|
28 |
+
load_data_hete,
|
29 |
+
replace_keys,
|
30 |
+
)
|
31 |
+
from .models import *
|
32 |
+
|
33 |
+
from .flwr_mods import get_wandb_mod
|
34 |
+
from .metrics import exact_match, f1, get_rouge_score
|
35 |
+
from .utils import clean_output_text
|
36 |
+
from .make_data import Prompter, generate_and_tokenize_prompt
|
37 |
+
|
38 |
+
# Avoid warnings
|
39 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
40 |
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
41 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
42 |
+
|
43 |
+
|
44 |
+
def input_constructor(batch_size, seq_len, tokenizer):
|
45 |
+
fake_seq = ""
|
46 |
+
for _ in range(seq_len - 2): # ignore the two special tokens [CLS] and [SEP]
|
47 |
+
fake_seq += tokenizer.pad_token
|
48 |
+
inputs = tokenizer([fake_seq] * batch_size,
|
49 |
+
padding=True,
|
50 |
+
truncation=True,
|
51 |
+
max_length=seq_len,
|
52 |
+
return_tensors="pt")
|
53 |
+
labels = torch.tensor([1] * batch_size)
|
54 |
+
inputs = dict(inputs)
|
55 |
+
# inputs.update({"labels": torch.unsqueeze(labels,dim=0)})
|
56 |
+
|
57 |
+
# To device
|
58 |
+
inputs = {k: v.to('cuda:0') for k, v in inputs.items()}
|
59 |
+
return inputs
|
60 |
+
|
61 |
+
def convert_to_float(value_str):
|
62 |
+
value, unit = value_str.split()
|
63 |
+
value = float(value)
|
64 |
+
if unit == 'T' or 'T' in unit:
|
65 |
+
return value * 1e12
|
66 |
+
elif unit == 'G' or 'G' in unit:
|
67 |
+
return value * 1e9
|
68 |
+
elif unit == 'M' or 'M' in unit:
|
69 |
+
return value * 1e6
|
70 |
+
elif unit == 'K' or 'K' in unit:
|
71 |
+
return value * 1e3
|
72 |
+
return value
|
73 |
+
|
74 |
+
# pylint: disable=too-many-arguments
|
75 |
+
# pylint: disable=too-many-instance-attributes
|
76 |
+
class FlowerClient(NumPyClient):
|
77 |
+
"""Standard Flower client for CNN training."""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
model_cfg: DictConfig,
|
82 |
+
train_cfg: DictConfig,
|
83 |
+
mates_args: DictConfig,
|
84 |
+
trainset,
|
85 |
+
valset,
|
86 |
+
num_rounds,
|
87 |
+
): # pylint: disable=too-many-arguments
|
88 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
89 |
+
self.train_cfg = train_cfg
|
90 |
+
|
91 |
+
self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
|
92 |
+
# self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
|
93 |
+
|
94 |
+
self.num_rounds = num_rounds
|
95 |
+
self.trainset = trainset
|
96 |
+
self.valset = valset
|
97 |
+
self.mates_args = mates_args
|
98 |
+
self.holdoutset = None
|
99 |
+
self.refset = None
|
100 |
+
self.data_influence_model = None
|
101 |
+
self.data_influence_tokenizer = None
|
102 |
+
|
103 |
+
# instantiate model
|
104 |
+
self.model, self.tokenizer = get_model(model_cfg)
|
105 |
+
|
106 |
+
if self.mates_args.state:
|
107 |
+
self.data_influence_model, self.data_influence_tokenizer = get_data_influence_model(model_cfg)
|
108 |
+
|
109 |
+
# (
|
110 |
+
# self.data_collator,
|
111 |
+
# self.formatting_prompts_func
|
112 |
+
# ) = get_data_collator_and_propt_formatting(self.tokenizer)
|
113 |
+
|
114 |
+
self.data_collator = DataCollatorForSeq2Seq(
|
115 |
+
self.tokenizer,
|
116 |
+
pad_to_multiple_of=8,
|
117 |
+
return_tensors="pt",
|
118 |
+
padding=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.train_on_inputs = self.train_cfg.train_on_inputs
|
122 |
+
|
123 |
+
self._make_dataset()
|
124 |
+
|
125 |
+
def compute_metrics(self, pred):
|
126 |
+
labels_ids = pred['label_ids']
|
127 |
+
pred_ids = pred['predictions']
|
128 |
+
|
129 |
+
# Replace -100 with pad token id in labels
|
130 |
+
labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
|
131 |
+
|
132 |
+
print(f"Shape of predictions: {np.shape(pred_ids)}")
|
133 |
+
print(f"Shape of labels: {np.shape(labels_ids)}")
|
134 |
+
|
135 |
+
# Decode predictions and labels
|
136 |
+
pred_str = self.tokenizer.batch_decode(
|
137 |
+
pred_ids, skip_special_tokens=True
|
138 |
+
)
|
139 |
+
label_str = self.tokenizer.batch_decode(
|
140 |
+
labels_ids, skip_special_tokens=True
|
141 |
+
)
|
142 |
+
|
143 |
+
# Remove any extra whitespace from the decoded strings
|
144 |
+
pred_str = [s.strip() for s in pred_str]
|
145 |
+
label_str = [s.strip() for s in label_str]
|
146 |
+
|
147 |
+
return {
|
148 |
+
**get_rouge_score(predictions=pred_str, targets=label_str),
|
149 |
+
**f1(predictions=pred_str, targets=label_str),
|
150 |
+
}
|
151 |
+
|
152 |
+
def _make_dataset(self):
|
153 |
+
prompter = Prompter(self.train_cfg.prompt_template_name, self.train_cfg.verbose)
|
154 |
+
tmp_dict = {
|
155 |
+
"prompter": prompter,
|
156 |
+
"seq_length": self.train_cfg.seq_length,
|
157 |
+
"train_on_inputs": self.train_on_inputs,
|
158 |
+
"tokenizer": self.tokenizer,
|
159 |
+
}
|
160 |
+
|
161 |
+
# Process trainset
|
162 |
+
self.trainset = (
|
163 |
+
self.trainset
|
164 |
+
.shuffle()
|
165 |
+
.map(
|
166 |
+
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
167 |
+
num_proc=8,
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
# Process valset
|
172 |
+
self.valset = (
|
173 |
+
self.valset
|
174 |
+
.shuffle()
|
175 |
+
.map(
|
176 |
+
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
177 |
+
num_proc=8,
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
# Create holdoutset and refset if state is True
|
182 |
+
if self.mates_args.state:
|
183 |
+
trainset_size = len(self.trainset)
|
184 |
+
|
185 |
+
# Calculate sizes for holdout and reference sets
|
186 |
+
holdout_size = int(trainset_size * self.mates_args.holdout_ratio)
|
187 |
+
ref_size = int(trainset_size * self.mates_args.reference_ratio)
|
188 |
+
|
189 |
+
# Shuffle the trainset to ensure randomness
|
190 |
+
shuffled_indices = list(range(trainset_size))
|
191 |
+
self.trainset = self.trainset.shuffle()
|
192 |
+
|
193 |
+
# Split the dataset
|
194 |
+
holdout_indices = shuffled_indices[:holdout_size]
|
195 |
+
ref_indices = shuffled_indices[holdout_size:holdout_size + ref_size]
|
196 |
+
|
197 |
+
# Create holdoutset and refset
|
198 |
+
self.holdoutset = self.trainset.select(holdout_indices)
|
199 |
+
self.refset = self.trainset.select(ref_indices)
|
200 |
+
|
201 |
+
print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
|
202 |
+
|
203 |
+
|
204 |
+
def fit(
|
205 |
+
self, parameters: NDArrays, config: Dict[str, Scalar]
|
206 |
+
) -> Tuple[NDArrays, int, Dict]:
|
207 |
+
"""Implement distributed fit function for a given client."""
|
208 |
+
if self.mates_args.state and int(config["current_round"]) != 1:
|
209 |
+
main_model_params, data_influence_model_params = split_models(parameters)
|
210 |
+
set_parameters(self.model, main_model_params)
|
211 |
+
set_parameters_bert(self.data_influence_model, data_influence_model_params)
|
212 |
+
else:
|
213 |
+
set_parameters(self.model, parameters)
|
214 |
+
|
215 |
+
new_lr = cosine_annealing(
|
216 |
+
int(config["current_round"]),
|
217 |
+
self.num_rounds,
|
218 |
+
self.train_cfg.learning_rate_max,
|
219 |
+
self.train_cfg.learning_rate_min,
|
220 |
+
)
|
221 |
+
|
222 |
+
self.training_arguments.learning_rate = new_lr
|
223 |
+
self.training_arguments.output_dir = config["save_path"]
|
224 |
+
|
225 |
+
# Initialize callback
|
226 |
+
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5)
|
227 |
+
|
228 |
+
# Construct supervised trainer
|
229 |
+
# trainer = SFTTrainer(
|
230 |
+
# model=self.model,
|
231 |
+
# tokenizer=self.tokenizer,
|
232 |
+
# args=self.training_arguments,
|
233 |
+
# train_dataset=self.trainset,
|
234 |
+
# eval_dataset=self.valset,
|
235 |
+
# formatting_func=self.formatting_prompts_func,
|
236 |
+
# data_collator=self.data_collator,
|
237 |
+
# compute_metrics=self.compute_metrics,
|
238 |
+
# callbacks=[flops_callback, early_stopping_callback]
|
239 |
+
# )
|
240 |
+
|
241 |
+
# # Constuct baseline Trainer
|
242 |
+
# trainer = Trainer(
|
243 |
+
# model=self.model,
|
244 |
+
# train_dataset=self.trainset,
|
245 |
+
# eval_dataset=self.valset.select(range(10)),
|
246 |
+
# args=self.training_arguments,
|
247 |
+
# data_collator=self.data_collator,
|
248 |
+
# compute_metrics=self.compute_metrics,
|
249 |
+
# callbacks=[early_stopping_callback]
|
250 |
+
# )
|
251 |
+
|
252 |
+
trainer = ManualTrainer(
|
253 |
+
model=self.model,
|
254 |
+
tokenizer = self.tokenizer,
|
255 |
+
train_dataset=self.trainset,
|
256 |
+
val_dataset=self.valset.select(range(10)),
|
257 |
+
holdout_dataset=self.holdoutset,
|
258 |
+
reference_dataset=self.refset,
|
259 |
+
args=self.training_arguments,
|
260 |
+
data_collator=self.data_collator,
|
261 |
+
compute_metrics=self.compute_metrics,
|
262 |
+
mates_args=self.mates_args,
|
263 |
+
data_influence_model=self.data_influence_model,
|
264 |
+
data_influence_tokenizer=self.data_influence_tokenizer,
|
265 |
+
)
|
266 |
+
|
267 |
+
# Train the model
|
268 |
+
results = trainer.train()
|
269 |
+
|
270 |
+
if self.mates_args.state:
|
271 |
+
# After training
|
272 |
+
main_model_params = get_parameters(self.model)
|
273 |
+
data_influence_model_params = model_parameters_to_ndarrays(self.data_influence_model)
|
274 |
+
final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
|
275 |
+
else:
|
276 |
+
final_model_params = get_parameters(self.model)
|
277 |
+
|
278 |
+
# Calculate FLOPs
|
279 |
+
with get_accelerator().device('cuda:0'):
|
280 |
+
batch_size = self.training_arguments.per_device_eval_batch_size
|
281 |
+
seq_len = self.train_cfg.seq_length
|
282 |
+
flops1, macs1, params1 = get_model_profile(
|
283 |
+
self.model,
|
284 |
+
kwargs=input_constructor(batch_size, seq_len, self.tokenizer),
|
285 |
+
print_profile=True,
|
286 |
+
detailed=False,
|
287 |
+
)
|
288 |
+
flops2, macs2, params2 = get_model_profile(
|
289 |
+
self.data_influence_model,
|
290 |
+
kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
|
291 |
+
print_profile=True,
|
292 |
+
detailed=False,
|
293 |
+
)
|
294 |
+
flops1_value, flops2_value = convert_to_float(flops1), convert_to_float(flops2)
|
295 |
+
macs1_value, macs2_value = convert_to_float(macs1), convert_to_float(macs2)
|
296 |
+
params1_value, params2_value = convert_to_float(params1), convert_to_float(params2)
|
297 |
+
wandb.log({"total_flops": flops1_value + flops2_value, "macs": macs1_value + macs2_value, "params": params1_value + params2_value}) # wa
|
298 |
+
|
299 |
+
return (
|
300 |
+
final_model_params,
|
301 |
+
len(self.trainset),
|
302 |
+
{"train_loss": results['training_loss'], "flops": flops1_value + flops2_value},
|
303 |
+
)
|
304 |
+
|
305 |
+
|
306 |
+
def client_fn(context: Context) -> FlowerClient:
|
307 |
+
"""Create a Flower client representing a single organization."""
|
308 |
+
partition_id = context.node_config["partition-id"]
|
309 |
+
num_partitions = context.node_config["num-partitions"]
|
310 |
+
num_rounds = context.run_config["num-server-rounds"]
|
311 |
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
312 |
+
|
313 |
+
# Let's get the client partition
|
314 |
+
if cfg.dataset.type == 'homo':
|
315 |
+
client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
|
316 |
+
else:
|
317 |
+
client_set = load_data_hete(partition_id)
|
318 |
+
|
319 |
+
return FlowerClient(
|
320 |
+
cfg.model,
|
321 |
+
cfg.train,
|
322 |
+
cfg.mates,
|
323 |
+
client_set['train'],
|
324 |
+
client_set['test'],
|
325 |
+
num_rounds,
|
326 |
+
).to_client()
|
327 |
+
|
328 |
+
|
329 |
+
# Flower ClientApp
|
330 |
+
app = ClientApp(
|
331 |
+
client_fn,
|
332 |
+
mods=[
|
333 |
+
get_wandb_mod("FL@CSS25"),
|
334 |
+
],
|
335 |
+
)
|
template_FL/src/fedllm/data_domains.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import pandas as pd
|
3 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
global_test_set_hete = {}
|
8 |
+
|
9 |
+
class DatasetAbstract:
|
10 |
+
def __init__(self, dataset_name: list[str], category: str):
|
11 |
+
self.dataset_name = dataset_name
|
12 |
+
self.metadata = {
|
13 |
+
'domain': category
|
14 |
+
}
|
15 |
+
|
16 |
+
def _processing_data(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def get_dataset(cls, dataset_name, local_data_dir=None):
|
21 |
+
if dataset_name in ["gsm8k"]:
|
22 |
+
dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name
|
23 |
+
dataset = load_dataset(dataset_name, name="main")
|
24 |
+
else:
|
25 |
+
dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name
|
26 |
+
dataset = load_dataset(dataset_name)
|
27 |
+
|
28 |
+
return dataset
|
29 |
+
|
30 |
+
def get_split_dataset(self, dataset):
|
31 |
+
print(f">> ===== After processing, Dataset has {len(dataset)} examples. =====")
|
32 |
+
if len(dataset) > 10000:
|
33 |
+
ds_part1, ds_part2 = train_test_split(
|
34 |
+
dataset, test_size=0.5, shuffle=True, random_state=42
|
35 |
+
)
|
36 |
+
print(f">> ===== After split, Dataset1 has {len(ds_part1)} examples and Dataset2 has {len(ds_part2)} examples. =====")
|
37 |
+
list_dataset = []
|
38 |
+
list_global_set = []
|
39 |
+
for subset in [ds_part1, ds_part2]:
|
40 |
+
train, test = train_test_split(
|
41 |
+
subset, test_size=0.2, shuffle=True, random_state=42
|
42 |
+
)
|
43 |
+
test, global_test = train_test_split(
|
44 |
+
subset, test_size=0.1, shuffle=True, random_state=42
|
45 |
+
)
|
46 |
+
ds = DatasetDict({
|
47 |
+
"train": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),
|
48 |
+
"test": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
|
49 |
+
})
|
50 |
+
list_dataset.append(ds)
|
51 |
+
list_global_set.append(global_test)
|
52 |
+
|
53 |
+
list_global_set = pd.concat(list_global_set, ignore_index=True)
|
54 |
+
list_global_set = Dataset.from_pandas(list_global_set)
|
55 |
+
return list_dataset, list_global_set
|
56 |
+
|
57 |
+
else:
|
58 |
+
train, test = train_test_split(
|
59 |
+
dataset , test_size=0.2, shuffle=True, random_state=42
|
60 |
+
)
|
61 |
+
test, global_test = train_test_split(
|
62 |
+
subset, test_size=0.1, shuffle=True, random_state=42
|
63 |
+
)
|
64 |
+
ds = DatasetDict(
|
65 |
+
{
|
66 |
+
"train": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),
|
67 |
+
"test": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
|
68 |
+
}
|
69 |
+
)
|
70 |
+
global_set = Dataset.from_pandas(global_test).remove_columns(['__index_level_0__'])
|
71 |
+
return [ds], global_set
|
72 |
+
|
73 |
+
|
74 |
+
class GeneralDataset(DatasetAbstract):
|
75 |
+
|
76 |
+
def __init__(self):
|
77 |
+
list_dataset = ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4"]
|
78 |
+
super().__init__(list_dataset, 'general')
|
79 |
+
self._processing_data()
|
80 |
+
|
81 |
+
def _processing_data(self):
|
82 |
+
datasets = []
|
83 |
+
for dataset_name in self.dataset_name:
|
84 |
+
datasets.append(
|
85 |
+
pd.DataFrame(super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train'])
|
86 |
+
)
|
87 |
+
dataset = pd.concat(datasets, ignore_index=True)
|
88 |
+
self.list_dataset, global_test = self.get_split_dataset(dataset)
|
89 |
+
global global_test_set_hete
|
90 |
+
global_test_set_hete.update(
|
91 |
+
{self.metadata['domain']: global_test}
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
class FinanceDataset(DatasetAbstract):
|
96 |
+
|
97 |
+
def __init__(self):
|
98 |
+
list_dataset = ["gbharti/finance-alpaca", "FinGPT/fingpt-sentiment-train"]
|
99 |
+
super().__init__(list_dataset, 'finance')
|
100 |
+
|
101 |
+
self._processing_data()
|
102 |
+
|
103 |
+
def _processing_data(self):
|
104 |
+
datasets = []
|
105 |
+
for dataset_name in self.dataset_name:
|
106 |
+
ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
|
107 |
+
if dataset_name == 'gbharti/finance-alpaca':
|
108 |
+
ds = ds.remove_columns(['text'])
|
109 |
+
df = pd.DataFrame(ds)
|
110 |
+
datasets.append(df)
|
111 |
+
dataset = pd.concat(datasets, ignore_index=True)
|
112 |
+
self.list_dataset, global_test = self.get_split_dataset(dataset)
|
113 |
+
global global_test_set_hete
|
114 |
+
global_test_set_hete.update(
|
115 |
+
{self.metadata['domain']: global_test}
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
class MathDataset(DatasetAbstract):
|
120 |
+
|
121 |
+
def __init__(self):
|
122 |
+
list_dataset = ["TIGER-Lab/MathInstruct", "xDAN2099/lighteval-MATH", "gsm8k"]
|
123 |
+
super().__init__(list_dataset, 'math')
|
124 |
+
self._processing_data()
|
125 |
+
|
126 |
+
|
127 |
+
def get_split_dataset(self, dataset):
|
128 |
+
dataset_train, dataset_test = dataset[0], dataset[1]
|
129 |
+
dataset_test, global_test = train_test_split(
|
130 |
+
dataset_test, test_size=0.1, shuffle=True, random_state=42
|
131 |
+
)
|
132 |
+
global_test = Dataset.from_pandas(global_test)
|
133 |
+
print(f">> ===== After processing, Dataset has {len(dataset_train)} examples. =====")
|
134 |
+
if len(dataset_train) > 10000:
|
135 |
+
ds_train_part1, ds_train_part2 = train_test_split(
|
136 |
+
dataset_train, test_size=0.5, shuffle=True, random_state=42
|
137 |
+
)
|
138 |
+
ds_test_part1, ds_test_part2 = train_test_split(
|
139 |
+
dataset_test, test_size=0.5, shuffle=True, random_state=42
|
140 |
+
)
|
141 |
+
print(f">> ===== After split, Dataset1 has {len(ds_train_part1)} examples and Dataset2 has {len(ds_train_part2)} examples. =====")
|
142 |
+
list_dataset = []
|
143 |
+
for i in range(2):
|
144 |
+
ds = DatasetDict({
|
145 |
+
"train": Dataset.from_pandas(eval(f'ds_train_part{i+1}')).remove_columns(['__index_level_0__']),
|
146 |
+
"test": Dataset.from_pandas(eval(f'ds_test_part{i+1}')).remove_columns(['__index_level_0__'])
|
147 |
+
})
|
148 |
+
list_dataset.append(ds)
|
149 |
+
return list_dataset, global_test
|
150 |
+
|
151 |
+
else:
|
152 |
+
ds = DatasetDict(
|
153 |
+
{
|
154 |
+
"train": Dataset.from_pandas(dataset_train).remove_columns(['__index_level_0__']),
|
155 |
+
"test": Dataset.from_pandas(dataset_test).remove_columns(['__index_level_0__'])
|
156 |
+
}
|
157 |
+
)
|
158 |
+
return [ds], global_test
|
159 |
+
|
160 |
+
def _processing_data(self):
|
161 |
+
datasets_train, datasets_test = [], []
|
162 |
+
for dataset_name in self.dataset_name:
|
163 |
+
ds_tmp = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)
|
164 |
+
if dataset_name == 'TIGER-Lab/MathInstruct':
|
165 |
+
df = pd.DataFrame(ds_tmp['train'])
|
166 |
+
df = df.drop_duplicates(subset=['instruction'])
|
167 |
+
df = df.drop(['source'], axis=1)
|
168 |
+
df_train, df_test = train_test_split(df, test_size=0.3, shuffle=True, random_state=42)
|
169 |
+
|
170 |
+
elif dataset_name == "xDAN2099/lighteval-MATH":
|
171 |
+
ds_tmp = ds_tmp.remove_columns(['level', 'type'])
|
172 |
+
ds_tmp = ds_tmp.rename_column("solution", "output")
|
173 |
+
ds_tmp = ds_tmp.rename_column("problem", "instruction")
|
174 |
+
df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])
|
175 |
+
|
176 |
+
elif dataset_name == 'gsm8k':
|
177 |
+
ds_tmp = ds_tmp.rename_column("answer", "output")
|
178 |
+
ds_tmp = ds_tmp.rename_column("question", "instruction")
|
179 |
+
df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])
|
180 |
+
|
181 |
+
df_train['input'] = [''] * len(df_train)
|
182 |
+
df_test['input'] = [''] * len(df_test)
|
183 |
+
datasets_train.append(df_train)
|
184 |
+
datasets_test.append(df_test)
|
185 |
+
|
186 |
+
dataset_train = pd.concat(datasets_train, ignore_index=True)
|
187 |
+
dataset_test = pd.concat(datasets_test, ignore_index=True)
|
188 |
+
dataset = [dataset_train, dataset_test]
|
189 |
+
self.list_dataset, global_test = self.get_split_dataset(dataset)
|
190 |
+
global global_test_set_hete
|
191 |
+
global_test_set_hete.update(
|
192 |
+
{self.metadata['domain']: global_test}
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
class MedicalDataset(DatasetAbstract):
|
197 |
+
|
198 |
+
def __init__(self):
|
199 |
+
list_dataset = ["medalpaca/medical_meadow_medical_flashcards"]
|
200 |
+
super().__init__(list_dataset, 'medical')
|
201 |
+
self._processing_data()
|
202 |
+
|
203 |
+
def _processing_data(self):
|
204 |
+
datasets = []
|
205 |
+
for dataset_name in self.dataset_name:
|
206 |
+
ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
|
207 |
+
if dataset_name == 'medalpaca/medical_meadow_medical_flashcards':
|
208 |
+
ds = ds.remove_columns(['instruction'])
|
209 |
+
ds = ds.rename_column("input", "instruction")
|
210 |
+
|
211 |
+
df = pd.DataFrame(ds)
|
212 |
+
df['input'] = [''] * len(df)
|
213 |
+
datasets.append(df)
|
214 |
+
dataset = pd.concat(datasets, ignore_index=True)
|
215 |
+
self.list_dataset, global_test = self.get_split_dataset(dataset)
|
216 |
+
global global_test_set_hete
|
217 |
+
global_test_set_hete.update(
|
218 |
+
{self.metadata['domain']: global_test}
|
219 |
+
)
|
220 |
+
|
221 |
+
class CodeDataset(DatasetAbstract):
|
222 |
+
|
223 |
+
def __init__(self):
|
224 |
+
list_dataset = ["lucasmccabe-lmi/CodeAlpaca-20k", "WizardLMTeam/WizardLM_evol_instruct_70k"]
|
225 |
+
super().__init__(list_dataset, 'code')
|
226 |
+
self._processing_data()
|
227 |
+
|
228 |
+
def _processing_data(self):
|
229 |
+
datasets = []
|
230 |
+
for dataset_name in self.dataset_name:
|
231 |
+
ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
|
232 |
+
df = pd.DataFrame(ds)
|
233 |
+
if dataset_name == 'WizardLMTeam/WizardLM_evol_instruct_70k':
|
234 |
+
df['input'] = [''] * len(df)
|
235 |
+
datasets.append(df)
|
236 |
+
dataset = pd.concat(datasets, ignore_index=True)
|
237 |
+
self.list_dataset, global_test = self.get_split_dataset(dataset)
|
238 |
+
global global_test_set_hete
|
239 |
+
global_test_set_hete.update(
|
240 |
+
{self.metadata['domain']: global_test}
|
241 |
+
)
|
242 |
+
|
243 |
+
def release_ds():
|
244 |
+
data_domain = {
|
245 |
+
'general': GeneralDataset().list_dataset,
|
246 |
+
'finance': FinanceDataset().list_dataset,
|
247 |
+
'math': MathDataset().list_dataset,
|
248 |
+
'medical': MedicalDataset().list_dataset,
|
249 |
+
'code': CodeDataset().list_dataset
|
250 |
+
}
|
251 |
+
tmp_dataset = {}
|
252 |
+
k = 0
|
253 |
+
for task in data_domain.keys():
|
254 |
+
tmp_dataset[str(k)] = data_domain[task][0]
|
255 |
+
tmp_dataset[str(k+1)] = data_domain[task][1]
|
256 |
+
k += 2
|
257 |
+
|
258 |
+
return tmp_dataset
|
259 |
+
|
260 |
+
# data_domain = {
|
261 |
+
# 'general': GeneralDataset().list_dataset,
|
262 |
+
# 'finance': FinanceDataset().list_dataset,
|
263 |
+
# 'math': MathDataset().list_dataset,
|
264 |
+
# 'medical': MedicalDataset().list_dataset,
|
265 |
+
# 'code': CodeDataset().list_dataset
|
266 |
+
# }
|
267 |
+
|
268 |
+
# client_id_dataset = {
|
269 |
+
# '0': data_domain['general'][0],
|
270 |
+
# '1': data_domain['general'][1],
|
271 |
+
# '2': data_domain['finance'][0],
|
272 |
+
# '3': data_domain['finance'][1],
|
273 |
+
# '4': data_domain['math'][0],
|
274 |
+
# '5': data_domain['math'][1],
|
275 |
+
# '6': data_domain['medical'][0],
|
276 |
+
# '7': data_domain['medical'][1],
|
277 |
+
# '8': data_domain['code'][0],
|
278 |
+
# '9': data_domain['code'][1],
|
279 |
+
# }
|
280 |
+
|
281 |
+
client_id_dataset = release_ds()
|
template_FL/src/fedllm/dataset.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trl import DataCollatorForCompletionOnlyLM
|
2 |
+
|
3 |
+
from flwr_datasets.partitioner import IidPartitioner
|
4 |
+
from flwr_datasets import FederatedDataset
|
5 |
+
from datasets import Dataset, DatasetDict
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
FDS = None # Cache FederatedDataset
|
11 |
+
client_id_ds = None
|
12 |
+
global_test_set_homo = None
|
13 |
+
|
14 |
+
def split_train_test(dataset, test_size):
|
15 |
+
# Split the dataset into train and test sets
|
16 |
+
train_data, test_data = train_test_split(dataset.to_pandas(), test_size=test_size, shuffle=True, random_state=42)
|
17 |
+
test_data, global_test = train_test_split(test_data, test_size=0.1, shuffle=True, random_state=42)
|
18 |
+
|
19 |
+
# Convert to Dataset objects
|
20 |
+
train_dataset = Dataset.from_pandas(train_data)
|
21 |
+
test_dataset = Dataset.from_pandas(test_data)
|
22 |
+
global_test = Dataset.from_pandas(global_test)
|
23 |
+
|
24 |
+
# Combine into a DatasetDict
|
25 |
+
datasets_dict = DatasetDict({
|
26 |
+
'train': train_dataset,
|
27 |
+
'test': test_dataset
|
28 |
+
})
|
29 |
+
return datasets_dict
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def formatting_prompts_func(example):
|
34 |
+
output_texts = []
|
35 |
+
# Constructing a standard Alpaca (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
|
36 |
+
mssg = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
37 |
+
for i in range(len(example["instruction"])):
|
38 |
+
text = f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n### Response: {example['response'][i]}"
|
39 |
+
output_texts.append(text)
|
40 |
+
return output_texts
|
41 |
+
|
42 |
+
|
43 |
+
def get_data_collator_and_propt_formatting(tokenizer):
|
44 |
+
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
45 |
+
response_template_with_context = "\n### Response:" # alpaca response tag
|
46 |
+
response_template_ids = tokenizer.encode(
|
47 |
+
response_template_with_context, add_special_tokens=False
|
48 |
+
)[2:]
|
49 |
+
data_collator = DataCollatorForCompletionOnlyLM(
|
50 |
+
response_template_ids, tokenizer=tokenizer
|
51 |
+
)
|
52 |
+
|
53 |
+
return data_collator, formatting_prompts_func
|
54 |
+
|
55 |
+
|
56 |
+
def load_data(partition_id: int, num_partitions: int, dataset_name: str):
|
57 |
+
"""Load partition data."""
|
58 |
+
# Only initialize `FederatedDataset` once
|
59 |
+
global FDS
|
60 |
+
if FDS is None:
|
61 |
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
62 |
+
FDS = FederatedDataset(
|
63 |
+
dataset=dataset_name,
|
64 |
+
partitioners={"train": partitioner},
|
65 |
+
)
|
66 |
+
print(f"<---- Load client {partition_id} --->")
|
67 |
+
client_trainset = FDS.load_partition(partition_id, "train")
|
68 |
+
client_trainset = client_trainset.rename_column("output", "response")
|
69 |
+
print(client_trainset)
|
70 |
+
return client_trainset
|
71 |
+
|
72 |
+
def load_data_homo(partition_id: int, num_partitions: int, dataset_name: str):
|
73 |
+
"""Load partition data."""
|
74 |
+
# Only initialize `FederatedDataset` once
|
75 |
+
global FDS
|
76 |
+
global global_test_set_homo
|
77 |
+
if FDS is None:
|
78 |
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
79 |
+
FDS = FederatedDataset(
|
80 |
+
dataset=dataset_name,
|
81 |
+
partitioners={"train": partitioner},
|
82 |
+
)
|
83 |
+
# list_ds = []
|
84 |
+
# for cid in range(0,num_partitions):
|
85 |
+
# tmp_set = FDS.load_partition(cid, "train")
|
86 |
+
# list_ds.append(
|
87 |
+
# pd.DataFrame(tmp_set)
|
88 |
+
# )
|
89 |
+
# list_ds = pd.concat(list_ds, ignore_index=True)
|
90 |
+
# _, global_test_set_homo = train_test_split(
|
91 |
+
# list_ds, test_size=0.1, shuffle=True, random_state=42
|
92 |
+
# )
|
93 |
+
# global_test_set_homo = Dataset.from_pandas(global_test_set_homo).remove_columns(['__index_level_0__'])
|
94 |
+
|
95 |
+
print(f"<---- Load client {partition_id} --->")
|
96 |
+
client_trainset = FDS.load_partition(partition_id, "train")
|
97 |
+
# client_trainset = client_trainset.rename_column("output", "response")
|
98 |
+
client_set = split_train_test(client_trainset, test_size=0.2)
|
99 |
+
return client_set
|
100 |
+
|
101 |
+
|
102 |
+
def load_data_hete(partition_id: int):
|
103 |
+
"""Load partition data heterogeneous"""
|
104 |
+
global client_id_ds
|
105 |
+
if client_id_ds is None:
|
106 |
+
from .data_domains import client_id_dataset
|
107 |
+
client_id_ds = client_id_dataset
|
108 |
+
print(f"<---- Load client {partition_id} --->")
|
109 |
+
client_set = client_id_ds[str(partition_id)]
|
110 |
+
return client_set
|
111 |
+
|
112 |
+
|
113 |
+
def replace_keys(input_dict, match="-", target="_"):
|
114 |
+
"""Recursively replace match string with target string in dictionary keys."""
|
115 |
+
new_dict = {}
|
116 |
+
for key, value in input_dict.items():
|
117 |
+
new_key = key.replace(match, target)
|
118 |
+
if isinstance(value, dict):
|
119 |
+
new_dict[new_key] = replace_keys(value, match, target)
|
120 |
+
else:
|
121 |
+
new_dict[new_key] = value
|
122 |
+
return new_dict
|
template_FL/src/fedllm/flwr_mods.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flwr.common import Context, Message, MessageType, ConfigsRecord
|
2 |
+
from flwr.client.typing import ClientAppCallable
|
3 |
+
from typing import Callable
|
4 |
+
import wandb
|
5 |
+
import time
|
6 |
+
from .myfedavg import client_id_idx
|
7 |
+
|
8 |
+
|
9 |
+
# Define type alias for Mod
|
10 |
+
Mod = Callable[[Message, Context, ClientAppCallable], Message]
|
11 |
+
|
12 |
+
def get_wandb_mod(name: str) -> Mod:
|
13 |
+
# Keep track of active runs
|
14 |
+
active_run: Optional[wandb.Run] = None
|
15 |
+
|
16 |
+
def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message:
|
17 |
+
nonlocal active_run
|
18 |
+
server_round = int(msg.metadata.group_id)
|
19 |
+
run_id = msg.metadata.run_id
|
20 |
+
group_name = f"Run ID: {run_id}"
|
21 |
+
node_id = str(msg.metadata.dst_node_id)
|
22 |
+
run_name = f"Client ID: {client_id_idx[node_id]}"
|
23 |
+
|
24 |
+
wandb.init(
|
25 |
+
project=name,
|
26 |
+
group=group_name,
|
27 |
+
name=run_name,
|
28 |
+
id=f"{run_id}_{client_id_idx[node_id]}",
|
29 |
+
resume="allow",
|
30 |
+
reinit=True,
|
31 |
+
# settings=wandb.Settings(start_method="thread")
|
32 |
+
)
|
33 |
+
|
34 |
+
start_time = time.time()
|
35 |
+
reply = app(msg, context)
|
36 |
+
|
37 |
+
if reply.metadata.message_type == MessageType.TRAIN and reply.has_content():
|
38 |
+
|
39 |
+
time_diff = time.time() - start_time
|
40 |
+
metrics = reply.content.configs_records
|
41 |
+
results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord()))
|
42 |
+
results_to_log["fit_time"] = time_diff
|
43 |
+
|
44 |
+
wandb.log(results_to_log, step=int(server_round), commit=True)
|
45 |
+
|
46 |
+
return reply
|
47 |
+
|
48 |
+
return wandb_mod
|
49 |
+
|
template_FL/src/fedllm/make_data.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
class Prompter(object):
|
7 |
+
__slots__ = ("template", "_verbose")
|
8 |
+
|
9 |
+
def __init__(self, template_name: str = "", verbose: bool = False):
|
10 |
+
self._verbose = verbose
|
11 |
+
if not template_name:
|
12 |
+
# Enforece the default here, so the constructor can be called with '' and will not break.
|
13 |
+
template_name = "alpaca"
|
14 |
+
file_name = osp.join(
|
15 |
+
os.getcwd(), "fedllm/templates", f"{template_name}.json"
|
16 |
+
)
|
17 |
+
|
18 |
+
if not osp.exists(file_name):
|
19 |
+
raise ValueError(f"Can't read {file_name}")
|
20 |
+
with open(file=file_name) as fp:
|
21 |
+
self.template = json.load(fp)
|
22 |
+
if self._verbose:
|
23 |
+
print(
|
24 |
+
f"Using prompt template {template_name}: {self.template['description']}"
|
25 |
+
)
|
26 |
+
|
27 |
+
def generate_prompt(
|
28 |
+
self,
|
29 |
+
instruction: str,
|
30 |
+
input: Union[None, str] = None,
|
31 |
+
label: Union[None, str] = None,
|
32 |
+
) -> str:
|
33 |
+
# returns the full prompt from instruction and optional input
|
34 |
+
# if a label (=response, =output) is provided, it's also appended.
|
35 |
+
if input:
|
36 |
+
res = self.template["prompt_input"].format(
|
37 |
+
instruction=instruction,
|
38 |
+
input=input,
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
res = self.template["prompt_no_input"].format(
|
42 |
+
instruction=instruction,
|
43 |
+
)
|
44 |
+
if label:
|
45 |
+
res = f"{res}{label}"
|
46 |
+
if self._verbose:
|
47 |
+
print(res)
|
48 |
+
return res
|
49 |
+
|
50 |
+
def get_reponse(self, output: str) -> str:
|
51 |
+
return output.split(self.template["response_split"])[1].strip()
|
52 |
+
|
53 |
+
|
54 |
+
def tokenize(tokenizer, prompt, cutoff_len=512, add_eos_token=True):
|
55 |
+
result = tokenizer(
|
56 |
+
prompt,
|
57 |
+
truncation=True,
|
58 |
+
max_length=cutoff_len,
|
59 |
+
padding=False,
|
60 |
+
return_tensors=None,
|
61 |
+
)
|
62 |
+
if (
|
63 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
64 |
+
and len(result["input_ids"]) < cutoff_len
|
65 |
+
and add_eos_token
|
66 |
+
):
|
67 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
68 |
+
result["attention_mask"].append(1)
|
69 |
+
|
70 |
+
result["labels"] = result["input_ids"].copy()
|
71 |
+
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
def generate_and_tokenize_prompt(data_point, **kwargs):
|
76 |
+
full_prompt = kwargs["prompter"].generate_prompt(
|
77 |
+
data_point["instruction"],
|
78 |
+
data_point["input"],
|
79 |
+
data_point["output"],
|
80 |
+
)
|
81 |
+
tokenized_full_prompt = tokenize(
|
82 |
+
kwargs["tokenizer"],
|
83 |
+
full_prompt,
|
84 |
+
cutoff_len=kwargs["seq_length"],
|
85 |
+
add_eos_token=True,
|
86 |
+
)
|
87 |
+
if not kwargs["train_on_inputs"]:
|
88 |
+
user_prompt = kwargs["prompter"].generate_prompt(
|
89 |
+
data_point["instruction"], data_point["input"]
|
90 |
+
)
|
91 |
+
tokenized_user_prompt = kwargs["tokenizer"](
|
92 |
+
user_prompt, add_eos_token=False
|
93 |
+
)
|
94 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
95 |
+
|
96 |
+
tokenized_full_prompt["labels"] = [
|
97 |
+
-100
|
98 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
99 |
+
user_prompt_len:
|
100 |
+
] # could be sped up, probably
|
101 |
+
return tokenized_full_prompt
|
template_FL/src/fedllm/metrics.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import evaluate
|
3 |
+
from rouge_score import rouge_scorer
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
from collections import OrderedDict, Counter
|
7 |
+
from .utils import clean_output_text
|
8 |
+
|
9 |
+
def get_answer(text):
|
10 |
+
# text = text.lower()
|
11 |
+
label = text.split('Response:')[-1].strip()
|
12 |
+
return label
|
13 |
+
|
14 |
+
def check_data_state(preds, targets):
|
15 |
+
assert len(preds) == len(targets)
|
16 |
+
|
17 |
+
def get_rouge_score(predictions, targets):
|
18 |
+
check_data_state(predictions, targets)
|
19 |
+
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
|
20 |
+
scores = {
|
21 |
+
'rouge1': 0.0,
|
22 |
+
'rouge2': 0.0,
|
23 |
+
'rougeL': 0.0,
|
24 |
+
'rougeLsum': 0.0
|
25 |
+
}
|
26 |
+
for prediction, target in zip(predictions, targets):
|
27 |
+
prediction = get_answer(clean_output_text(prediction))
|
28 |
+
target = get_answer(clean_output_text(target))
|
29 |
+
rouge_output = rouge.score(prediction=prediction, target=target)
|
30 |
+
scores['rouge1'] += round(rouge_output["rouge1"].fmeasure, 4)
|
31 |
+
scores['rouge2'] += round(rouge_output["rouge2"].fmeasure, 4)
|
32 |
+
scores['rougeL'] += round(rouge_output["rougeL"].fmeasure, 4)
|
33 |
+
scores['rougeLsum'] += round(rouge_output["rougeLsum"].fmeasure, 4)
|
34 |
+
|
35 |
+
|
36 |
+
scores['rouge1'] /= len(predictions)
|
37 |
+
scores['rouge2'] /= len(predictions)
|
38 |
+
scores['rougeL'] /= len(predictions)
|
39 |
+
scores['rougeLsum'] /= len(predictions)
|
40 |
+
return scores
|
41 |
+
|
42 |
+
def exact_match(predictions, targets):
|
43 |
+
check_data_state(predictions, targets)
|
44 |
+
predictions = [get_answer(clean_output_text(prediction)) for prediction in predictions]
|
45 |
+
targets = [get_answer(clean_output_text(target)) for target in targets]
|
46 |
+
|
47 |
+
preds, targets = np.asarray(predictions, dtype="<U16"), np.asarray(targets, dtype="<U16")
|
48 |
+
|
49 |
+
# print(preds, targets)
|
50 |
+
return {"exact_match": np.sum(preds == targets) / preds.size}
|
51 |
+
|
52 |
+
def _f1_score(prediction, target):
|
53 |
+
prediction_tokens = prediction.split()
|
54 |
+
target_tokens = target.split()
|
55 |
+
common = Counter(prediction_tokens) & Counter(target_tokens)
|
56 |
+
num_same = sum(common.values())
|
57 |
+
if num_same == 0:
|
58 |
+
return 0
|
59 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
60 |
+
recall = 1.0 * num_same / len(target_tokens)
|
61 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
62 |
+
return f1
|
63 |
+
|
64 |
+
def f1(predictions, targets):
|
65 |
+
check_data_state(predictions, targets)
|
66 |
+
f1_score = 0.0
|
67 |
+
for prediction, target in zip(predictions, targets):
|
68 |
+
prediction = get_answer(clean_output_text(prediction))
|
69 |
+
target = get_answer(clean_output_text(target))
|
70 |
+
f1_score += _f1_score(prediction=prediction, target=target)
|
71 |
+
|
72 |
+
f1_score /= len(predictions)
|
73 |
+
return {'f1': f1_score}
|
template_FL/src/fedllm/models.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from collections import OrderedDict
|
7 |
+
from peft import (
|
8 |
+
LoraConfig,
|
9 |
+
get_peft_model,
|
10 |
+
get_peft_model_state_dict,
|
11 |
+
set_peft_model_state_dict,
|
12 |
+
)
|
13 |
+
from peft.utils import prepare_model_for_kbit_training
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback, BertForSequenceClassification
|
15 |
+
|
16 |
+
from flwr.common.typing import NDArrays
|
17 |
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
18 |
+
from transformers.training_args import TrainingArguments
|
19 |
+
from thop import profile
|
20 |
+
import wandb
|
21 |
+
from typing import Dict, List
|
22 |
+
import copy
|
23 |
+
import time
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
|
27 |
+
def cosine_annealing(
|
28 |
+
current_round: int,
|
29 |
+
total_round: int,
|
30 |
+
lrate_max: float = 0.001,
|
31 |
+
lrate_min: float = 0.0,
|
32 |
+
) -> float:
|
33 |
+
"""Implement cosine annealing learning rate schedule."""
|
34 |
+
|
35 |
+
cos_inner = math.pi * current_round / total_round
|
36 |
+
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
37 |
+
|
38 |
+
|
39 |
+
def get_model(model_cfg: DictConfig):
|
40 |
+
"""Load model with appropriate quantization config and other optimizations.
|
41 |
+
|
42 |
+
Please refer to this example for `peft + BitsAndBytes`:
|
43 |
+
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
44 |
+
"""
|
45 |
+
use_cuda = torch.cuda.is_available()
|
46 |
+
device_map = torch.device("cuda:0" if use_cuda else "cpu")
|
47 |
+
if model_cfg.quantization == 4:
|
48 |
+
quantization_config = BitsAndBytesConfig(
|
49 |
+
load_in_4bit=True,
|
50 |
+
bnb_4bit_use_double_quant=True,
|
51 |
+
bnb_4bit_quant_type="nf4",
|
52 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
53 |
+
)
|
54 |
+
elif model_cfg.quantization == 8:
|
55 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
56 |
+
else:
|
57 |
+
raise ValueError(
|
58 |
+
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
59 |
+
)
|
60 |
+
|
61 |
+
model = AutoModelForCausalLM.from_pretrained(
|
62 |
+
model_cfg.name,
|
63 |
+
quantization_config=quantization_config,
|
64 |
+
# torch_dtype=torch.bfloat16,
|
65 |
+
attn_implementation=(
|
66 |
+
"flash_attention_2" if model_cfg.flash_attention else "eager"
|
67 |
+
),
|
68 |
+
).to(device_map)
|
69 |
+
|
70 |
+
if use_cuda:
|
71 |
+
model = prepare_model_for_kbit_training(
|
72 |
+
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
# Get tokenizer
|
77 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
78 |
+
model_cfg.name, use_fast=True, padding_side="right"
|
79 |
+
)
|
80 |
+
tokenizer.pad_token = tokenizer.eos_token
|
81 |
+
|
82 |
+
peft_config = LoraConfig(
|
83 |
+
r=model_cfg.lora.lora_r,
|
84 |
+
lora_alpha=model_cfg.lora.lora_alpha,
|
85 |
+
lora_dropout=model_cfg.lora.lora_dropout,
|
86 |
+
target_modules=model_cfg.lora.lora_target_modules.split(", "),
|
87 |
+
bias="none",
|
88 |
+
task_type="CAUSAL_LM",
|
89 |
+
)
|
90 |
+
|
91 |
+
return get_peft_model(model, peft_config), tokenizer
|
92 |
+
|
93 |
+
def get_data_influence_model(model_cfg: DictConfig):
|
94 |
+
use_cuda = torch.cuda.is_available()
|
95 |
+
device_map = torch.device("cuda:0" if use_cuda else "cpu")
|
96 |
+
|
97 |
+
# Load model with num_labels=1
|
98 |
+
model = BertForSequenceClassification.from_pretrained(
|
99 |
+
"bert-base-uncased",
|
100 |
+
num_labels=1, # Set number of labels to 1 for regression or single-class tasks
|
101 |
+
).to(device_map)
|
102 |
+
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
104 |
+
|
105 |
+
if use_cuda:
|
106 |
+
model = prepare_model_for_kbit_training(
|
107 |
+
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
108 |
+
)
|
109 |
+
|
110 |
+
return model, tokenizer
|
111 |
+
|
112 |
+
|
113 |
+
def set_parameters(model, parameters: NDArrays) -> None:
|
114 |
+
"""Change the parameters of the model using the given ones."""
|
115 |
+
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
|
116 |
+
params_dict = zip(peft_state_dict_keys, parameters)
|
117 |
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
118 |
+
set_peft_model_state_dict(model, state_dict)
|
119 |
+
|
120 |
+
|
121 |
+
def get_parameters(model) -> NDArrays:
|
122 |
+
"""Return the parameters of the current net."""
|
123 |
+
state_dict = get_peft_model_state_dict(model)
|
124 |
+
return [val.cpu().numpy() for _, val in state_dict.items()]
|
125 |
+
|
126 |
+
def model_parameters_to_ndarrays(model):
|
127 |
+
"""
|
128 |
+
Convert the parameters of a HuggingFace model into a list of NDArrays.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
model (torch.nn.Module): The HuggingFace model.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
list[NDArrays]: A list of NumPy arrays representing the model's parameters.
|
135 |
+
"""
|
136 |
+
ndarrays = []
|
137 |
+
for param_tensor in model.state_dict().values():
|
138 |
+
# Convert PyTorch tensor to NumPy array
|
139 |
+
ndarrays.append(param_tensor.cpu().numpy())
|
140 |
+
return ndarrays
|
141 |
+
|
142 |
+
|
143 |
+
def concatenate_models_with_marker(main_model_params: list[NDArrays],
|
144 |
+
data_influence_model_params: list[NDArrays],
|
145 |
+
marker_value: float = np.nan) -> list[NDArrays]:
|
146 |
+
"""
|
147 |
+
Concatenate two models' parameters with a unique marker.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
main_model_params (list[NDArrays]): Parameters of the main model as NDArrays.
|
151 |
+
data_influence_model_params (list[NDArrays]): Parameters of the data influence model as NDArrays.
|
152 |
+
marker_value (float): A unique marker value to separate the two models.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
list[NDArrays]: A single list of NDArrays with the unique marker separating the models.
|
156 |
+
"""
|
157 |
+
marker = np.array([marker_value]) # Unique marker
|
158 |
+
concatenated_params = main_model_params + [marker] + data_influence_model_params
|
159 |
+
return concatenated_params
|
160 |
+
|
161 |
+
|
162 |
+
def split_models(concatenated_model: list[NDArrays]) -> tuple[list[NDArrays], list[NDArrays]]:
|
163 |
+
"""Split the concatenated model back into main and data influence models."""
|
164 |
+
# Find the marker's index
|
165 |
+
marker_index = next(
|
166 |
+
(i for i, param in enumerate(concatenated_model) if np.isnan(param).all()),
|
167 |
+
-1,
|
168 |
+
)
|
169 |
+
if marker_index == -1:
|
170 |
+
raise ValueError("Marker not found in the concatenated model parameters.")
|
171 |
+
|
172 |
+
main_model = concatenated_model[:marker_index]
|
173 |
+
data_influence_model = concatenated_model[marker_index + 1 :]
|
174 |
+
return main_model, data_influence_model
|
175 |
+
|
176 |
+
|
177 |
+
def set_parameters_bert(model: BertForSequenceClassification, parameters: list[NDArrays]) -> None:
|
178 |
+
"""
|
179 |
+
Set the parameters of a BertForSequenceClassification model using the given ones.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
model (BertForSequenceClassification): The model whose parameters need to be updated.
|
183 |
+
parameters (list[NDArrays]): A list of NumPy arrays representing the parameters.
|
184 |
+
"""
|
185 |
+
# Get the state_dict keys from the model
|
186 |
+
state_dict_keys = model.state_dict().keys()
|
187 |
+
|
188 |
+
# Ensure the number of parameters matches the model's state_dict
|
189 |
+
if len(parameters) != len(state_dict_keys):
|
190 |
+
raise ValueError(
|
191 |
+
f"Number of parameters ({len(parameters)}) does not match "
|
192 |
+
f"the number of state_dict keys ({len(state_dict_keys)})."
|
193 |
+
)
|
194 |
+
|
195 |
+
# Create an OrderedDict to update the model
|
196 |
+
params_dict = zip(state_dict_keys, parameters)
|
197 |
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
198 |
+
|
199 |
+
# Load the updated state_dict into the model
|
200 |
+
model.load_state_dict(state_dict)
|
template_FL/src/fedllm/myaggregation.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Aggregation functions for strategy implementations."""
|
16 |
+
# mypy: disallow_untyped_calls=False
|
17 |
+
|
18 |
+
from functools import partial, reduce
|
19 |
+
from typing import Any, Callable, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays
|
24 |
+
from flwr.server.client_proxy import ClientProxy
|
25 |
+
|
26 |
+
from .models import split_models
|
27 |
+
|
28 |
+
|
29 |
+
def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
30 |
+
"""Compute weighted average."""
|
31 |
+
# Calculate the total number of examples used during training
|
32 |
+
num_examples_total = sum(num_examples for (_, num_examples) in results)
|
33 |
+
|
34 |
+
# Create a list of weights, each multiplied by the related number of examples
|
35 |
+
weighted_weights = [
|
36 |
+
[layer * num_examples for layer in weights] for weights, num_examples in results
|
37 |
+
]
|
38 |
+
|
39 |
+
# Compute average weights of each layer
|
40 |
+
weights_prime: NDArrays = [
|
41 |
+
reduce(np.add, layer_updates) / num_examples_total
|
42 |
+
for layer_updates in zip(*weighted_weights)
|
43 |
+
]
|
44 |
+
return weights_prime
|
45 |
+
|
46 |
+
|
47 |
+
def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
48 |
+
"""Compute in-place weighted average."""
|
49 |
+
# Count total examples
|
50 |
+
num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
|
51 |
+
|
52 |
+
# Compute scaling factors for each result
|
53 |
+
scaling_factors = np.asarray(
|
54 |
+
[fit_res.num_examples / num_examples_total for _, fit_res in results]
|
55 |
+
)
|
56 |
+
|
57 |
+
def _try_inplace(
|
58 |
+
x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
|
59 |
+
) -> NDArray:
|
60 |
+
return ( # type: ignore[no-any-return]
|
61 |
+
np_binary_op(x, y, out=x)
|
62 |
+
if np.can_cast(y, x.dtype, casting="same_kind")
|
63 |
+
else np_binary_op(x, np.array(y, x.dtype), out=x)
|
64 |
+
)
|
65 |
+
|
66 |
+
# Let's do in-place aggregation
|
67 |
+
# Get first result, then add up each other
|
68 |
+
# print(f"Param : {results[0][1].parameters}. Type: {type(results[0][1].parameters)}")
|
69 |
+
params = [
|
70 |
+
_try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
|
71 |
+
for x in parameters_to_ndarrays(results[0][1].parameters)
|
72 |
+
]
|
73 |
+
|
74 |
+
for i, (_, fit_res) in enumerate(results[1:], start=1):
|
75 |
+
res = (
|
76 |
+
_try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
|
77 |
+
for x in parameters_to_ndarrays(fit_res.parameters)
|
78 |
+
)
|
79 |
+
params = [
|
80 |
+
reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
|
81 |
+
for layer_updates in zip(params, res)
|
82 |
+
]
|
83 |
+
|
84 |
+
return params
|
85 |
+
|
86 |
+
def aggregate_inplace_mates(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
87 |
+
"""Aggregate main model and data influence model separately."""
|
88 |
+
num_examples_total = sum(fit_res.num_examples for _, fit_res in results)
|
89 |
+
scaling_factors = [
|
90 |
+
fit_res.num_examples / num_examples_total for _, fit_res in results
|
91 |
+
]
|
92 |
+
|
93 |
+
aggregated_main_model = None
|
94 |
+
aggregated_data_influence_model = None
|
95 |
+
|
96 |
+
for i, (_, fit_res) in enumerate(results):
|
97 |
+
# Convert parameters to NDArrays and split into main and data influence models
|
98 |
+
concatenated_params = parameters_to_ndarrays(fit_res.parameters)
|
99 |
+
main_model, data_influence_model = split_models(concatenated_params)
|
100 |
+
|
101 |
+
# Scale the models by the scaling factor
|
102 |
+
scaled_main_model = [x * scaling_factors[i] for x in main_model]
|
103 |
+
scaled_data_influence_model = [x * scaling_factors[i] for x in data_influence_model]
|
104 |
+
|
105 |
+
# Aggregate in-place
|
106 |
+
if aggregated_main_model is None:
|
107 |
+
aggregated_main_model = scaled_main_model
|
108 |
+
aggregated_data_influence_model = scaled_data_influence_model
|
109 |
+
else:
|
110 |
+
aggregated_main_model = [
|
111 |
+
x + y for x, y in zip(aggregated_main_model, scaled_main_model)
|
112 |
+
]
|
113 |
+
aggregated_data_influence_model = [
|
114 |
+
x + y for x, y in zip(aggregated_data_influence_model, scaled_data_influence_model)
|
115 |
+
]
|
116 |
+
|
117 |
+
return concatenate_models_with_marker(aggregated_main_model, aggregated_data_influence_model)
|
118 |
+
|
119 |
+
|
120 |
+
def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
121 |
+
"""Compute median."""
|
122 |
+
# Create a list of weights and ignore the number of examples
|
123 |
+
weights = [weights for weights, _ in results]
|
124 |
+
|
125 |
+
# Compute median weight of each layer
|
126 |
+
median_w: NDArrays = [
|
127 |
+
np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
|
128 |
+
]
|
129 |
+
return median_w
|
130 |
+
|
131 |
+
|
132 |
+
def aggregate_krum(
|
133 |
+
results: list[tuple[NDArrays, int]], num_malicious: int, to_keep: int
|
134 |
+
) -> NDArrays:
|
135 |
+
"""Choose one parameter vector according to the Krum function.
|
136 |
+
|
137 |
+
If to_keep is not None, then MultiKrum is applied.
|
138 |
+
"""
|
139 |
+
# Create a list of weights and ignore the number of examples
|
140 |
+
weights = [weights for weights, _ in results]
|
141 |
+
|
142 |
+
# Compute distances between vectors
|
143 |
+
distance_matrix = _compute_distances(weights)
|
144 |
+
|
145 |
+
# For each client, take the n-f-2 closest parameters vectors
|
146 |
+
num_closest = max(1, len(weights) - num_malicious - 2)
|
147 |
+
closest_indices = []
|
148 |
+
for distance in distance_matrix:
|
149 |
+
closest_indices.append(
|
150 |
+
np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
|
151 |
+
)
|
152 |
+
|
153 |
+
# Compute the score for each client, that is the sum of the distances
|
154 |
+
# of the n-f-2 closest parameters vectors
|
155 |
+
scores = [
|
156 |
+
np.sum(distance_matrix[i, closest_indices[i]])
|
157 |
+
for i in range(len(distance_matrix))
|
158 |
+
]
|
159 |
+
|
160 |
+
if to_keep > 0:
|
161 |
+
# Choose to_keep clients and return their average (MultiKrum)
|
162 |
+
best_indices = np.argsort(scores)[::-1][len(scores) - to_keep :] # noqa: E203
|
163 |
+
best_results = [results[i] for i in best_indices]
|
164 |
+
return aggregate(best_results)
|
165 |
+
|
166 |
+
# Return the model parameters that minimize the score (Krum)
|
167 |
+
return weights[np.argmin(scores)]
|
168 |
+
|
169 |
+
|
170 |
+
# pylint: disable=too-many-locals
|
171 |
+
def aggregate_bulyan(
|
172 |
+
results: list[tuple[NDArrays, int]],
|
173 |
+
num_malicious: int,
|
174 |
+
aggregation_rule: Callable, # type: ignore
|
175 |
+
**aggregation_rule_kwargs: Any,
|
176 |
+
) -> NDArrays:
|
177 |
+
"""Perform Bulyan aggregation.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
results: list[tuple[NDArrays, int]]
|
182 |
+
Weights and number of samples for each of the client.
|
183 |
+
num_malicious: int
|
184 |
+
The maximum number of malicious clients.
|
185 |
+
aggregation_rule: Callable
|
186 |
+
Byzantine resilient aggregation rule used as the first step of the Bulyan
|
187 |
+
aggregation_rule_kwargs: Any
|
188 |
+
The arguments to the aggregation rule.
|
189 |
+
|
190 |
+
Returns
|
191 |
+
-------
|
192 |
+
aggregated_parameters: NDArrays
|
193 |
+
Aggregated parameters according to the Bulyan strategy.
|
194 |
+
"""
|
195 |
+
byzantine_resilient_single_ret_model_aggregation = [aggregate_krum]
|
196 |
+
# also GeoMed (but not implemented yet)
|
197 |
+
byzantine_resilient_many_return_models_aggregation = [] # type: ignore
|
198 |
+
# Brute, Medoid (but not implemented yet)
|
199 |
+
|
200 |
+
num_clients = len(results)
|
201 |
+
if num_clients < 4 * num_malicious + 3:
|
202 |
+
raise ValueError(
|
203 |
+
"The Bulyan aggregation requires then number of clients to be greater or "
|
204 |
+
"equal to the 4 * num_malicious + 3. This is the assumption of this method."
|
205 |
+
"It is needed to ensure that the method reduces the attacker's leeway to "
|
206 |
+
"the one proved in the paper."
|
207 |
+
)
|
208 |
+
selected_models_set: list[tuple[NDArrays, int]] = []
|
209 |
+
|
210 |
+
theta = len(results) - 2 * num_malicious
|
211 |
+
beta = theta - 2 * num_malicious
|
212 |
+
|
213 |
+
for _ in range(theta):
|
214 |
+
best_model = aggregation_rule(
|
215 |
+
results=results, num_malicious=num_malicious, **aggregation_rule_kwargs
|
216 |
+
)
|
217 |
+
list_of_weights = [weights for weights, num_samples in results]
|
218 |
+
# This group gives exact result
|
219 |
+
if aggregation_rule in byzantine_resilient_single_ret_model_aggregation:
|
220 |
+
best_idx = _find_reference_weights(best_model, list_of_weights)
|
221 |
+
# This group requires finding the closest model to the returned one
|
222 |
+
# (weights distance wise)
|
223 |
+
elif aggregation_rule in byzantine_resilient_many_return_models_aggregation:
|
224 |
+
# when different aggregation strategies available
|
225 |
+
# write a function to find the closest model
|
226 |
+
raise NotImplementedError(
|
227 |
+
"aggregate_bulyan currently does not support the aggregation rules that"
|
228 |
+
" return many models as results. "
|
229 |
+
"Such aggregation rules are currently not available in Flower."
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
raise ValueError(
|
233 |
+
"The given aggregation rule is not added as Byzantine resilient. "
|
234 |
+
"Please choose from Byzantine resilient rules."
|
235 |
+
)
|
236 |
+
|
237 |
+
selected_models_set.append(results[best_idx])
|
238 |
+
|
239 |
+
# remove idx from tracker and weights_results
|
240 |
+
results.pop(best_idx)
|
241 |
+
|
242 |
+
# Compute median parameter vector across selected_models_set
|
243 |
+
median_vect = aggregate_median(selected_models_set)
|
244 |
+
|
245 |
+
# Take the averaged beta parameters of the closest distance to the median
|
246 |
+
# (coordinate-wise)
|
247 |
+
parameters_aggregated = _aggregate_n_closest_weights(
|
248 |
+
median_vect, selected_models_set, beta_closest=beta
|
249 |
+
)
|
250 |
+
return parameters_aggregated
|
251 |
+
|
252 |
+
|
253 |
+
def weighted_loss_avg(results: list[tuple[int, float]]) -> float:
|
254 |
+
"""Aggregate evaluation results obtained from multiple clients."""
|
255 |
+
num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results)
|
256 |
+
weighted_losses = [num_examples * loss for num_examples, loss in results]
|
257 |
+
return sum(weighted_losses) / num_total_evaluation_examples
|
258 |
+
|
259 |
+
|
260 |
+
def aggregate_qffl(
|
261 |
+
parameters: NDArrays, deltas: list[NDArrays], hs_fll: list[NDArrays]
|
262 |
+
) -> NDArrays:
|
263 |
+
"""Compute weighted average based on Q-FFL paper."""
|
264 |
+
demominator: float = np.sum(np.asarray(hs_fll))
|
265 |
+
scaled_deltas = []
|
266 |
+
for client_delta in deltas:
|
267 |
+
scaled_deltas.append([layer * 1.0 / demominator for layer in client_delta])
|
268 |
+
updates = []
|
269 |
+
for i in range(len(deltas[0])):
|
270 |
+
tmp = scaled_deltas[0][i]
|
271 |
+
for j in range(1, len(deltas)):
|
272 |
+
tmp += scaled_deltas[j][i]
|
273 |
+
updates.append(tmp)
|
274 |
+
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
|
275 |
+
return new_parameters
|
276 |
+
|
277 |
+
|
278 |
+
def _compute_distances(weights: list[NDArrays]) -> NDArray:
|
279 |
+
"""Compute distances between vectors.
|
280 |
+
|
281 |
+
Input: weights - list of weights vectors
|
282 |
+
Output: distances - matrix distance_matrix of squared distances between the vectors
|
283 |
+
"""
|
284 |
+
flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights])
|
285 |
+
distance_matrix = np.zeros((len(weights), len(weights)))
|
286 |
+
for i, flat_w_i in enumerate(flat_w):
|
287 |
+
for j, flat_w_j in enumerate(flat_w):
|
288 |
+
delta = flat_w_i - flat_w_j
|
289 |
+
norm = np.linalg.norm(delta)
|
290 |
+
distance_matrix[i, j] = norm**2
|
291 |
+
return distance_matrix
|
292 |
+
|
293 |
+
|
294 |
+
def _trim_mean(array: NDArray, proportiontocut: float) -> NDArray:
|
295 |
+
"""Compute trimmed mean along axis=0.
|
296 |
+
|
297 |
+
It is based on the scipy implementation.
|
298 |
+
|
299 |
+
https://docs.scipy.org/doc/scipy/reference/generated/
|
300 |
+
scipy.stats.trim_mean.html.
|
301 |
+
"""
|
302 |
+
axis = 0
|
303 |
+
nobs = array.shape[axis]
|
304 |
+
lowercut = int(proportiontocut * nobs)
|
305 |
+
uppercut = nobs - lowercut
|
306 |
+
if lowercut > uppercut:
|
307 |
+
raise ValueError("Proportion too big.")
|
308 |
+
|
309 |
+
atmp = np.partition(array, (lowercut, uppercut - 1), axis)
|
310 |
+
|
311 |
+
slice_list = [slice(None)] * atmp.ndim
|
312 |
+
slice_list[axis] = slice(lowercut, uppercut)
|
313 |
+
result: NDArray = np.mean(atmp[tuple(slice_list)], axis=axis)
|
314 |
+
return result
|
315 |
+
|
316 |
+
|
317 |
+
def aggregate_trimmed_avg(
|
318 |
+
results: list[tuple[NDArrays, int]], proportiontocut: float
|
319 |
+
) -> NDArrays:
|
320 |
+
"""Compute trimmed average."""
|
321 |
+
# Create a list of weights and ignore the number of examples
|
322 |
+
weights = [weights for weights, _ in results]
|
323 |
+
|
324 |
+
trimmed_w: NDArrays = [
|
325 |
+
_trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
|
326 |
+
for layer in zip(*weights)
|
327 |
+
]
|
328 |
+
|
329 |
+
return trimmed_w
|
330 |
+
|
331 |
+
|
332 |
+
def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
|
333 |
+
"""Check if weights are the same."""
|
334 |
+
if len(weights1) != len(weights2):
|
335 |
+
return False
|
336 |
+
return all(
|
337 |
+
np.array_equal(layer_weights1, layer_weights2)
|
338 |
+
for layer_weights1, layer_weights2 in zip(weights1, weights2)
|
339 |
+
)
|
340 |
+
|
341 |
+
|
342 |
+
def _find_reference_weights(
|
343 |
+
reference_weights: NDArrays, list_of_weights: list[NDArrays]
|
344 |
+
) -> int:
|
345 |
+
"""Find the reference weights by looping through the `list_of_weights`.
|
346 |
+
|
347 |
+
Raise Error if the reference weights is not found.
|
348 |
+
|
349 |
+
Parameters
|
350 |
+
----------
|
351 |
+
reference_weights: NDArrays
|
352 |
+
Weights that will be searched for.
|
353 |
+
list_of_weights: list[NDArrays]
|
354 |
+
list of weights that will be searched through.
|
355 |
+
|
356 |
+
Returns
|
357 |
+
-------
|
358 |
+
index: int
|
359 |
+
The index of `reference_weights` in the `list_of_weights`.
|
360 |
+
|
361 |
+
Raises
|
362 |
+
------
|
363 |
+
ValueError
|
364 |
+
If `reference_weights` is not found in `list_of_weights`.
|
365 |
+
"""
|
366 |
+
for idx, weights in enumerate(list_of_weights):
|
367 |
+
if _check_weights_equality(reference_weights, weights):
|
368 |
+
return idx
|
369 |
+
raise ValueError("The reference weights not found in list_of_weights.")
|
370 |
+
|
371 |
+
|
372 |
+
def _aggregate_n_closest_weights(
|
373 |
+
reference_weights: NDArrays, results: list[tuple[NDArrays, int]], beta_closest: int
|
374 |
+
) -> NDArrays:
|
375 |
+
"""Calculate element-wise mean of the `N` closest values.
|
376 |
+
|
377 |
+
Note, each i-th coordinate of the result weight is the average of the beta_closest
|
378 |
+
-ith coordinates to the reference weights
|
379 |
+
|
380 |
+
|
381 |
+
Parameters
|
382 |
+
----------
|
383 |
+
reference_weights: NDArrays
|
384 |
+
The weights from which the distances will be computed
|
385 |
+
results: list[tuple[NDArrays, int]]
|
386 |
+
The weights from models
|
387 |
+
beta_closest: int
|
388 |
+
The number of the closest distance weights that will be averaged
|
389 |
+
|
390 |
+
Returns
|
391 |
+
-------
|
392 |
+
aggregated_weights: NDArrays
|
393 |
+
Averaged (element-wise) beta weights that have the closest distance to
|
394 |
+
reference weights
|
395 |
+
"""
|
396 |
+
list_of_weights = [weights for weights, num_examples in results]
|
397 |
+
aggregated_weights = []
|
398 |
+
|
399 |
+
for layer_id, layer_weights in enumerate(reference_weights):
|
400 |
+
other_weights_layer_list = []
|
401 |
+
for other_w in list_of_weights:
|
402 |
+
other_weights_layer = other_w[layer_id]
|
403 |
+
other_weights_layer_list.append(other_weights_layer)
|
404 |
+
other_weights_layer_np = np.array(other_weights_layer_list)
|
405 |
+
diff_np = np.abs(layer_weights - other_weights_layer_np)
|
406 |
+
# Create indices of the smallest differences
|
407 |
+
# We do not need the exact order but just the beta closest weights
|
408 |
+
# therefore np.argpartition is used instead of np.argsort
|
409 |
+
indices = np.argpartition(diff_np, kth=beta_closest - 1, axis=0)
|
410 |
+
# Take the weights (coordinate-wise) corresponding to the beta of the
|
411 |
+
# closest distances
|
412 |
+
beta_closest_weights = np.take_along_axis(
|
413 |
+
other_weights_layer_np, indices=indices, axis=0
|
414 |
+
)[:beta_closest]
|
415 |
+
aggregated_weights.append(np.mean(beta_closest_weights, axis=0))
|
416 |
+
return aggregated_weights
|
template_FL/src/fedllm/myfedavg.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Federated Averaging (FedAvg) [McMahan et al., 2016] strategy.
|
16 |
+
|
17 |
+
Paper: arxiv.org/abs/1602.05629
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
from logging import WARNING
|
22 |
+
from typing import Callable, Optional, Union
|
23 |
+
|
24 |
+
from flwr.common import (
|
25 |
+
EvaluateIns,
|
26 |
+
EvaluateRes,
|
27 |
+
FitIns,
|
28 |
+
FitRes,
|
29 |
+
MetricsAggregationFn,
|
30 |
+
NDArrays,
|
31 |
+
Parameters,
|
32 |
+
Scalar,
|
33 |
+
ndarrays_to_parameters,
|
34 |
+
parameters_to_ndarrays,
|
35 |
+
)
|
36 |
+
from flwr.common.logger import log
|
37 |
+
from flwr.server.client_manager import ClientManager
|
38 |
+
from flwr.server.client_proxy import ClientProxy
|
39 |
+
|
40 |
+
from .myaggregation import aggregate, aggregate_inplace, weighted_loss_avg
|
41 |
+
from flwr.server.strategy import Strategy
|
42 |
+
|
43 |
+
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
|
44 |
+
Setting `min_available_clients` lower than `min_fit_clients` or
|
45 |
+
`min_evaluate_clients` can cause the server to fail when there are too few clients
|
46 |
+
connected to the server. `min_available_clients` must be set to a value larger
|
47 |
+
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
|
48 |
+
"""
|
49 |
+
|
50 |
+
client_id_idx = {}
|
51 |
+
|
52 |
+
# pylint: disable=line-too-long
|
53 |
+
class FedAvg(Strategy):
|
54 |
+
"""Federated Averaging strategy.
|
55 |
+
|
56 |
+
Implementation based on https://arxiv.org/abs/1602.05629
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
fraction_fit : float, optional
|
61 |
+
Fraction of clients used during training. In case `min_fit_clients`
|
62 |
+
is larger than `fraction_fit * available_clients`, `min_fit_clients`
|
63 |
+
will still be sampled. Defaults to 1.0.
|
64 |
+
fraction_evaluate : float, optional
|
65 |
+
Fraction of clients used during validation. In case `min_evaluate_clients`
|
66 |
+
is larger than `fraction_evaluate * available_clients`,
|
67 |
+
`min_evaluate_clients` will still be sampled. Defaults to 1.0.
|
68 |
+
min_fit_clients : int, optional
|
69 |
+
Minimum number of clients used during training. Defaults to 2.
|
70 |
+
min_evaluate_clients : int, optional
|
71 |
+
Minimum number of clients used during validation. Defaults to 2.
|
72 |
+
min_available_clients : int, optional
|
73 |
+
Minimum number of total clients in the system. Defaults to 2.
|
74 |
+
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]]
|
75 |
+
Optional function used for validation. Defaults to None.
|
76 |
+
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
|
77 |
+
Function used to configure training. Defaults to None.
|
78 |
+
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
|
79 |
+
Function used to configure validation. Defaults to None.
|
80 |
+
accept_failures : bool, optional
|
81 |
+
Whether or not accept rounds containing failures. Defaults to True.
|
82 |
+
initial_parameters : Parameters, optional
|
83 |
+
Initial global model parameters.
|
84 |
+
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
|
85 |
+
Metrics aggregation function, optional.
|
86 |
+
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
|
87 |
+
Metrics aggregation function, optional.
|
88 |
+
inplace : bool (default: True)
|
89 |
+
Enable (True) or disable (False) in-place aggregation of model updates.
|
90 |
+
"""
|
91 |
+
|
92 |
+
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
*,
|
96 |
+
fraction_fit: float = 1.0,
|
97 |
+
fraction_evaluate: float = 1.0,
|
98 |
+
min_fit_clients: int = 2,
|
99 |
+
min_evaluate_clients: int = 2,
|
100 |
+
min_available_clients: int = 2,
|
101 |
+
evaluate_fn: Optional[
|
102 |
+
Callable[
|
103 |
+
[int, NDArrays, dict[str, Scalar]],
|
104 |
+
Optional[tuple[float, dict[str, Scalar]]],
|
105 |
+
]
|
106 |
+
] = None,
|
107 |
+
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
108 |
+
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
109 |
+
accept_failures: bool = True,
|
110 |
+
initial_parameters: Optional[Parameters] = None,
|
111 |
+
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
|
112 |
+
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
|
113 |
+
inplace: bool = True,
|
114 |
+
use_mates: bool = False,
|
115 |
+
) -> None:
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
if (
|
119 |
+
min_fit_clients > min_available_clients
|
120 |
+
or min_evaluate_clients > min_available_clients
|
121 |
+
):
|
122 |
+
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)
|
123 |
+
|
124 |
+
self.fraction_fit = fraction_fit
|
125 |
+
self.fraction_evaluate = fraction_evaluate
|
126 |
+
self.min_fit_clients = min_fit_clients
|
127 |
+
self.min_evaluate_clients = min_evaluate_clients
|
128 |
+
self.min_available_clients = min_available_clients
|
129 |
+
self.evaluate_fn = evaluate_fn
|
130 |
+
self.on_fit_config_fn = on_fit_config_fn
|
131 |
+
self.on_evaluate_config_fn = on_evaluate_config_fn
|
132 |
+
self.accept_failures = accept_failures
|
133 |
+
self.initial_parameters = initial_parameters
|
134 |
+
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
|
135 |
+
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
|
136 |
+
self.inplace = inplace
|
137 |
+
self.use_mates = use_mates
|
138 |
+
|
139 |
+
def __repr__(self) -> str:
|
140 |
+
"""Compute a string representation of the strategy."""
|
141 |
+
rep = f"FedAvg(accept_failures={self.accept_failures})"
|
142 |
+
return rep
|
143 |
+
|
144 |
+
def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]:
|
145 |
+
"""Return the sample size and the required number of available clients."""
|
146 |
+
num_clients = int(num_available_clients * self.fraction_fit)
|
147 |
+
return max(num_clients, self.min_fit_clients), self.min_available_clients
|
148 |
+
|
149 |
+
def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]:
|
150 |
+
"""Use a fraction of available clients for evaluation."""
|
151 |
+
num_clients = int(num_available_clients * self.fraction_evaluate)
|
152 |
+
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
153 |
+
|
154 |
+
def initialize_parameters(
|
155 |
+
self, client_manager: ClientManager
|
156 |
+
) -> Optional[Parameters]:
|
157 |
+
"""Initialize global model parameters."""
|
158 |
+
initial_parameters = self.initial_parameters
|
159 |
+
self.initial_parameters = None # Don't keep initial parameters in memory
|
160 |
+
return initial_parameters
|
161 |
+
|
162 |
+
def evaluate(
|
163 |
+
self, server_round: int, parameters: Parameters
|
164 |
+
) -> Optional[tuple[float, dict[str, Scalar]]]:
|
165 |
+
"""Evaluate model parameters using an evaluation function."""
|
166 |
+
if self.evaluate_fn is None:
|
167 |
+
# No evaluation function provided
|
168 |
+
return None
|
169 |
+
parameters_ndarrays = parameters_to_ndarrays(parameters)
|
170 |
+
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
|
171 |
+
if eval_res is None:
|
172 |
+
return None
|
173 |
+
loss, metrics = eval_res
|
174 |
+
return loss, metrics
|
175 |
+
|
176 |
+
def configure_fit(
|
177 |
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
178 |
+
) -> list[tuple[ClientProxy, FitIns]]:
|
179 |
+
"""Configure the next round of training."""
|
180 |
+
config = {}
|
181 |
+
if self.on_fit_config_fn is not None:
|
182 |
+
# Custom fit config function provided
|
183 |
+
config = self.on_fit_config_fn(server_round)
|
184 |
+
fit_ins = FitIns(parameters, config)
|
185 |
+
|
186 |
+
if not client_id_idx:
|
187 |
+
for i, (client_id, _) in enumerate(client_manager.clients.items()):
|
188 |
+
client_id_idx[client_id] = i
|
189 |
+
|
190 |
+
|
191 |
+
# Sample clients
|
192 |
+
sample_size, min_num_clients = self.num_fit_clients(
|
193 |
+
client_manager.num_available()
|
194 |
+
)
|
195 |
+
clients = client_manager.sample(
|
196 |
+
num_clients=sample_size, min_num_clients=min_num_clients
|
197 |
+
)
|
198 |
+
|
199 |
+
# Return client/config pairs
|
200 |
+
return [(client, fit_ins) for client in clients]
|
201 |
+
|
202 |
+
def configure_evaluate(
|
203 |
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
204 |
+
) -> list[tuple[ClientProxy, EvaluateIns]]:
|
205 |
+
"""Configure the next round of evaluation."""
|
206 |
+
# Do not configure federated evaluation if fraction eval is 0.
|
207 |
+
if self.fraction_evaluate == 0.0:
|
208 |
+
return []
|
209 |
+
|
210 |
+
# Parameters and config
|
211 |
+
config = {}
|
212 |
+
if self.on_evaluate_config_fn is not None:
|
213 |
+
# Custom evaluation config function provided
|
214 |
+
config = self.on_evaluate_config_fn(server_round)
|
215 |
+
evaluate_ins = EvaluateIns(parameters, config)
|
216 |
+
|
217 |
+
# Sample clients
|
218 |
+
sample_size, min_num_clients = self.num_evaluation_clients(
|
219 |
+
client_manager.num_available()
|
220 |
+
)
|
221 |
+
clients = client_manager.sample(
|
222 |
+
num_clients=sample_size, min_num_clients=min_num_clients
|
223 |
+
)
|
224 |
+
|
225 |
+
# Return client/config pairs
|
226 |
+
return [(client, evaluate_ins) for client in clients]
|
227 |
+
|
228 |
+
def aggregate_fit(
|
229 |
+
self,
|
230 |
+
server_round: int,
|
231 |
+
results: list[tuple[ClientProxy, FitRes]],
|
232 |
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
233 |
+
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
|
234 |
+
"""Aggregate fit results using weighted average."""
|
235 |
+
if not results:
|
236 |
+
return None, {}
|
237 |
+
# Do not aggregate if there are failures and failures are not accepted
|
238 |
+
if not self.accept_failures and failures:
|
239 |
+
return None, {}
|
240 |
+
|
241 |
+
if self.inplace:
|
242 |
+
# Does in-place weighted average of results
|
243 |
+
aggregated_ndarrays = aggregate_inplace(results)
|
244 |
+
elif self.use_mates:
|
245 |
+
aggregated_ndarrays = aggregate_inplace_mates(results)
|
246 |
+
else:
|
247 |
+
# Convert results
|
248 |
+
weights_results = [
|
249 |
+
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
|
250 |
+
for _, fit_res in results
|
251 |
+
]
|
252 |
+
aggregated_ndarrays = aggregate(weights_results)
|
253 |
+
|
254 |
+
parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)
|
255 |
+
|
256 |
+
# Aggregate custom metrics if aggregation fn was provided
|
257 |
+
metrics_aggregated = {}
|
258 |
+
if self.fit_metrics_aggregation_fn:
|
259 |
+
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
|
260 |
+
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
|
261 |
+
elif server_round == 1: # Only log this warning once
|
262 |
+
log(WARNING, "No fit_metrics_aggregation_fn provided")
|
263 |
+
|
264 |
+
return parameters_aggregated, metrics_aggregated
|
265 |
+
|
266 |
+
def aggregate_evaluate(
|
267 |
+
self,
|
268 |
+
server_round: int,
|
269 |
+
results: list[tuple[ClientProxy, EvaluateRes]],
|
270 |
+
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
|
271 |
+
) -> tuple[Optional[float], dict[str, Scalar]]:
|
272 |
+
"""Aggregate evaluation losses using weighted average."""
|
273 |
+
if not results:
|
274 |
+
return None, {}
|
275 |
+
# Do not aggregate if there are failures and failures are not accepted
|
276 |
+
if not self.accept_failures and failures:
|
277 |
+
return None, {}
|
278 |
+
|
279 |
+
# Aggregate loss
|
280 |
+
loss_aggregated = weighted_loss_avg(
|
281 |
+
[
|
282 |
+
(evaluate_res.num_examples, evaluate_res.loss)
|
283 |
+
for _, evaluate_res in results
|
284 |
+
]
|
285 |
+
)
|
286 |
+
|
287 |
+
# Aggregate custom metrics if aggregation fn was provided
|
288 |
+
metrics_aggregated = {}
|
289 |
+
if self.evaluate_metrics_aggregation_fn:
|
290 |
+
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
|
291 |
+
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
|
292 |
+
elif server_round == 1: # Only log this warning once
|
293 |
+
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
|
294 |
+
|
295 |
+
return loss_aggregated, metrics_aggregated
|
template_FL/src/fedllm/server_app.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""flowertune-llm: A Flower / FlowerTune app."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import wandb
|
6 |
+
import numpy as np
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from datetime import datetime
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding, TrainingArguments, Trainer, GenerationConfig
|
12 |
+
from .trainer import ManualTrainer
|
13 |
+
from transformers.integrations import WandbCallback
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from flwr.common import Context, ndarrays_to_parameters
|
16 |
+
from flwr.common.config import unflatten_dict
|
17 |
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
18 |
+
# from flwr.server.strategy import FedAvg
|
19 |
+
from omegaconf import DictConfig
|
20 |
+
|
21 |
+
from .models import *
|
22 |
+
from .dataset import replace_keys
|
23 |
+
from .myfedavg import FedAvg
|
24 |
+
from .data_domains import global_test_set_hete
|
25 |
+
from .make_data import Prompter, generate_and_tokenize_prompt
|
26 |
+
from .metrics import exact_match, f1, get_rouge_score
|
27 |
+
|
28 |
+
from datasets import load_dataset, Dataset
|
29 |
+
from sklearn.model_selection import train_test_split
|
30 |
+
|
31 |
+
|
32 |
+
load_dotenv(".env")
|
33 |
+
|
34 |
+
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
|
35 |
+
os.environ["WANDB_NAME"] = os.getenv("WANDB_NAME")
|
36 |
+
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
37 |
+
# os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
38 |
+
|
39 |
+
class LLMSampleCB(WandbCallback):
|
40 |
+
def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
|
41 |
+
"A CallBack to log samples a wandb.Table during training"
|
42 |
+
super().__init__()
|
43 |
+
# self._log_model = log_model
|
44 |
+
self.task = task
|
45 |
+
self.sample_dataset = test_dataset.shuffle().select(range(num_samples))
|
46 |
+
self.model, self.tokenizer = trainer.model, trainer.tokenizer
|
47 |
+
self.max_new_tokens = max_new_tokens
|
48 |
+
self.gen_config = GenerationConfig.from_pretrained(trainer.model.name_or_path,
|
49 |
+
max_new_tokens=max_new_tokens)
|
50 |
+
def generate(self, prompt):
|
51 |
+
tokenized_prompt = self.tokenizer(
|
52 |
+
prompt,
|
53 |
+
# padding='max_length', max_length=self.max_new_tokens,
|
54 |
+
return_tensors='pt'
|
55 |
+
)
|
56 |
+
input_ids = tokenized_prompt['input_ids'].to('cuda:0')
|
57 |
+
|
58 |
+
with torch.inference_mode():
|
59 |
+
output = self.model.generate(input_ids, generation_config=self.gen_config)
|
60 |
+
return self.tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)
|
61 |
+
|
62 |
+
def samples_table(self, examples):
|
63 |
+
"Create a wandb.Table to store the generations"
|
64 |
+
records_table = wandb.Table(columns=["input", "prediction", "label", "task"] + list(self.gen_config.to_dict().keys()))
|
65 |
+
for example in tqdm(examples, leave=False):
|
66 |
+
instruction = example["instruction"]
|
67 |
+
inputt = example["input"]
|
68 |
+
output = example['output']
|
69 |
+
prompt = ''
|
70 |
+
if inputt == '':
|
71 |
+
prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Response: """
|
72 |
+
else:
|
73 |
+
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Input: {inputt} ### Response:"""
|
74 |
+
|
75 |
+
generation = self.generate(prompt=prompt)
|
76 |
+
records_table.add_data(prompt, generation, output, self.task, *list(self.gen_config.to_dict().values()))
|
77 |
+
return records_table
|
78 |
+
|
79 |
+
def on_evaluate(self, args, state, control, **kwargs):
|
80 |
+
"Log the wandb.Table after calling trainer.evaluate"
|
81 |
+
super().on_evaluate(args, state, control, **kwargs)
|
82 |
+
records_table = self.samples_table(self.sample_dataset)
|
83 |
+
self._wandb.log({"sample_predictions":records_table})
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, task):
|
88 |
+
|
89 |
+
wandb.init(
|
90 |
+
project='FL@CSS25',
|
91 |
+
name=f'global_eval_round_{sround}',
|
92 |
+
id=f"round_{sround}",
|
93 |
+
resume="allow",
|
94 |
+
reinit=True,
|
95 |
+
# settings=wandb.Settings(start_method="thread")
|
96 |
+
)
|
97 |
+
|
98 |
+
def compute_metrics(pred):
|
99 |
+
labels_ids = pred['label_ids']
|
100 |
+
labels_ids[labels_ids == -100] = 1829
|
101 |
+
pred_ids = pred['predictions']
|
102 |
+
|
103 |
+
# all unnecessary tokens are removed
|
104 |
+
pred_str = tokenizer.batch_decode(
|
105 |
+
pred_ids, skip_special_tokens=True
|
106 |
+
)
|
107 |
+
label_str = tokenizer.batch_decode(
|
108 |
+
labels_ids, skip_special_tokens=True
|
109 |
+
)
|
110 |
+
return {
|
111 |
+
**get_rouge_score(predictions=pred_str, targets=label_str),
|
112 |
+
**f1(predictions=pred_str, targets=label_str),
|
113 |
+
}
|
114 |
+
|
115 |
+
data_collator = DataCollatorForSeq2Seq(
|
116 |
+
tokenizer,
|
117 |
+
pad_to_multiple_of=8,
|
118 |
+
return_tensors="pt",
|
119 |
+
padding=True,
|
120 |
+
)
|
121 |
+
|
122 |
+
testset = (
|
123 |
+
dataset
|
124 |
+
.shuffle()
|
125 |
+
.map(
|
126 |
+
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
127 |
+
num_proc=8,
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
training_arguments = TrainingArguments(**train_cfg.training_arguments)
|
132 |
+
training_arguments.output_dir = './global_results'
|
133 |
+
training_arguments.logging_dir='./global_logs'
|
134 |
+
# training_arguments.run_name = f'global_eval_round_{sround}'
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
# # Constuct baseline Trainer
|
139 |
+
# trainer = Trainer(
|
140 |
+
# model=model,
|
141 |
+
# eval_dataset=testset.select(range(10)),
|
142 |
+
# args=training_arguments,
|
143 |
+
# data_collator=data_collator,
|
144 |
+
# compute_metrics=compute_metrics,
|
145 |
+
# tokenizer=tokenizer
|
146 |
+
# )
|
147 |
+
|
148 |
+
mates_args.state = False
|
149 |
+
|
150 |
+
trainer = ManualTrainer(
|
151 |
+
model= model,
|
152 |
+
tokenizer = tokenizer,
|
153 |
+
train_dataset=None,
|
154 |
+
val_dataset=testset.select(range(10)),
|
155 |
+
holdout_dataset=None,
|
156 |
+
reference_dataset=None,
|
157 |
+
args=training_arguments,
|
158 |
+
data_collator=data_collator,
|
159 |
+
compute_metrics=compute_metrics,
|
160 |
+
mates_args=mates_args,
|
161 |
+
data_influence_model=None,
|
162 |
+
data_influence_tokenizer=None,
|
163 |
+
)
|
164 |
+
|
165 |
+
# Do local training
|
166 |
+
results = trainer.evaluate(wandb_sample=True)
|
167 |
+
|
168 |
+
# Extract loss, predictions, and labels
|
169 |
+
|
170 |
+
eval_loss = results[f"eval_loss"]
|
171 |
+
eval_metrics = {
|
172 |
+
f'{task}_f1': results["f1"],
|
173 |
+
f'{task}_rouge1': results["rouge1"],
|
174 |
+
f'{task}_rouge2': results['rouge2'],
|
175 |
+
f'{task}_rougeL': results['rougeL'],
|
176 |
+
f'{task}_rougeLsum': results['rougeLsum'],
|
177 |
+
}
|
178 |
+
|
179 |
+
wandb.finish()
|
180 |
+
|
181 |
+
return eval_loss, eval_metrics
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
# Get function that will be executed by the strategy's evaluate() method
|
186 |
+
# Here we use it to save global model checkpoints
|
187 |
+
|
188 |
+
def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args):
|
189 |
+
"""Return an evaluation function for saving global model."""
|
190 |
+
|
191 |
+
def evaluate(server_round: int, parameters, config):
|
192 |
+
# Save model
|
193 |
+
total_loss, result_metric = 0, {}
|
194 |
+
prompter = Prompter(train_cfg.prompt_template_name, train_cfg.verbose)
|
195 |
+
if server_round != 0 and (
|
196 |
+
server_round == total_round or server_round % save_every_round == 0
|
197 |
+
):
|
198 |
+
# Init model
|
199 |
+
main_model_params, _ = split_models(parameters)
|
200 |
+
model, tokenizer = get_model(model_cfg)
|
201 |
+
set_parameters(model, main_model_params)
|
202 |
+
|
203 |
+
tmp_dict = {
|
204 |
+
"prompter": prompter,
|
205 |
+
"seq_length": train_cfg.seq_length,
|
206 |
+
"train_on_inputs": train_cfg.train_on_inputs,
|
207 |
+
"tokenizer": tokenizer,
|
208 |
+
}
|
209 |
+
if dataset_cfg.type == 'homo':
|
210 |
+
ds = load_dataset(dataset_cfg.name)
|
211 |
+
_, test = train_test_split(
|
212 |
+
ds, test_size=0.09, shuffle=True, random_state=42
|
213 |
+
)
|
214 |
+
global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
|
215 |
+
|
216 |
+
loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, 'homo')
|
217 |
+
total_loss = loss
|
218 |
+
result_metric = {'homo_f1': metrics['homo_f1']}
|
219 |
+
else:
|
220 |
+
(
|
221 |
+
list_loss, list_f1,
|
222 |
+
list_rouge1, list_rouge2,
|
223 |
+
list_rougeL, list_rougeLsum
|
224 |
+
) = [], {}, {}, {}, {}, {}
|
225 |
+
|
226 |
+
for task in ['general', 'finance', 'math', 'medical', 'code']:
|
227 |
+
ds = global_test_set_hete[task]
|
228 |
+
loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, task)
|
229 |
+
list_loss.append(loss)
|
230 |
+
|
231 |
+
list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
|
232 |
+
# list_rouge1[f'{task}_rouge1'] = metrics['rouge1']
|
233 |
+
# list_rouge2[f'{task}_rouge2'] = metrics['rouge2']
|
234 |
+
# list_rougeL[f'{task}_rougeL'] = metrics['rougeL']
|
235 |
+
# list_rougeLsum[f'{task}_rougeLsum'] = metrics['rougeLsum']
|
236 |
+
|
237 |
+
|
238 |
+
total_loss = sum(list_loss) / len(list_loss)
|
239 |
+
avg_f1 = sum([v for k, v in list_f1.items()]) / len(list_f1)
|
240 |
+
result_metric = {**list_f1, 'avg_hete_f1': avg_f1}
|
241 |
+
|
242 |
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
243 |
+
|
244 |
+
return total_loss, result_metric
|
245 |
+
|
246 |
+
return evaluate
|
247 |
+
|
248 |
+
|
249 |
+
def get_on_fit_config(save_path):
|
250 |
+
"""Return a function that will be used to construct the config that the client's
|
251 |
+
fit() method will receive."""
|
252 |
+
|
253 |
+
def fit_config_fn(server_round: int):
|
254 |
+
fit_config = {}
|
255 |
+
fit_config["current_round"] = server_round
|
256 |
+
fit_config["save_path"] = save_path
|
257 |
+
return fit_config
|
258 |
+
|
259 |
+
return fit_config_fn
|
260 |
+
|
261 |
+
|
262 |
+
def fit_weighted_average(metrics):
|
263 |
+
"""Aggregate (federated) evaluation metrics."""
|
264 |
+
# Multiply accuracy of each client by number of examples used
|
265 |
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
266 |
+
total_flops = [m["flops"] for num_examples, m in metrics]
|
267 |
+
examples = [num_examples for num_examples, _ in metrics]
|
268 |
+
|
269 |
+
# Aggregate and return custom metric (weighted average)
|
270 |
+
return {"train_loss": round(sum(losses) / sum(examples), 3), "total_flops": f"{sum(total_flops)/1e12:.2f}T"}
|
271 |
+
|
272 |
+
|
273 |
+
def server_fn(context: Context):
|
274 |
+
"""Construct components that set the ServerApp behaviour."""
|
275 |
+
# Create output directory given current timestamp
|
276 |
+
current_time = datetime.now()
|
277 |
+
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
278 |
+
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
279 |
+
os.makedirs(save_path, exist_ok=True)
|
280 |
+
|
281 |
+
# Read from config
|
282 |
+
num_rounds = context.run_config["num-server-rounds"]
|
283 |
+
num_nodes = context.run_config['num-supernodes']
|
284 |
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
285 |
+
|
286 |
+
# Get initial model weights
|
287 |
+
init_model, tokenizer = get_model(cfg.model)
|
288 |
+
init_model_parameters = get_parameters(init_model)
|
289 |
+
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
290 |
+
|
291 |
+
# Define strategy
|
292 |
+
strategy = FedAvg(
|
293 |
+
fraction_fit=cfg.train.strategy.fraction_fit,
|
294 |
+
fraction_evaluate=cfg.train.strategy.fraction_evaluate,
|
295 |
+
on_fit_config_fn=get_on_fit_config(save_path),
|
296 |
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
297 |
+
initial_parameters=init_model_parameters,
|
298 |
+
evaluate_fn=get_evaluate_fn(
|
299 |
+
cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates
|
300 |
+
),
|
301 |
+
use_mates=cfg.mates.state
|
302 |
+
)
|
303 |
+
config = ServerConfig(num_rounds=num_rounds)
|
304 |
+
|
305 |
+
return ServerAppComponents(strategy=strategy, config=config)
|
306 |
+
|
307 |
+
|
308 |
+
# Flower ServerApp
|
309 |
+
app = ServerApp(server_fn=server_fn)
|
template_FL/src/fedllm/skipbert/__init__.py
ADDED
File without changes
|
template_FL/src/fedllm/skipbert/modeling.py
ADDED
@@ -0,0 +1,922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SkipBERT modeling"""
|
2 |
+
|
3 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import nn
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
18 |
+
|
19 |
+
import transformers
|
20 |
+
from transformers import BertPreTrainedModel, BertModel
|
21 |
+
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertLayer
|
22 |
+
from transformers.models.bert.modeling_bert import BertPreTrainingHeads
|
23 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
24 |
+
from . import plot
|
25 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, add_code_sample_docstrings
|
26 |
+
|
27 |
+
import logging
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
logger.warn('Hacking BertSelfAttention! Now it returns attention scores rather than probabilities.')
|
31 |
+
|
32 |
+
class BertSelfAttention(transformers.models.bert.modeling_bert.BertSelfAttention):
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
hidden_states,
|
37 |
+
attention_mask=None,
|
38 |
+
head_mask=None,
|
39 |
+
encoder_hidden_states=None,
|
40 |
+
encoder_attention_mask=None,
|
41 |
+
past_key_value=None,
|
42 |
+
output_attentions=False,
|
43 |
+
):
|
44 |
+
|
45 |
+
device = hidden_states.device
|
46 |
+
mixed_query_layer = self.query(hidden_states)
|
47 |
+
|
48 |
+
# most codes are copied from transformers v4.3.3
|
49 |
+
|
50 |
+
is_cross_attention = encoder_hidden_states is not None
|
51 |
+
|
52 |
+
if is_cross_attention and past_key_value is not None:
|
53 |
+
# reuse k,v, cross_attentions
|
54 |
+
key_layer = past_key_value[0]
|
55 |
+
value_layer = past_key_value[1]
|
56 |
+
attention_mask = encoder_attention_mask
|
57 |
+
elif is_cross_attention:
|
58 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
59 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
60 |
+
attention_mask = encoder_attention_mask
|
61 |
+
elif past_key_value is not None:
|
62 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
63 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
64 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
65 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
66 |
+
else:
|
67 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
68 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
69 |
+
|
70 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
71 |
+
|
72 |
+
if self.is_decoder:
|
73 |
+
past_key_value = (key_layer, value_layer)
|
74 |
+
|
75 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
76 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
77 |
+
|
78 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
79 |
+
seq_length = hidden_states.size()[1]
|
80 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
81 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
82 |
+
distance = position_ids_l - position_ids_r
|
83 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
84 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
85 |
+
|
86 |
+
if self.position_embedding_type == "relative_key":
|
87 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
88 |
+
attention_scores = attention_scores + relative_position_scores
|
89 |
+
elif self.position_embedding_type == "relative_key_query":
|
90 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
91 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
92 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
93 |
+
|
94 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
95 |
+
if attention_mask is not None:
|
96 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
97 |
+
attention_scores = attention_scores.to(device) + attention_mask.to(device)
|
98 |
+
|
99 |
+
# Normalize the attention scores to probabilities.
|
100 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
101 |
+
|
102 |
+
# This is actually dropping out entire tokens to attend to, which might
|
103 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
104 |
+
attention_probs = self.dropout(attention_probs)
|
105 |
+
|
106 |
+
# Mask heads if we want to
|
107 |
+
if head_mask is not None:
|
108 |
+
attention_probs = attention_probs * head_mask
|
109 |
+
#attention_scores = attention_scores * head_mask
|
110 |
+
|
111 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
112 |
+
|
113 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
114 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
115 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
116 |
+
|
117 |
+
outputs = (context_layer, attention_scores) if output_attentions else (context_layer,) # hacked: replace attention_probs with attention_scores
|
118 |
+
|
119 |
+
if self.is_decoder:
|
120 |
+
outputs = outputs + (past_key_value,)
|
121 |
+
return outputs
|
122 |
+
|
123 |
+
transformers.models.bert.modeling_bert.BertSelfAttention = BertSelfAttention
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
class BertForPreTraining(BertPreTrainedModel):
|
128 |
+
def __init__(self, config):
|
129 |
+
super().__init__(config)
|
130 |
+
fit_size = getattr(config, 'fit_size', 768)
|
131 |
+
self.bert = BertModel(config)
|
132 |
+
self.cls = BertPreTrainingHeads(config)
|
133 |
+
self.fit_denses = nn.ModuleList(
|
134 |
+
[nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers+1)]
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, input_ids, token_type_ids=None,
|
138 |
+
attention_mask=None, masked_lm_labels=None,
|
139 |
+
next_sentence_label=None, labels=None,
|
140 |
+
output_attentions=True, output_hidden_states=True,):
|
141 |
+
outputs = self.bert(
|
142 |
+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
143 |
+
sequence_output, att_output, pooled_output = outputs.hidden_states, outputs.attentions, outputs.pooler_output
|
144 |
+
tmp = []
|
145 |
+
for s_id, sequence_layer in enumerate(sequence_output):
|
146 |
+
tmp.append(self.fit_denses[s_id](sequence_layer))
|
147 |
+
sequence_output = tmp
|
148 |
+
|
149 |
+
return att_output, sequence_output
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
# class BertForSequenceClassification(BertPreTrainedModel):
|
155 |
+
# def __init__(self, config, do_fit=False, share_param=True):
|
156 |
+
# super().__init__(config)
|
157 |
+
# num_labels = config.num_labels
|
158 |
+
# self.hidden_size = config.hidden_size
|
159 |
+
# self.num_labels = num_labels
|
160 |
+
# self.bert = BertModel(config)
|
161 |
+
# self.dropout = nn.Dropout(
|
162 |
+
# config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
163 |
+
# self.classifier = nn.Linear(config.hidden_size, num_labels)
|
164 |
+
|
165 |
+
# self.do_fit, self.share_param = do_fit, share_param
|
166 |
+
# if self.do_fit:
|
167 |
+
# fit_size = getattr(config, 'fit_size', 768)
|
168 |
+
# self.fit_size = fit_size
|
169 |
+
# if self.share_param:
|
170 |
+
# self.fit_dense = nn.Linear(config.hidden_size, fit_size)
|
171 |
+
# else:
|
172 |
+
# self.fit_denses = nn.ModuleList(
|
173 |
+
# [nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
|
174 |
+
# )
|
175 |
+
|
176 |
+
# def do_fit_dense(self, sequence_output):
|
177 |
+
|
178 |
+
# tmp = []
|
179 |
+
# if self.do_fit:
|
180 |
+
# for s_id, sequence_layer in enumerate(sequence_output):
|
181 |
+
# if self.share_param:
|
182 |
+
# tmp.append(self.fit_dense(sequence_layer))
|
183 |
+
# else:
|
184 |
+
# tmp.append(self.fit_denses[s_id](sequence_layer))
|
185 |
+
# sequence_output = tmp
|
186 |
+
|
187 |
+
# return sequence_output
|
188 |
+
|
189 |
+
# def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
190 |
+
|
191 |
+
# outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
|
192 |
+
# output_hidden_states=True, output_attentions=True)
|
193 |
+
# sequence_output, att_output, pooled_output = outputs.hidden_states, outputs.attentions, outputs.pooler_output
|
194 |
+
|
195 |
+
# logits = self.classifier(pooled_output)
|
196 |
+
|
197 |
+
# sequence_output = self.do_fit_dense(sequence_output)
|
198 |
+
|
199 |
+
# return logits, att_output, sequence_output
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
# SequenceClassification docstring
|
205 |
+
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
|
206 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
207 |
+
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
208 |
+
_CONFIG_FOR_DOC = "BertConfig"
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
BERT_START_DOCSTRING = r"""
|
213 |
+
|
214 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
215 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
216 |
+
etc.)
|
217 |
+
|
218 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
219 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
220 |
+
and behavior.
|
221 |
+
|
222 |
+
Parameters:
|
223 |
+
config ([`BertConfig`]): Model configuration class with all the parameters of the model.
|
224 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
225 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
226 |
+
"""
|
227 |
+
|
228 |
+
|
229 |
+
BERT_INPUTS_DOCSTRING = r"""
|
230 |
+
Args:
|
231 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
232 |
+
Indices of input sequence tokens in the vocabulary.
|
233 |
+
|
234 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
235 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
236 |
+
|
237 |
+
[What are input IDs?](../glossary#input-ids)
|
238 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
|
239 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
240 |
+
|
241 |
+
- 1 for tokens that are **not masked**,
|
242 |
+
- 0 for tokens that are **masked**.
|
243 |
+
|
244 |
+
[What are attention masks?](../glossary#attention-mask)
|
245 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
246 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
247 |
+
1]`:
|
248 |
+
|
249 |
+
- 0 corresponds to a *sentence A* token,
|
250 |
+
- 1 corresponds to a *sentence B* token.
|
251 |
+
|
252 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
253 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
254 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
255 |
+
config.max_position_embeddings - 1]`.
|
256 |
+
|
257 |
+
[What are position IDs?](../glossary#position-ids)
|
258 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
259 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
260 |
+
|
261 |
+
- 1 indicates the head is **not masked**,
|
262 |
+
- 0 indicates the head is **masked**.
|
263 |
+
|
264 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
265 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
266 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
267 |
+
model's internal embedding lookup matrix.
|
268 |
+
output_attentions (`bool`, *optional*):
|
269 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
270 |
+
tensors for more detail.
|
271 |
+
output_hidden_states (`bool`, *optional*):
|
272 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
273 |
+
more detail.
|
274 |
+
return_dict (`bool`, *optional*):
|
275 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
276 |
+
"""
|
277 |
+
|
278 |
+
@add_start_docstrings(
|
279 |
+
"""
|
280 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
281 |
+
output) e.g. for GLUE tasks.
|
282 |
+
""",
|
283 |
+
BERT_START_DOCSTRING,
|
284 |
+
)
|
285 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
286 |
+
def __init__(self, config, do_fit=False, share_param=True):
|
287 |
+
super().__init__(config)
|
288 |
+
self.num_labels = config.num_labels
|
289 |
+
self.config = config
|
290 |
+
self.do_fit = do_fit
|
291 |
+
self.share_param = share_param
|
292 |
+
|
293 |
+
self.bert = BertModel(config)
|
294 |
+
classifier_dropout = (
|
295 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
296 |
+
)
|
297 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
298 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
299 |
+
|
300 |
+
# Add fit layers if enabled
|
301 |
+
if self.do_fit:
|
302 |
+
fit_size = getattr(config, 'fit_size', 768)
|
303 |
+
if self.share_param:
|
304 |
+
self.fit_dense = nn.Linear(config.hidden_size, fit_size)
|
305 |
+
else:
|
306 |
+
self.fit_denses = nn.ModuleList(
|
307 |
+
[nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
|
308 |
+
)
|
309 |
+
|
310 |
+
self.post_init()
|
311 |
+
|
312 |
+
def do_fit_dense(self, hidden_states):
|
313 |
+
"""Process hidden states through fit layers if enabled"""
|
314 |
+
if not self.do_fit:
|
315 |
+
return hidden_states
|
316 |
+
|
317 |
+
processed_states = []
|
318 |
+
for layer_idx, state in enumerate(hidden_states):
|
319 |
+
if self.share_param:
|
320 |
+
processed_states.append(self.fit_dense(state))
|
321 |
+
else:
|
322 |
+
processed_states.append(self.fit_denses[layer_idx](state))
|
323 |
+
return processed_states
|
324 |
+
|
325 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
326 |
+
@add_code_sample_docstrings(
|
327 |
+
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
|
328 |
+
output_type=SequenceClassifierOutput,
|
329 |
+
config_class=_CONFIG_FOR_DOC,
|
330 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
331 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
332 |
+
)
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
input_ids: Optional[torch.Tensor] = None,
|
336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
337 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
338 |
+
position_ids: Optional[torch.Tensor] = None,
|
339 |
+
head_mask: Optional[torch.Tensor] = None,
|
340 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
341 |
+
labels: Optional[torch.Tensor] = None,
|
342 |
+
output_attentions: Optional[bool] = None,
|
343 |
+
output_hidden_states: Optional[bool] = None,
|
344 |
+
return_dict: Optional[bool] = None,
|
345 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
346 |
+
r"""
|
347 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
348 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
349 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
350 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
351 |
+
"""
|
352 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
353 |
+
|
354 |
+
# Force output hidden states if fit layers are enabled
|
355 |
+
if self.do_fit:
|
356 |
+
output_hidden_states = True
|
357 |
+
|
358 |
+
outputs = self.bert(
|
359 |
+
input_ids,
|
360 |
+
attention_mask=attention_mask,
|
361 |
+
token_type_ids=token_type_ids,
|
362 |
+
position_ids=position_ids,
|
363 |
+
head_mask=head_mask,
|
364 |
+
inputs_embeds=inputs_embeds,
|
365 |
+
output_attentions=output_attentions,
|
366 |
+
output_hidden_states=output_hidden_states,
|
367 |
+
return_dict=return_dict,
|
368 |
+
)
|
369 |
+
|
370 |
+
pooled_output = outputs[1]
|
371 |
+
pooled_output = self.dropout(pooled_output)
|
372 |
+
logits = self.classifier(pooled_output)
|
373 |
+
|
374 |
+
# Process hidden states through fit layers
|
375 |
+
hidden_states = outputs.hidden_states
|
376 |
+
if self.do_fit and hidden_states is not None:
|
377 |
+
hidden_states = self.do_fit_dense(hidden_states)
|
378 |
+
|
379 |
+
loss = None
|
380 |
+
if labels is not None:
|
381 |
+
if self.config.problem_type is None:
|
382 |
+
if self.num_labels == 1:
|
383 |
+
self.config.problem_type = "regression"
|
384 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
385 |
+
self.config.problem_type = "single_label_classification"
|
386 |
+
else:
|
387 |
+
self.config.problem_type = "multi_label_classification"
|
388 |
+
|
389 |
+
if self.config.problem_type == "regression":
|
390 |
+
loss_fct = MSELoss()
|
391 |
+
if self.num_labels == 1:
|
392 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
393 |
+
else:
|
394 |
+
loss = loss_fct(logits, labels)
|
395 |
+
elif self.config.problem_type == "single_label_classification":
|
396 |
+
loss_fct = CrossEntropyLoss()
|
397 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
398 |
+
elif self.config.problem_type == "multi_label_classification":
|
399 |
+
loss_fct = BCEWithLogitsLoss()
|
400 |
+
loss = loss_fct(logits, labels)
|
401 |
+
|
402 |
+
if not return_dict:
|
403 |
+
output = (logits,) + outputs[2:]
|
404 |
+
# Replace original hidden states with processed ones
|
405 |
+
if self.do_fit and hidden_states is not None:
|
406 |
+
output = (logits,) + (hidden_states,) + outputs[3:]
|
407 |
+
return ((loss,) + output) if loss is not None else output
|
408 |
+
|
409 |
+
return SequenceClassifierOutput(
|
410 |
+
loss=loss,
|
411 |
+
logits=logits,
|
412 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
413 |
+
attentions=outputs.attentions,
|
414 |
+
)
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
class BertForSequenceClassificationPrediction(BertForSequenceClassification):
|
419 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
420 |
+
|
421 |
+
assert not self.training
|
422 |
+
|
423 |
+
_, pooled_output, sequence_output, att_output = self.bert(
|
424 |
+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
|
425 |
+
output_hidden_states=True, output_attentions=True)
|
426 |
+
|
427 |
+
logits = self.classifier(pooled_output)
|
428 |
+
|
429 |
+
loss = None
|
430 |
+
if labels is not None:
|
431 |
+
loss = torch.tensor(0.)
|
432 |
+
|
433 |
+
return SequenceClassifierOutput(
|
434 |
+
loss=loss,
|
435 |
+
logits=logits,
|
436 |
+
)
|
437 |
+
|
438 |
+
|
439 |
+
class ShallowSkipping(nn.Module):
|
440 |
+
|
441 |
+
def __init__(self, model):
|
442 |
+
super().__init__()
|
443 |
+
# self.model = model # do not register
|
444 |
+
self.config = model.config
|
445 |
+
self.shallow_config = model.shallow_config
|
446 |
+
# current only support trigram
|
447 |
+
self.ngram = 3
|
448 |
+
|
449 |
+
if self.shallow_config.hidden_size != self.config.hidden_size:
|
450 |
+
self.linear = nn.Linear(self.shallow_config.hidden_size, self.config.hidden_size)
|
451 |
+
|
452 |
+
self.plot = plot.Plot(self.config.max_num_entries, self.config.hidden_size)
|
453 |
+
|
454 |
+
def _build_tri_gram_ids(self, input_ids:torch.Tensor) -> torch.Tensor:
|
455 |
+
return torch.from_numpy(
|
456 |
+
self.plot.input_ids_to_tri_grams(input_ids.cpu().numpy())
|
457 |
+
).to(input_ids.device)
|
458 |
+
|
459 |
+
def build_input_ngrams(self, input_ids:torch.Tensor, token_type_ids:torch.Tensor):
|
460 |
+
|
461 |
+
input_ngram_ids = self._build_tri_gram_ids(input_ids)
|
462 |
+
|
463 |
+
token_ngram_type_ids = None #
|
464 |
+
|
465 |
+
attention_mask = (input_ngram_ids > 0).float()
|
466 |
+
|
467 |
+
if self.training:
|
468 |
+
_mask = torch.rand(attention_mask.shape).to(attention_mask.device)
|
469 |
+
_mask = (_mask > self.config.ngram_masking)
|
470 |
+
attention_mask *= _mask
|
471 |
+
|
472 |
+
attention_mask[:, self.ngram//2] = 1 # avoid masking all tokens in a tri-gram
|
473 |
+
return input_ngram_ids, token_ngram_type_ids, attention_mask
|
474 |
+
|
475 |
+
@torch.jit.script
|
476 |
+
def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
|
477 |
+
batch_size, seq_length = input_ids.shape
|
478 |
+
lens = (input_ids!=0).sum(1)
|
479 |
+
hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
|
480 |
+
|
481 |
+
# assert to be trigrams
|
482 |
+
flat_hidden_state = ngram_hidden_states[:, 1]
|
483 |
+
flat_hidden_state[:-1] = flat_hidden_state[:-1] + ngram_hidden_states[1:, 0]
|
484 |
+
flat_hidden_state[1:] = flat_hidden_state[1:] + ngram_hidden_states[:-1, 2]
|
485 |
+
k = 0
|
486 |
+
for i in range(batch_size):
|
487 |
+
hidden_state[i, :lens[i]] = flat_hidden_state[k: k+lens[i]]
|
488 |
+
k += 1 + lens[i] # 1 for skipping one padding tri-gram
|
489 |
+
hidden_state = hidden_state + aux_embeddings
|
490 |
+
return hidden_state
|
491 |
+
|
492 |
+
def forward_shallow_layers(
|
493 |
+
self,
|
494 |
+
input_ids,
|
495 |
+
token_type_ids,
|
496 |
+
attention_mask,
|
497 |
+
ngram_mask_position=None,
|
498 |
+
head_mask=None,
|
499 |
+
encoder_hidden_states=None,
|
500 |
+
encoder_attention_mask=None,
|
501 |
+
past_key_value=None,
|
502 |
+
output_attentions=True,
|
503 |
+
output_hidden_states=True,
|
504 |
+
model=None,
|
505 |
+
):
|
506 |
+
device = model.device
|
507 |
+
|
508 |
+
input_ngram_ids, token_ngram_type_ids, attention_mask = self.build_input_ngrams(input_ids, token_type_ids)
|
509 |
+
ngram_attention_mask = attention_mask.clone()
|
510 |
+
|
511 |
+
if ngram_mask_position is not None:
|
512 |
+
input_ngram_ids[:, ngram_mask_position] = 0
|
513 |
+
ngram_attention_mask[:, ngram_mask_position] = 0
|
514 |
+
|
515 |
+
extended_attention_mask = model.get_extended_attention_mask(attention_mask, input_ngram_ids.shape, device)
|
516 |
+
|
517 |
+
ngram_index=(input_ngram_ids[:, self.ngram//2] > 0)
|
518 |
+
|
519 |
+
embedding_output = model.embeddings(input_ids=input_ngram_ids, token_type_ids=token_ngram_type_ids)
|
520 |
+
|
521 |
+
hidden_states = embedding_output
|
522 |
+
attention_mask = extended_attention_mask
|
523 |
+
|
524 |
+
for i, layer_module in enumerate(
|
525 |
+
model.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]):
|
526 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
527 |
+
|
528 |
+
layer_outputs = layer_module(
|
529 |
+
hidden_states=hidden_states,
|
530 |
+
attention_mask=attention_mask,
|
531 |
+
head_mask=layer_head_mask,
|
532 |
+
encoder_hidden_states=encoder_hidden_states,
|
533 |
+
encoder_attention_mask=encoder_attention_mask,
|
534 |
+
past_key_value=past_key_value,
|
535 |
+
output_attentions=output_attentions,
|
536 |
+
)
|
537 |
+
|
538 |
+
hidden_states = layer_outputs[0]
|
539 |
+
|
540 |
+
if self.shallow_config.hidden_size != self.config.hidden_size:
|
541 |
+
hidden_states = self.linear(hidden_states)
|
542 |
+
|
543 |
+
# Set zero the padding ngrams: (..., [PAD], ...)
|
544 |
+
hidden_states = hidden_states * ngram_index[:, None, None]
|
545 |
+
|
546 |
+
hidden_states = hidden_states * model.attn(hidden_states).sigmoid() * ngram_attention_mask.unsqueeze(-1)
|
547 |
+
|
548 |
+
return input_ngram_ids, hidden_states
|
549 |
+
|
550 |
+
def forward(
|
551 |
+
self,
|
552 |
+
input_ids,
|
553 |
+
token_type_ids,
|
554 |
+
attention_mask,
|
555 |
+
head_mask=None,
|
556 |
+
encoder_hidden_states=None,
|
557 |
+
encoder_attention_mask=None,
|
558 |
+
past_key_value=None,
|
559 |
+
output_attentions=True,
|
560 |
+
output_hidden_states=True,
|
561 |
+
model=None,
|
562 |
+
):
|
563 |
+
|
564 |
+
device = model.device
|
565 |
+
|
566 |
+
batch_size, seq_length = input_ids.shape
|
567 |
+
aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
|
568 |
+
aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
|
569 |
+
|
570 |
+
if self.config.plot_mode == 'force_compute':
|
571 |
+
'''
|
572 |
+
compute only, ignore PLOT
|
573 |
+
'''
|
574 |
+
input_ngram_ids, hidden_states = self.forward_shallow_layers(
|
575 |
+
input_ids=input_ids,
|
576 |
+
token_type_ids=token_type_ids,
|
577 |
+
attention_mask=attention_mask,
|
578 |
+
head_mask=head_mask,
|
579 |
+
encoder_hidden_states=encoder_hidden_states,
|
580 |
+
encoder_attention_mask=encoder_attention_mask,
|
581 |
+
past_key_value=past_key_value,
|
582 |
+
output_attentions=output_attentions,
|
583 |
+
output_hidden_states=output_hidden_states,
|
584 |
+
ngram_mask_position=None,
|
585 |
+
model=model,
|
586 |
+
)
|
587 |
+
|
588 |
+
elif self.config.plot_mode == 'update_all':
|
589 |
+
'''
|
590 |
+
build PLOT
|
591 |
+
'''
|
592 |
+
# uni-grams
|
593 |
+
input_ngram_ids, hidden_states = self.forward_shallow_layers(
|
594 |
+
input_ids=input_ids,
|
595 |
+
token_type_ids=token_type_ids,
|
596 |
+
attention_mask=attention_mask,
|
597 |
+
head_mask=head_mask,
|
598 |
+
encoder_hidden_states=encoder_hidden_states,
|
599 |
+
encoder_attention_mask=encoder_attention_mask,
|
600 |
+
past_key_value=past_key_value,
|
601 |
+
output_attentions=output_attentions,
|
602 |
+
output_hidden_states=output_hidden_states,
|
603 |
+
ngram_mask_position=(0,2),
|
604 |
+
model=model,
|
605 |
+
)
|
606 |
+
self.plot.update_data(input_ngram_ids, hidden_states)
|
607 |
+
|
608 |
+
# bi-grams
|
609 |
+
input_ngram_ids, hidden_states = self.forward_shallow_layers(
|
610 |
+
input_ids=input_ids,
|
611 |
+
token_type_ids=token_type_ids,
|
612 |
+
attention_mask=attention_mask,
|
613 |
+
head_mask=head_mask,
|
614 |
+
encoder_hidden_states=encoder_hidden_states,
|
615 |
+
encoder_attention_mask=encoder_attention_mask,
|
616 |
+
past_key_value=past_key_value,
|
617 |
+
output_attentions=output_attentions,
|
618 |
+
output_hidden_states=output_hidden_states,
|
619 |
+
ngram_mask_position=0,
|
620 |
+
model=model,
|
621 |
+
)
|
622 |
+
self.plot.update_data(input_ngram_ids, hidden_states)
|
623 |
+
|
624 |
+
# tri-grams
|
625 |
+
input_ngram_ids, hidden_states = self.forward_shallow_layers(
|
626 |
+
input_ids=input_ids,
|
627 |
+
token_type_ids=token_type_ids,
|
628 |
+
attention_mask=attention_mask,
|
629 |
+
head_mask=head_mask,
|
630 |
+
encoder_hidden_states=encoder_hidden_states,
|
631 |
+
encoder_attention_mask=encoder_attention_mask,
|
632 |
+
past_key_value=past_key_value,
|
633 |
+
output_attentions=output_attentions,
|
634 |
+
output_hidden_states=output_hidden_states,
|
635 |
+
ngram_mask_position=None,
|
636 |
+
model=model,
|
637 |
+
)
|
638 |
+
self.plot.update_data(input_ngram_ids, hidden_states)
|
639 |
+
|
640 |
+
elif self.config.plot_mode == 'plot_passive':
|
641 |
+
'''
|
642 |
+
use plot if no oov
|
643 |
+
'''
|
644 |
+
|
645 |
+
if input_ids.is_cuda:
|
646 |
+
input_ids = input_ids.cpu()
|
647 |
+
if not self.plot.has_oov(input_ids):
|
648 |
+
hidden_states = self.plot.retrieve_data(input_ids)
|
649 |
+
hidden_states = hidden_states.to(device)
|
650 |
+
else:
|
651 |
+
input_ids = input_ids.to(device)
|
652 |
+
input_ngram_ids, hidden_states = self.forward_shallow_layers(
|
653 |
+
input_ids=input_ids,
|
654 |
+
token_type_ids=token_type_ids,
|
655 |
+
attention_mask=attention_mask,
|
656 |
+
head_mask=head_mask,
|
657 |
+
encoder_hidden_states=encoder_hidden_states,
|
658 |
+
encoder_attention_mask=encoder_attention_mask,
|
659 |
+
past_key_value=past_key_value,
|
660 |
+
output_attentions=output_attentions,
|
661 |
+
output_hidden_states=output_hidden_states,
|
662 |
+
ngram_mask_position=None,
|
663 |
+
model=model,
|
664 |
+
)
|
665 |
+
self.plot.update_data(input_ngram_ids, hidden_states)
|
666 |
+
|
667 |
+
elif self.config.plot_mode == 'plot_only':
|
668 |
+
'''
|
669 |
+
plot only
|
670 |
+
looking up order: trigram -> bigram -> unigram -> 0
|
671 |
+
'''
|
672 |
+
if input_ids.is_cuda:
|
673 |
+
logger.warn("'input_ids' is better to placed in CPU.")
|
674 |
+
input_ids = input_ids.cpu()
|
675 |
+
hidden_states = self.plot.retrieve_data(input_ids)
|
676 |
+
hidden_states = hidden_states.to(device)
|
677 |
+
|
678 |
+
hidden_states = F.dropout(hidden_states, self.config.hidden_dropout_prob, self.training)
|
679 |
+
hidden_states = self.merge_ngrams(input_ids, hidden_states, aux_embeddings)
|
680 |
+
hidden_states = model.norm(hidden_states)
|
681 |
+
|
682 |
+
return hidden_states
|
683 |
+
|
684 |
+
|
685 |
+
class SkipBertEncoder(BertEncoder):
|
686 |
+
def __init__(self, shallow_config, config):
|
687 |
+
super(BertEncoder, self).__init__()
|
688 |
+
self.config = config
|
689 |
+
self.shallow_config = shallow_config
|
690 |
+
self.layer = nn.ModuleList(
|
691 |
+
[
|
692 |
+
BertLayer(shallow_config) for _ in range(config.num_hidden_layers - config.num_full_hidden_layers)
|
693 |
+
] + [
|
694 |
+
BertLayer(config) for _ in range(config.num_full_hidden_layers)
|
695 |
+
])
|
696 |
+
|
697 |
+
class SkipBertModel(BertModel):
|
698 |
+
def __init__(self, config, add_pooling_layer=True):
|
699 |
+
super().__init__(config)
|
700 |
+
self.config = config
|
701 |
+
self.shallow_config = copy.deepcopy(config)
|
702 |
+
|
703 |
+
self.shallow_config.hidden_size = getattr(config, 'shallow_hidden_size', 768)
|
704 |
+
self.shallow_config.intermediate_size = getattr(config, 'shallow_intermediate_size', 3072)
|
705 |
+
|
706 |
+
self.embeddings = BertEmbeddings(self.shallow_config)
|
707 |
+
self.encoder = SkipBertEncoder(self.shallow_config, config)
|
708 |
+
|
709 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
710 |
+
|
711 |
+
self.embeddings.position_embeddings2 = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
|
712 |
+
self.embeddings.token_type_embeddings2 = nn.Embedding(self.config.type_vocab_size, self.config.hidden_size)
|
713 |
+
|
714 |
+
self.norm = nn.LayerNorm(self.config.hidden_size)
|
715 |
+
self.attn = nn.Linear(self.config.hidden_size, 1)
|
716 |
+
self.shallow_skipping = ShallowSkipping(self)
|
717 |
+
|
718 |
+
self.init_weights()
|
719 |
+
|
720 |
+
def forward(
|
721 |
+
self,
|
722 |
+
input_ids=None,
|
723 |
+
attention_mask=None,
|
724 |
+
token_type_ids=None,
|
725 |
+
position_ids=None,
|
726 |
+
head_mask=None,
|
727 |
+
encoder_hidden_states=None,
|
728 |
+
encoder_attention_mask=None,
|
729 |
+
output_attentions=True,
|
730 |
+
output_hidden_states=True,
|
731 |
+
):
|
732 |
+
|
733 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
734 |
+
output_hidden_states = (
|
735 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
736 |
+
)
|
737 |
+
|
738 |
+
input_shape = input_ids.size()
|
739 |
+
device = self.device
|
740 |
+
|
741 |
+
if attention_mask is None:
|
742 |
+
attention_mask = (input_ids != 0).float()
|
743 |
+
if token_type_ids is None:
|
744 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
745 |
+
|
746 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
747 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
748 |
+
|
749 |
+
hidden_states = self.shallow_skipping(
|
750 |
+
input_ids=input_ids,
|
751 |
+
token_type_ids=token_type_ids,
|
752 |
+
attention_mask=attention_mask,
|
753 |
+
head_mask=head_mask,
|
754 |
+
encoder_hidden_states=encoder_hidden_states,
|
755 |
+
encoder_attention_mask=encoder_attention_mask,
|
756 |
+
model=self,
|
757 |
+
)
|
758 |
+
|
759 |
+
# Global transformer layers
|
760 |
+
attention_mask = extended_attention_mask
|
761 |
+
|
762 |
+
all_hidden_states = ()
|
763 |
+
all_self_attentions = ()
|
764 |
+
|
765 |
+
for i, layer_module in enumerate(self.encoder.layer[-self.config.num_full_hidden_layers:]):
|
766 |
+
|
767 |
+
if output_hidden_states:
|
768 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
769 |
+
|
770 |
+
layer_head_mask = head_mask[i + self.config.num_hidden_layers - self.config.num_full_hidden_layers] if head_mask is not None else None
|
771 |
+
|
772 |
+
layer_outputs = layer_module(
|
773 |
+
hidden_states=hidden_states,
|
774 |
+
attention_mask=attention_mask,
|
775 |
+
head_mask=layer_head_mask,
|
776 |
+
encoder_hidden_states=encoder_hidden_states,
|
777 |
+
encoder_attention_mask=encoder_attention_mask,
|
778 |
+
past_key_value=None,
|
779 |
+
output_attentions=output_attentions,
|
780 |
+
)
|
781 |
+
|
782 |
+
hidden_states = layer_outputs[0]
|
783 |
+
|
784 |
+
if output_attentions:
|
785 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
786 |
+
|
787 |
+
if output_hidden_states:
|
788 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
789 |
+
|
790 |
+
sequence_output = hidden_states
|
791 |
+
pooled_output = self.pooler(sequence_output)
|
792 |
+
|
793 |
+
return (sequence_output, pooled_output, all_hidden_states, all_self_attentions)
|
794 |
+
|
795 |
+
def freeze_shallow_layers(self):
|
796 |
+
for p in self.embeddings.parameters():
|
797 |
+
p.requires_grad = False
|
798 |
+
for layer in self.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]:
|
799 |
+
for p in layer.parameters():
|
800 |
+
p.requires_grad = False
|
801 |
+
try:
|
802 |
+
for p in self.shallow_skipping.linear.parameters():
|
803 |
+
p.requires_grad = False
|
804 |
+
except Exception as e:
|
805 |
+
pass
|
806 |
+
try:
|
807 |
+
for p in self.attn.parameters():
|
808 |
+
p.requires_grad = False
|
809 |
+
except Exception as e:
|
810 |
+
pass
|
811 |
+
|
812 |
+
self.embeddings.dropout.p = 0.
|
813 |
+
for layer in self.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]:
|
814 |
+
for m in layer.modules():
|
815 |
+
if isinstance(m, torch.nn.Dropout):
|
816 |
+
m.p = 0.
|
817 |
+
|
818 |
+
|
819 |
+
class SkipBertForPreTraining(BertPreTrainedModel):
|
820 |
+
def __init__(self, config):
|
821 |
+
super().__init__(config)
|
822 |
+
fit_size = getattr(config, 'fit_size', 768)
|
823 |
+
self.bert = SkipBertModel(config)
|
824 |
+
self.cls = BertPreTrainingHeads(config)
|
825 |
+
|
826 |
+
if self.fit_size != config.hidden_size:
|
827 |
+
self.fit_denses = nn.ModuleList(
|
828 |
+
[nn.Linear(config.hidden_size, self.fit_size) for _ in range(config.num_hidden_layers + 1)]
|
829 |
+
)
|
830 |
+
|
831 |
+
def forward(self, input_ids, token_type_ids=None,
|
832 |
+
attention_mask=None, masked_lm_labels=None,
|
833 |
+
next_sentence_label=None, labels=None,
|
834 |
+
output_attentions=True, output_hidden_states=True,):
|
835 |
+
_, pooled_output, sequence_output, att_output = self.bert(
|
836 |
+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
|
837 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
838 |
+
|
839 |
+
if self.fit_size != self.config.hidden_size:
|
840 |
+
tmp = []
|
841 |
+
for s_id, sequence_layer in enumerate(sequence_output):
|
842 |
+
tmp.append(self.fit_denses[s_id](sequence_layer))
|
843 |
+
sequence_output = tmp
|
844 |
+
|
845 |
+
return att_output, sequence_output
|
846 |
+
|
847 |
+
|
848 |
+
def freeze_shallow_layers(self):
|
849 |
+
self.bert.freeze_shallow_layers()
|
850 |
+
|
851 |
+
|
852 |
+
class SkipBertForSequenceClassification(BertPreTrainedModel):
|
853 |
+
def __init__(self, config, do_fit=False, share_param=True):
|
854 |
+
super().__init__(config)
|
855 |
+
num_labels = config.num_labels
|
856 |
+
self.hidden_size = config.hidden_size
|
857 |
+
self.num_labels = num_labels
|
858 |
+
self.bert = SkipBertModel(config)
|
859 |
+
self.dropout = nn.Dropout(
|
860 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
|
861 |
+
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
862 |
+
|
863 |
+
self.do_fit, self.share_param = do_fit, share_param
|
864 |
+
if self.do_fit:
|
865 |
+
fit_size = getattr(config, 'fit_size', 768)
|
866 |
+
self.fit_size = fit_size
|
867 |
+
if self.share_param:
|
868 |
+
self.share_fit_dense = nn.Linear(config.hidden_size, fit_size)
|
869 |
+
else:
|
870 |
+
self.fit_denses = nn.ModuleList(
|
871 |
+
[nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
|
872 |
+
)
|
873 |
+
|
874 |
+
def do_fit_dense(self, sequence_output):
|
875 |
+
|
876 |
+
tmp = []
|
877 |
+
if self.do_fit:
|
878 |
+
for s_id, sequence_layer in enumerate(sequence_output):
|
879 |
+
if self.share_param:
|
880 |
+
tmp.append(self.share_fit_dense(sequence_layer))
|
881 |
+
else:
|
882 |
+
tmp.append(self.fit_denses[s_id](sequence_layer))
|
883 |
+
sequence_output = tmp
|
884 |
+
|
885 |
+
return sequence_output
|
886 |
+
|
887 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
888 |
+
|
889 |
+
_, pooled_output, sequence_output, att_output = self.bert(
|
890 |
+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
|
891 |
+
output_hidden_states=True, output_attentions=True)
|
892 |
+
|
893 |
+
sequence_output = self.do_fit_dense(sequence_output)
|
894 |
+
|
895 |
+
pooled_output = self.dropout(pooled_output)
|
896 |
+
logits = self.classifier(pooled_output)
|
897 |
+
|
898 |
+
return logits, att_output, sequence_output
|
899 |
+
|
900 |
+
def freeze_shallow_layers(self):
|
901 |
+
self.bert.freeze_shallow_layers()
|
902 |
+
|
903 |
+
|
904 |
+
class SkipBertForSequenceClassificationPrediction(SkipBertForSequenceClassification):
|
905 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
906 |
+
|
907 |
+
assert not self.training
|
908 |
+
|
909 |
+
_, pooled_output, sequence_output, att_output = self.bert(
|
910 |
+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
|
911 |
+
output_hidden_states=True, output_attentions=True)
|
912 |
+
|
913 |
+
logits = self.classifier(pooled_output)
|
914 |
+
|
915 |
+
loss = None
|
916 |
+
if labels is not None:
|
917 |
+
loss = torch.tensor(0.)
|
918 |
+
|
919 |
+
return SequenceClassifierOutput(
|
920 |
+
loss=loss,
|
921 |
+
logits=logits,
|
922 |
+
)
|
template_FL/src/fedllm/skipbert/plot.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import sys,os
|
8 |
+
import numba
|
9 |
+
|
10 |
+
def _set_madvise(large_data, advise=1):
|
11 |
+
'''
|
12 |
+
0: MADV_NORMAL
|
13 |
+
1: MADV_RANDOM
|
14 |
+
2: MADV_SEQUENTIAL
|
15 |
+
3: MADV_WILLNEED
|
16 |
+
4: MADV_DONTNEED
|
17 |
+
'''
|
18 |
+
import ctypes
|
19 |
+
madvise = ctypes.CDLL("libc.so.6").madvise
|
20 |
+
madvise.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
|
21 |
+
madvise.restype = ctypes.c_int
|
22 |
+
assert madvise(large_data.ctypes.data, large_data.size * large_data.dtype.itemsize, advise) == 0, "MADVISE FAILED" # 1 means MADV_RANDOM
|
23 |
+
|
24 |
+
def _read_or_create_memmap(path, return_tensor=True, *args, **kargs):
|
25 |
+
if os.path.exists(path):
|
26 |
+
a = np.memmap(path, mode='r+', *args, **kargs)
|
27 |
+
else:
|
28 |
+
a = np.memmap(path, mode='w+', *args, **kargs)
|
29 |
+
# first row is reserved for oovs
|
30 |
+
a[0] = 0
|
31 |
+
_set_madvise(a, advise=1)
|
32 |
+
if return_tensor:
|
33 |
+
a = torch.from_numpy(a) # zero-copy
|
34 |
+
return a
|
35 |
+
|
36 |
+
def _to_key(k):
|
37 |
+
return tuple(k.tolist())
|
38 |
+
|
39 |
+
@numba.njit()
|
40 |
+
def _input_ids_to_tri_grams(x: np.array):
|
41 |
+
bs, seq_len = x.shape
|
42 |
+
ret = np.zeros((bs*(seq_len+1), 3), dtype=np.int64)
|
43 |
+
i_ret = 0
|
44 |
+
for i_bs in range(bs):
|
45 |
+
for i_token in range(seq_len):
|
46 |
+
if x[i_bs, i_token] == 0:
|
47 |
+
break
|
48 |
+
if i_token == 0:
|
49 |
+
ret[i_ret][1] = x[i_bs, i_token]
|
50 |
+
ret[i_ret][2] = x[i_bs, i_token+1]
|
51 |
+
elif i_token == seq_len - 1:
|
52 |
+
ret[i_ret][0] = x[i_bs, i_token-1]
|
53 |
+
ret[i_ret][1] = x[i_bs, i_token]
|
54 |
+
else:
|
55 |
+
ret[i_ret] = x[i_bs, i_token-1:i_token+2]
|
56 |
+
i_ret += 1
|
57 |
+
i_ret += 1 # add a pad trigram between seqs
|
58 |
+
return ret[:i_ret]
|
59 |
+
|
60 |
+
|
61 |
+
@numba.njit()
|
62 |
+
def _input_ids_to_ngram_ids(d: dict, x: np.array):
|
63 |
+
'''
|
64 |
+
input_ids tp ngram_ids.
|
65 |
+
try match (x0, x1, x2) -> id;
|
66 |
+
if not possible, match (0, x1, x2) -> id;
|
67 |
+
if also not possible, match (0, x1, 0) -> id.
|
68 |
+
'''
|
69 |
+
bs, seq_len = x.shape
|
70 |
+
ret = np.zeros(bs*(seq_len+1), dtype=np.int64)
|
71 |
+
i_ret = 0
|
72 |
+
for i_bs in range(bs):
|
73 |
+
for i_token in range(seq_len):
|
74 |
+
if x[i_bs, i_token] == 0:
|
75 |
+
break
|
76 |
+
if i_token == 0:
|
77 |
+
k = (0, x[i_bs, i_token], x[i_bs, i_token+1])
|
78 |
+
elif i_token == seq_len - 1:
|
79 |
+
k = (x[i_bs, i_token-1], x[i_bs, i_token], 0)
|
80 |
+
else:
|
81 |
+
k = (x[i_bs, i_token-1], x[i_bs, i_token], x[i_bs, i_token+1])
|
82 |
+
if k in d: # tri-gram
|
83 |
+
ret[i_ret] = d[k]
|
84 |
+
else:
|
85 |
+
k = (0, k[1], k[2])
|
86 |
+
if k in d: # bi-gram
|
87 |
+
ret[i_ret] = d[k]
|
88 |
+
else:
|
89 |
+
k = (0, k[1], 0)
|
90 |
+
if k in d: # uni-gram
|
91 |
+
ret[i_ret] = d[k]
|
92 |
+
i_ret += 1
|
93 |
+
i_ret += 1 # add a pad trigram between seqs
|
94 |
+
return ret[:i_ret]
|
95 |
+
|
96 |
+
@numba.njit()
|
97 |
+
def _has_oov(d: dict, x: np.array):
|
98 |
+
bs, seq_len = x.shape
|
99 |
+
for i_bs in range(bs):
|
100 |
+
for i_token in range(seq_len):
|
101 |
+
if x[i_bs, i_token] == 0:
|
102 |
+
break
|
103 |
+
if i_token == 0:
|
104 |
+
k = (0, x[i_bs, i_token], x[i_bs, i_token+1])
|
105 |
+
elif i_token == seq_len - 1:
|
106 |
+
k = (x[i_bs, i_token-1], x[i_bs, i_token], 0)
|
107 |
+
else:
|
108 |
+
k = (x[i_bs, i_token-1], x[i_bs, i_token], x[i_bs, i_token+1])
|
109 |
+
if k not in d:
|
110 |
+
return True
|
111 |
+
return False
|
112 |
+
|
113 |
+
|
114 |
+
class Plot:
|
115 |
+
def __init__(self, max_num_entries=100000, hidden_size=768):
|
116 |
+
|
117 |
+
self.max_num_entries = max_num_entries
|
118 |
+
self.hidden_size = hidden_size
|
119 |
+
|
120 |
+
self.trigram_to_id, self.id_to_trigram = self.build_hash_table('input_ids_tri_gram.memmap', max_num_entries)
|
121 |
+
self.orig_trigram_hidden_states = _read_or_create_memmap("plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
|
122 |
+
|
123 |
+
def build_hash_table(self, path, max_num_entries):
|
124 |
+
n_gram = 3
|
125 |
+
hash_table1 = numba.typed.Dict()
|
126 |
+
hash_table1[tuple([0]*n_gram)] = 0 # dummy entry
|
127 |
+
orig_ngram_ids_mmap = _read_or_create_memmap(
|
128 |
+
path, return_tensor=False, dtype='int32', shape=(max_num_entries, n_gram))
|
129 |
+
|
130 |
+
for i in range(1, self.max_num_entries):
|
131 |
+
_tmp = orig_ngram_ids_mmap[i]
|
132 |
+
# break when meet all 0 ngram
|
133 |
+
if (_tmp==0).all():
|
134 |
+
break
|
135 |
+
tmp_hash = _to_key(_tmp)
|
136 |
+
if tmp_hash not in hash_table1:
|
137 |
+
hash_table1[tmp_hash] = i
|
138 |
+
|
139 |
+
return hash_table1, orig_ngram_ids_mmap
|
140 |
+
|
141 |
+
def input_ids_to_tri_grams(self, input_ids):
|
142 |
+
return _input_ids_to_tri_grams(input_ids)
|
143 |
+
|
144 |
+
def update_data(self, ngram_input_ids, ngram_hidden_states):
|
145 |
+
ngram_input_ids = ngram_input_ids.cpu().numpy()
|
146 |
+
ngram_hidden_states = ngram_hidden_states.detach().half().cpu() # FP16
|
147 |
+
bs, ngram = ngram_input_ids.shape
|
148 |
+
ngram_to_id, id_to_ngram, id_to_hidden_state = \
|
149 |
+
self.trigram_to_id, self.id_to_trigram, self.orig_trigram_hidden_states
|
150 |
+
# TODO: optimize the for-loop later
|
151 |
+
id_to_save = []
|
152 |
+
for i in range(bs):
|
153 |
+
ngram = _to_key(ngram_input_ids[i])
|
154 |
+
# TODO: handle ngram_id > max_size
|
155 |
+
ngram_id = ngram_to_id.get(ngram, len(ngram_to_id))
|
156 |
+
if ngram_id >= self.max_num_entries:
|
157 |
+
print('Exceed max number of entries...')
|
158 |
+
print('Skip current entry...')
|
159 |
+
continue
|
160 |
+
ngram_to_id[ngram] = ngram_id
|
161 |
+
id_to_ngram[ngram_id] = ngram
|
162 |
+
id_to_save.append(ngram_id)
|
163 |
+
id_to_hidden_state[id_to_save] = ngram_hidden_states
|
164 |
+
|
165 |
+
def retrieve_data(self, input_ids):
|
166 |
+
input_ids = input_ids.numpy()
|
167 |
+
id_to_get = _input_ids_to_ngram_ids(self.trigram_to_id, input_ids)
|
168 |
+
hidden_states = self.orig_trigram_hidden_states[id_to_get]
|
169 |
+
return hidden_states
|
170 |
+
|
171 |
+
def has_oov(self, input_ids):
|
172 |
+
return _has_oov(self.trigram_to_id, input_ids.numpy())
|
173 |
+
|
template_FL/src/fedllm/templates/alpaca.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Template used by Alpaca-LoRA.",
|
3 |
+
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
4 |
+
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
5 |
+
"response_split": "### Response:"
|
6 |
+
}
|
template_FL/src/fedllm/trainer.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from accelerate import Accelerator
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import torch
|
4 |
+
import copy
|
5 |
+
import numpy as np
|
6 |
+
from transformers import BertForSequenceClassification, GenerationConfig, AutoTokenizer
|
7 |
+
import inspect
|
8 |
+
import logging
|
9 |
+
import wandb
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ManualLLMSampleCB:
|
15 |
+
def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
|
16 |
+
self.model = model
|
17 |
+
self.concat_model = None
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.task = task
|
20 |
+
self.num_samples = num_samples
|
21 |
+
self.max_new_tokens = max_new_tokens
|
22 |
+
self.gen_config = GenerationConfig.from_pretrained(
|
23 |
+
model.config.name_or_path, max_new_tokens=max_new_tokens
|
24 |
+
)
|
25 |
+
|
26 |
+
def generate(self, prompt):
|
27 |
+
# Tokenize the input prompt and include the attention mask
|
28 |
+
tokenized_prompt = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
|
29 |
+
input_ids = tokenized_prompt['input_ids']
|
30 |
+
attention_mask = tokenized_prompt['attention_mask'] # Extract attention mask
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
output = self.model.generate(
|
34 |
+
input_ids=input_ids,
|
35 |
+
attention_mask=attention_mask,
|
36 |
+
max_new_tokens=self.max_new_tokens,
|
37 |
+
generation_config=self.gen_config
|
38 |
+
)
|
39 |
+
return self.tokenizer.decode(output[0], skip_special_tokens=True)
|
40 |
+
|
41 |
+
|
42 |
+
def create_samples_table(self, dataset):
|
43 |
+
table = wandb.Table(columns=["input", "prediction", "label", "task"])
|
44 |
+
sampled_dataset = dataset.shuffle(seed=42).select(range(self.num_samples))
|
45 |
+
|
46 |
+
for example in tqdm(sampled_dataset, desc="Generating Samples"):
|
47 |
+
instruction = example.get("instruction", "")
|
48 |
+
input_text = example.get("input", "")
|
49 |
+
label = example.get("output", "")
|
50 |
+
|
51 |
+
if input_text:
|
52 |
+
prompt = f"Instruction: {instruction} Input: {input_text} Response:"
|
53 |
+
else:
|
54 |
+
prompt = f"Instruction: {instruction} Response:"
|
55 |
+
|
56 |
+
prediction = self.generate(prompt)
|
57 |
+
table.add_data(prompt, prediction, label, self.task)
|
58 |
+
|
59 |
+
return table
|
60 |
+
|
61 |
+
def log_samples_to_wandb(self, dataset):
|
62 |
+
samples_table = self.create_samples_table(dataset)
|
63 |
+
wandb.log({"sample_predictions": samples_table})
|
64 |
+
|
65 |
+
|
66 |
+
class ManualTrainer:
|
67 |
+
def __init__(
|
68 |
+
self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
|
69 |
+
args, data_collator, compute_metrics, mates_args, data_influence_model, data_influence_tokenizer
|
70 |
+
):
|
71 |
+
self.accelerator = Accelerator()
|
72 |
+
self.model = model
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
self.args = args
|
75 |
+
self.data_collator = data_collator
|
76 |
+
self.compute_metrics = compute_metrics
|
77 |
+
self.mates_args = mates_args
|
78 |
+
self.data_influence_model = data_influence_model
|
79 |
+
self.data_influence_tokenizer = data_influence_tokenizer
|
80 |
+
|
81 |
+
# Remove unused columns from datasets
|
82 |
+
if train_dataset:
|
83 |
+
self.train_dataset = self._remove_unused_columns(train_dataset, "training")
|
84 |
+
# Prepare data loaders
|
85 |
+
self.full_train_loader = DataLoader(
|
86 |
+
self.train_dataset,
|
87 |
+
batch_size=self.args.per_device_train_batch_size,
|
88 |
+
shuffle=True,
|
89 |
+
collate_fn=self.data_collator,
|
90 |
+
drop_last=self.args.dataloader_drop_last
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
self.train_loader = None
|
94 |
+
self.full_train_loader = None
|
95 |
+
|
96 |
+
if val_dataset:
|
97 |
+
self.val_dataset = self._remove_unused_columns(val_dataset, "validation")
|
98 |
+
self.val_loader = DataLoader(
|
99 |
+
self.val_dataset,
|
100 |
+
batch_size=self.args.per_device_eval_batch_size,
|
101 |
+
shuffle=False,
|
102 |
+
collate_fn=self.data_collator,
|
103 |
+
drop_last=self.args.dataloader_drop_last
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.val_loader = None
|
107 |
+
|
108 |
+
if self.mates_args.state:
|
109 |
+
self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
|
110 |
+
self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
|
111 |
+
|
112 |
+
self.holdout_loader = DataLoader(
|
113 |
+
self.holdout_dataset,
|
114 |
+
batch_size=self.mates_args.holdout_batch_size,
|
115 |
+
shuffle=True,
|
116 |
+
collate_fn=self.data_collator,
|
117 |
+
drop_last=self.args.dataloader_drop_last
|
118 |
+
)
|
119 |
+
|
120 |
+
self.reference_loader = DataLoader(
|
121 |
+
self.reference_dataset,
|
122 |
+
batch_size=self.mates_args.reference_batch_size,
|
123 |
+
shuffle=False,
|
124 |
+
collate_fn=self.data_collator,
|
125 |
+
drop_last=self.args.dataloader_drop_last
|
126 |
+
)
|
127 |
+
|
128 |
+
# Prepare optimizer
|
129 |
+
self.optimizer = torch.optim.AdamW(
|
130 |
+
self.model.parameters(),
|
131 |
+
lr=self.args.learning_rate
|
132 |
+
)
|
133 |
+
|
134 |
+
# Prepare model, optimizer, and data loaders for Accelerator
|
135 |
+
self.model, self.optimizer, self.full_train_loader, self.val_loader = self.accelerator.prepare(
|
136 |
+
self.model, self.optimizer, self.full_train_loader, self.val_loader
|
137 |
+
)
|
138 |
+
|
139 |
+
if self.mates_args.state:
|
140 |
+
# Prepare holdout and reference loaders for Accelerator
|
141 |
+
self.data_influence_model, self.holdout_loader, self.reference_loader = self.accelerator.prepare(
|
142 |
+
self.data_influence_model, self.holdout_loader, self.reference_loader
|
143 |
+
)
|
144 |
+
|
145 |
+
def _remove_unused_columns(self, dataset, description=None):
|
146 |
+
"""
|
147 |
+
Removes columns from a dataset that are not used by the model's forward method.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
dataset: A dataset object (e.g., from datasets.Dataset).
|
151 |
+
description: A string description of the dataset (e.g., "training" or "validation").
|
152 |
+
Returns:
|
153 |
+
The dataset with unused columns removed.
|
154 |
+
"""
|
155 |
+
# Inspect the model forward signature
|
156 |
+
forward_signature = inspect.signature(self.model.forward)
|
157 |
+
signature_columns = list(forward_signature.parameters.keys())
|
158 |
+
|
159 |
+
# Add label columns to the signature columns
|
160 |
+
label_columns = ["labels", "label_ids"]
|
161 |
+
signature_columns += label_columns
|
162 |
+
|
163 |
+
# Determine unused columns
|
164 |
+
dataset_columns = set(dataset.column_names)
|
165 |
+
used_columns = set(signature_columns).intersection(dataset_columns)
|
166 |
+
ignored_columns = list(dataset_columns - used_columns)
|
167 |
+
|
168 |
+
if ignored_columns:
|
169 |
+
logger.info(
|
170 |
+
f"The following columns in the {description} set don't have a corresponding argument in "
|
171 |
+
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
172 |
+
)
|
173 |
+
|
174 |
+
# Ensure at least one column matches the model's expected inputs
|
175 |
+
if not used_columns:
|
176 |
+
raise ValueError(
|
177 |
+
f"No columns in the {description} dataset match the model's forward method signature. "
|
178 |
+
f"The following columns have been ignored: {', '.join(ignored_columns)}."
|
179 |
+
)
|
180 |
+
|
181 |
+
return dataset.remove_columns(ignored_columns)
|
182 |
+
|
183 |
+
def train(self):
|
184 |
+
best_val_loss = float('inf')
|
185 |
+
early_stopping_counter = 0
|
186 |
+
early_stopping_patience = 5
|
187 |
+
training_loss = []
|
188 |
+
|
189 |
+
for epoch in range(self.args.num_train_epochs):
|
190 |
+
# Check if it's time to update the data influence model and state is True
|
191 |
+
if self.mates_args.state and epoch % self.mates_args.update_data_influence_model_step == 0:
|
192 |
+
print("Updating the data influence model and selecting high-quality data...")
|
193 |
+
logger.info("Updating the data influence model and selecting high-quality data...")
|
194 |
+
self.update_data_influence_model()
|
195 |
+
|
196 |
+
# Filter high-quality data using the data influence model
|
197 |
+
high_quality_indices = self.select_high_quality_data(
|
198 |
+
dataset_size=len(self.train_dataset),
|
199 |
+
selection_fraction=self.mates_args.selection_fraction,
|
200 |
+
)
|
201 |
+
self.train_loader = self.accelerator.prepare(
|
202 |
+
self.create_filtered_dataloader(high_quality_indices)
|
203 |
+
)
|
204 |
+
|
205 |
+
self.model.train()
|
206 |
+
epoch_loss = 0.0
|
207 |
+
|
208 |
+
for step, batch in enumerate(self.train_loader):
|
209 |
+
if step >= self.args.max_steps:
|
210 |
+
break
|
211 |
+
|
212 |
+
self.optimizer.zero_grad()
|
213 |
+
|
214 |
+
outputs = self.model(
|
215 |
+
input_ids=batch['input_ids'],
|
216 |
+
attention_mask=batch['attention_mask'],
|
217 |
+
labels=batch['labels']
|
218 |
+
)
|
219 |
+
loss = outputs.loss
|
220 |
+
|
221 |
+
self.accelerator.backward(loss)
|
222 |
+
self.optimizer.step()
|
223 |
+
|
224 |
+
epoch_loss += loss.item()
|
225 |
+
|
226 |
+
if (step + 1) % self.args.logging_steps == 0:
|
227 |
+
# print(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
|
228 |
+
logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
|
229 |
+
|
230 |
+
avg_epoch_loss = epoch_loss / len(self.train_loader)
|
231 |
+
training_loss.append(avg_epoch_loss)
|
232 |
+
|
233 |
+
val_results = self.evaluate()
|
234 |
+
|
235 |
+
# print(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
|
236 |
+
logger,info(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
|
237 |
+
|
238 |
+
# Early stopping logic
|
239 |
+
if val_results["eval_loss"] < best_val_loss:
|
240 |
+
best_val_loss = val_results["eval_loss"]
|
241 |
+
early_stopping_counter = 0
|
242 |
+
else:
|
243 |
+
early_stopping_counter += 1
|
244 |
+
if early_stopping_counter >= early_stopping_patience:
|
245 |
+
print("Early stopping triggered")
|
246 |
+
break
|
247 |
+
|
248 |
+
return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
|
249 |
+
|
250 |
+
def select_high_quality_data(self, dataset_size, selection_fraction):
|
251 |
+
"""
|
252 |
+
Use the data influence model to predict quality scores and select high-quality data indices.
|
253 |
+
"""
|
254 |
+
print("Selecting high-quality data using the data influence model...")
|
255 |
+
|
256 |
+
# Predict influence scores for the entire dataset
|
257 |
+
influence_scores = []
|
258 |
+
self.data_influence_model.eval()
|
259 |
+
influence_optimizer = self.accelerator.prepare(
|
260 |
+
torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
|
261 |
+
)
|
262 |
+
i = 0
|
263 |
+
with torch.no_grad():
|
264 |
+
for batch in self.full_train_loader: # Full dataset loader
|
265 |
+
text = self.tokenizer.batch_decode(
|
266 |
+
batch['input_ids'],
|
267 |
+
skip_special_tokens=True
|
268 |
+
)
|
269 |
+
|
270 |
+
# Tokenize the text using the BERT tokenizer
|
271 |
+
bert_inputs = self.data_influence_tokenizer(
|
272 |
+
text,
|
273 |
+
truncation=True,
|
274 |
+
padding='max_length',
|
275 |
+
max_length=256,
|
276 |
+
return_tensors='pt'
|
277 |
+
).to(self.accelerator.device)
|
278 |
+
|
279 |
+
# Train the data influence model
|
280 |
+
influence_optimizer.zero_grad()
|
281 |
+
outputs = self.data_influence_model(
|
282 |
+
input_ids=bert_inputs['input_ids'],
|
283 |
+
attention_mask=bert_inputs['attention_mask'],
|
284 |
+
)
|
285 |
+
|
286 |
+
influence_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
|
287 |
+
|
288 |
+
i += 1
|
289 |
+
|
290 |
+
if i == 100:
|
291 |
+
break
|
292 |
+
|
293 |
+
# Normalize influence scores and apply Gumbel-Top-$k$ selection
|
294 |
+
influence_scores = np.array(influence_scores)
|
295 |
+
print(">> Influence scores shape:", influence_scores.shape)
|
296 |
+
|
297 |
+
# Add Gumbel noise for diversity
|
298 |
+
rng = np.random.default_rng()
|
299 |
+
gumbel_noise = rng.gumbel(size=len(influence_scores))
|
300 |
+
influence_scores += gumbel_noise
|
301 |
+
|
302 |
+
# Select top indices based on influence scores
|
303 |
+
selection_size = int(len(influence_scores)*selection_fraction)
|
304 |
+
high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
|
305 |
+
print(f"Selected {len(high_quality_indices)} high-quality samples.")
|
306 |
+
|
307 |
+
return high_quality_indices
|
308 |
+
|
309 |
+
def create_filtered_dataloader(self, indices):
|
310 |
+
"""
|
311 |
+
Create a new dataloader with only the selected high-quality data.
|
312 |
+
"""
|
313 |
+
print("Creating a filtered dataloader with selected high-quality data...")
|
314 |
+
subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
|
315 |
+
return torch.utils.data.DataLoader(
|
316 |
+
subset_dataset,
|
317 |
+
batch_size=self.args.per_device_train_batch_size,
|
318 |
+
shuffle=True,
|
319 |
+
collate_fn=self.data_collator, # Use the same collate function
|
320 |
+
drop_last=self.args.dataloader_drop_last
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
def update_data_influence_model(self):
|
325 |
+
# Train a copy of the model on holdout data and validate on reference data
|
326 |
+
copied_model = copy.deepcopy(self.model)
|
327 |
+
copied_model.train()
|
328 |
+
optimizer = self.accelerator.prepare(
|
329 |
+
torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
|
330 |
+
)
|
331 |
+
holdout_reference_pairs = []
|
332 |
+
|
333 |
+
# print("Starting to collect holdout-reference pairs...")
|
334 |
+
logger.info("Starting to collect holdout-reference pairs...")
|
335 |
+
for step, holdout_batch in enumerate(self.holdout_loader):
|
336 |
+
# print(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
|
337 |
+
logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
|
338 |
+
|
339 |
+
optimizer.zero_grad()
|
340 |
+
outputs = copied_model(
|
341 |
+
input_ids=holdout_batch['input_ids'],
|
342 |
+
attention_mask=holdout_batch['attention_mask'],
|
343 |
+
labels=holdout_batch['labels']
|
344 |
+
)
|
345 |
+
holdout_loss = outputs.loss
|
346 |
+
decoded_texts = self.tokenizer.batch_decode(
|
347 |
+
holdout_batch['input_ids'],
|
348 |
+
skip_special_tokens=True
|
349 |
+
)
|
350 |
+
|
351 |
+
holdout_loss.backward()
|
352 |
+
optimizer.step()
|
353 |
+
|
354 |
+
print(f"Evaluating reference losses at step {step}...")
|
355 |
+
logger.info(f"Evaluating reference losses at step {step}...")
|
356 |
+
|
357 |
+
copied_model.eval()
|
358 |
+
reference_losses = []
|
359 |
+
|
360 |
+
with torch.no_grad():
|
361 |
+
for ref_batch in self.reference_loader:
|
362 |
+
outputs = copied_model(
|
363 |
+
input_ids=ref_batch['input_ids'],
|
364 |
+
attention_mask=ref_batch['attention_mask'],
|
365 |
+
labels=ref_batch['labels']
|
366 |
+
)
|
367 |
+
reference_losses.append(outputs.loss.item())
|
368 |
+
|
369 |
+
# Compute the mean of reference losses
|
370 |
+
score = sum(reference_losses) / len(reference_losses) if reference_losses else 0.0
|
371 |
+
holdout_reference_pairs.append((decoded_texts, score))
|
372 |
+
# copied_model.train()
|
373 |
+
|
374 |
+
# Train the data influence model using the generated pairs
|
375 |
+
print("Starting to train the data influence model...")
|
376 |
+
logger.info("Starting to train the data influence model...")
|
377 |
+
|
378 |
+
self.data_influence_model.train()
|
379 |
+
influence_optimizer = torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
|
380 |
+
|
381 |
+
for step, (text, score) in enumerate(holdout_reference_pairs):
|
382 |
+
# Tokenize the text using the BERT tokenizer
|
383 |
+
bert_inputs = self.data_influence_tokenizer(
|
384 |
+
text,
|
385 |
+
truncation=True,
|
386 |
+
padding='max_length',
|
387 |
+
max_length=256,
|
388 |
+
return_tensors='pt'
|
389 |
+
).to(self.accelerator.device)
|
390 |
+
|
391 |
+
# Convert score to tensor and enable gradients
|
392 |
+
score_tensor = torch.tensor([score], device=self.accelerator.device, dtype=torch.float32, requires_grad=True)
|
393 |
+
|
394 |
+
# Train the data influence model
|
395 |
+
influence_optimizer.zero_grad()
|
396 |
+
outputs = self.data_influence_model(
|
397 |
+
input_ids=bert_inputs['input_ids'],
|
398 |
+
attention_mask=bert_inputs['attention_mask'],
|
399 |
+
labels=score_tensor
|
400 |
+
)
|
401 |
+
influence_loss = outputs.loss
|
402 |
+
|
403 |
+
influence_loss.backward()
|
404 |
+
influence_optimizer.step()
|
405 |
+
|
406 |
+
if step % 50 == 0:
|
407 |
+
print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
408 |
+
logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
409 |
+
|
410 |
+
|
411 |
+
# Distillation for SkipBERT
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
def evaluate(self, wandb_sample=True):
|
416 |
+
self.model.eval()
|
417 |
+
val_loss = 0.0
|
418 |
+
|
419 |
+
all_preds = []
|
420 |
+
all_labels = []
|
421 |
+
|
422 |
+
with torch.no_grad():
|
423 |
+
for batch in self.val_loader:
|
424 |
+
outputs = self.model(
|
425 |
+
input_ids=batch['input_ids'],
|
426 |
+
attention_mask=batch['attention_mask'],
|
427 |
+
labels=batch['labels']
|
428 |
+
)
|
429 |
+
val_loss += outputs.loss.item()
|
430 |
+
|
431 |
+
logits = self.accelerator.gather(outputs.logits)
|
432 |
+
labels = self.accelerator.gather(batch['labels'])
|
433 |
+
|
434 |
+
logits = logits.cpu().numpy()
|
435 |
+
labels = labels.cpu().numpy()
|
436 |
+
|
437 |
+
predictions = np.argmax(logits, axis=-1)
|
438 |
+
attention_mask = batch['attention_mask'].cpu().numpy()
|
439 |
+
|
440 |
+
for pred, label, mask in zip(predictions, labels, attention_mask):
|
441 |
+
valid_pred = pred[mask.astype(bool)]
|
442 |
+
valid_label = label[mask.astype(bool)]
|
443 |
+
|
444 |
+
all_preds.append(valid_pred)
|
445 |
+
all_labels.append(valid_label)
|
446 |
+
|
447 |
+
max_len = max(len(seq) for seq in all_preds)
|
448 |
+
padded_preds = np.array([
|
449 |
+
np.pad(seq, (0, max_len - len(seq)), 'constant', constant_values=self.tokenizer.pad_token_id)
|
450 |
+
for seq in all_preds
|
451 |
+
])
|
452 |
+
|
453 |
+
max_len = max(len(seq) for seq in all_labels)
|
454 |
+
padded_labels = np.array([
|
455 |
+
np.pad(seq, (0, max_len - len(seq)), 'constant', constant_values=-100)
|
456 |
+
for seq in all_labels
|
457 |
+
])
|
458 |
+
|
459 |
+
metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
|
460 |
+
|
461 |
+
metrics.update({"eval_loss": val_loss / len(self.val_loader)})
|
462 |
+
print("Validation Metrics:", metrics)
|
463 |
+
|
464 |
+
if wandb_sample:
|
465 |
+
# Sample Logging
|
466 |
+
llm_sample_cb = ManualLLMSampleCB(
|
467 |
+
model=self.model,
|
468 |
+
tokenizer=self.tokenizer,
|
469 |
+
task="classification",
|
470 |
+
num_samples=5,
|
471 |
+
max_new_tokens=128
|
472 |
+
)
|
473 |
+
llm_sample_cb.log_samples_to_wandb(self.val_dataset)
|
474 |
+
|
475 |
+
return metrics
|
476 |
+
|
template_FL/src/fedllm/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def clean_output_text(text):
|
2 |
+
"""
|
3 |
+
Clean and normalize text from LLM outputs by removing noise and repetitions.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
text (str): Raw text from LLM prediction
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
str: Cleaned and normalized text
|
10 |
+
"""
|
11 |
+
import re
|
12 |
+
|
13 |
+
def remove_repeats(text):
|
14 |
+
# Remove repeated words
|
15 |
+
pattern_words = r'\b(\w+)(?:\s+\1\b)+'
|
16 |
+
text = re.sub(pattern_words, r'\1', text)
|
17 |
+
|
18 |
+
# Remove repeated character patterns (like 'asasas')
|
19 |
+
pattern_chars = r'(\w+?)\1+'
|
20 |
+
text = re.sub(pattern_chars, r'\1', text)
|
21 |
+
|
22 |
+
return text
|
23 |
+
|
24 |
+
# Remove excessive punctuation
|
25 |
+
def normalize_punctuation(text):
|
26 |
+
# Replace multiple exclamation/question marks with single ones
|
27 |
+
text = re.sub(r'!+', '!', text)
|
28 |
+
text = re.sub(r'\?+', '?', text)
|
29 |
+
# Remove multiple periods (except for ellipsis)
|
30 |
+
text = re.sub(r'\.{4,}', '...', text)
|
31 |
+
text = text.replace('cor', '').replace('asesa', '')
|
32 |
+
return text
|
33 |
+
|
34 |
+
# Main cleaning pipeline
|
35 |
+
cleaned_text = text.strip()
|
36 |
+
|
37 |
+
# Remove common noise patterns
|
38 |
+
noise_patterns = [
|
39 |
+
r'\n+', # Multiple newlines
|
40 |
+
r'\s+', # Multiple spaces
|
41 |
+
r'\\n', # Literal \n
|
42 |
+
r'\\t', # Literal \t
|
43 |
+
]
|
44 |
+
|
45 |
+
for pattern in noise_patterns:
|
46 |
+
cleaned_text = re.sub(pattern, ' ', cleaned_text)
|
47 |
+
|
48 |
+
# Apply cleaning functions
|
49 |
+
# cleaned_text = remove_repetitions(cleaned_text)
|
50 |
+
cleaned_text = remove_repeats(cleaned_text)
|
51 |
+
cleaned_text = normalize_punctuation(cleaned_text)
|
52 |
+
cleaned_text = ' '.join(cleaned_text.split()) # Normalize spacing
|
53 |
+
|
54 |
+
return cleaned_text.strip()
|
template_FL/src/pyproject.toml
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["hatchling"]
|
3 |
+
build-backend = "hatchling.build"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "flowertune-llm"
|
7 |
+
version = "1.0.0"
|
8 |
+
description = "FlowerTune LLM: Federated LLM Fine-tuning with Flower"
|
9 |
+
license = "Apache-2.0"
|
10 |
+
dependencies = [
|
11 |
+
"flwr[simulation]==1.12.0",
|
12 |
+
"flwr-datasets>=0.3.0",
|
13 |
+
"trl==0.8.1",
|
14 |
+
"bitsandbytes==0.43.0",
|
15 |
+
"scipy==1.13.0",
|
16 |
+
"peft==0.6.2",
|
17 |
+
"fschat[model_worker,webui]==0.2.35",
|
18 |
+
"transformers==4.41.1",
|
19 |
+
"sentencepiece==0.2.0",
|
20 |
+
"omegaconf==2.3.0",
|
21 |
+
"hf_transfer==0.1.8",
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
[tool.hatch.build.targets.wheel]
|
26 |
+
packages = ["."]
|
27 |
+
|
28 |
+
[tool.flwr.app]
|
29 |
+
publisher = "flwrlabs"
|
30 |
+
|
31 |
+
[tool.flwr.app.components]
|
32 |
+
serverapp = "fedllm.server_app:app"
|
33 |
+
clientapp = "fedllm.client_app:app"
|
34 |
+
|
35 |
+
[tool.flwr.app.config]
|
36 |
+
num-server-rounds = 2
|
37 |
+
num-supernodes = 10
|
38 |
+
|
39 |
+
# Define dataset
|
40 |
+
dataset.type = 'hete' # type = ['homo','hete']
|
41 |
+
dataset.name = "vicgalle/alpaca-gpt4"
|
42 |
+
|
43 |
+
# Define model settings
|
44 |
+
model.name = "Qwen/Qwen2.5-1.5B-Instruct"
|
45 |
+
model.quantization = 4
|
46 |
+
model.gradient-checkpointing = true
|
47 |
+
model.flash_attention = false
|
48 |
+
|
49 |
+
### Use MATES ###
|
50 |
+
mates.state = true
|
51 |
+
mates.holdout-ratio = 0.001
|
52 |
+
mates.reference-ratio = 0.0005
|
53 |
+
mates.holdout-batch-size = 4
|
54 |
+
mates.reference-batch-size = 2
|
55 |
+
mates.update-data-influence-model-step = 20
|
56 |
+
mates.selection-fraction = 0.4
|
57 |
+
### END ###
|
58 |
+
|
59 |
+
### Use SkipBERT ###
|
60 |
+
|
61 |
+
# Model setting
|
62 |
+
skipbert.student-model = "bert-base-uncased"
|
63 |
+
skipbert.num_layers_student = 12
|
64 |
+
skipbert.num_full_hidden_layers_student = 6
|
65 |
+
skipbert.num_masked_layers_teacher = 0
|
66 |
+
skipbert.num_masked_last_layers_teacher = 0
|
67 |
+
|
68 |
+
# Training hyperparameters
|
69 |
+
skipbert.train_batch_size = 8
|
70 |
+
skipbert.gradient_accumulation_steps = 2
|
71 |
+
skipbert.eval_batch_size = 8
|
72 |
+
skipbert.eval_accumulation_steps = 2
|
73 |
+
skipbert.learning_rate = 2.0e-5
|
74 |
+
skipbert.num_train_epochs = 10
|
75 |
+
skipbert.eval_step = 10
|
76 |
+
skipbert.max_seq_length = 128
|
77 |
+
skipbert.weight_decay = 1.0e-4
|
78 |
+
skipbert.warmup_steps = 100 # 500
|
79 |
+
skipbert.do_train = true
|
80 |
+
skipbert.do_eval = true
|
81 |
+
skipbert.max_steps = -1
|
82 |
+
skipbert.evaluation_strategy = "epoch"
|
83 |
+
skipbert.save_strategy = "epoch"
|
84 |
+
skipbert.lr_scheduler_type = "cosine" # or 'warmup_linear'
|
85 |
+
skipbert.logging_dir = './skipbert_logs'
|
86 |
+
skipbert.output_dir = "./skipbert_results"
|
87 |
+
skipbert.report_to = 'wandb'
|
88 |
+
|
89 |
+
# Knowledge distillation parameters
|
90 |
+
skipbert.beta = 0.01
|
91 |
+
skipbert.T = 1.0
|
92 |
+
skipbert.alpha = 1.0
|
93 |
+
skipbert.reduce_T = 1.0
|
94 |
+
skipbert.epochs_no_cls = 5
|
95 |
+
|
96 |
+
# Training schedule and features
|
97 |
+
skipbert.freeze_lower_layers = true
|
98 |
+
|
99 |
+
# Feature usage flags
|
100 |
+
skipbert.use_logits = true
|
101 |
+
skipbert.use_att = true
|
102 |
+
skipbert.use_rep = true
|
103 |
+
skipbert.use_embedding = false
|
104 |
+
|
105 |
+
# Training modes
|
106 |
+
skipbert.do_train = true
|
107 |
+
skipbert.do_eval = true
|
108 |
+
skipbert.do_predict = false
|
109 |
+
skipbert.do_fit = false
|
110 |
+
skipbert.fp16 = false
|
111 |
+
skipbert.no_pretrain = false
|
112 |
+
skipbert.use_init_weight = false
|
113 |
+
skipbert.share_param = true
|
114 |
+
skipbert.do_lower_case = true
|
115 |
+
skipbert.no_cuda = false
|
116 |
+
|
117 |
+
# N-gram settings
|
118 |
+
skipbert.n_gram_left = 1
|
119 |
+
skipbert.n_gram_right = 1
|
120 |
+
|
121 |
+
# Layer mappings
|
122 |
+
skipbert.att_layer_maps: [1, 3, 5, 7, 9, 11]
|
123 |
+
skipbert.hid_layer_maps: [6, 7, 8, 9, 10, 11, 12]
|
124 |
+
|
125 |
+
### END ###
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# Define LoRA settings
|
131 |
+
model.lora.lora-r = 8
|
132 |
+
model.lora.lora-alpha = 16
|
133 |
+
model.lora.lora-dropout = 0.05
|
134 |
+
model.lora.lora-target-modules = "q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj"
|
135 |
+
|
136 |
+
# Define training settings
|
137 |
+
train.save-every-round = 5
|
138 |
+
train.learning-rate-max = 5e-5
|
139 |
+
train.learning-rate-min = 1e-6
|
140 |
+
train.seq-length = 256
|
141 |
+
train.prompt_template_name = "alpaca"
|
142 |
+
train.train_on_inputs = true
|
143 |
+
train.verbose = false
|
144 |
+
|
145 |
+
# Define training agruments for HF Trainer
|
146 |
+
train.training-arguments.output-dir = ""
|
147 |
+
train.training-arguments.learning-rate = 3e-4
|
148 |
+
train.training-arguments.per-device-train-batch-size = 4
|
149 |
+
train.training-arguments.gradient-accumulation-steps = 1
|
150 |
+
train.training-arguments.per-device-eval-batch-size = 2
|
151 |
+
train.training-arguments.eval-accumulation-steps = 1
|
152 |
+
train.training-arguments.logging-steps = 10
|
153 |
+
train.training-arguments.num-train-epochs = 1
|
154 |
+
train.training-arguments.max-steps = 10
|
155 |
+
train.training-arguments.save-steps = 1000
|
156 |
+
train.training-arguments.save-total-limit = 10
|
157 |
+
train.training-arguments.gradient-checkpointing = true
|
158 |
+
train.training-arguments.lr-scheduler-type = "cosine"
|
159 |
+
train.training-arguments.warmup-steps = 0
|
160 |
+
train.training-arguments.do-train = true
|
161 |
+
train.training-arguments.do-eval = true
|
162 |
+
train.training-arguments.dataloader-drop-last = false
|
163 |
+
train.training-arguments.eval-strategy = "epoch"
|
164 |
+
train.training-arguments.save-strategy = "epoch"
|
165 |
+
train.training-arguments.ddp-find-unused-parameters = false
|
166 |
+
train.training-arguments.group-by-length = true
|
167 |
+
train.training-arguments.load_best_model_at_end = true
|
168 |
+
train.training-arguments.report-to = "wandb"
|
169 |
+
|
170 |
+
# Define local training settings
|
171 |
+
train.strategy.fraction-fit = 0.2
|
172 |
+
train.strategy.fraction-evaluate = 0.0
|
173 |
+
|
174 |
+
[tool.flwr.federations]
|
175 |
+
default = "local-simulation"
|
176 |
+
|
177 |
+
[tool.flwr.federations.local-simulation]
|
178 |
+
options.num-supernodes = 10
|
179 |
+
options.backend.client-resources.num-cpus = 8
|
180 |
+
options.backend.client-resources.num-gpus = 1.0
|
template_FL/src/tesst.ipynb
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "013c8cfe-254e-4ee8-81e0-27478628c8e9",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import tomllib\n",
|
11 |
+
"\n",
|
12 |
+
"with open(\"config.toml\", \"rb\") as f:\n",
|
13 |
+
" config = tomllib.load(f)\n",
|
14 |
+
"\n",
|
15 |
+
"numbers = config[\"my_integers\"] # Returns Python list: [1, 2, 3, 42]"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"metadata": {
|
20 |
+
"kernelspec": {
|
21 |
+
"display_name": "Python 3 (ipykernel)",
|
22 |
+
"language": "python",
|
23 |
+
"name": "python3"
|
24 |
+
},
|
25 |
+
"language_info": {
|
26 |
+
"codemirror_mode": {
|
27 |
+
"name": "ipython",
|
28 |
+
"version": 3
|
29 |
+
},
|
30 |
+
"file_extension": ".py",
|
31 |
+
"mimetype": "text/x-python",
|
32 |
+
"name": "python",
|
33 |
+
"nbconvert_exporter": "python",
|
34 |
+
"pygments_lexer": "ipython3",
|
35 |
+
"version": "3.11.7"
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"nbformat": 4,
|
39 |
+
"nbformat_minor": 5
|
40 |
+
}
|