kisejin commited on
Commit
795c49e
·
1 Parent(s): 02e64cb

initial: create version skipbert for mates

Browse files
template_FL/.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Remove uncenssary output files
174
+ results/
175
+ wandb/
template_FL/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KiseJin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
template_FL/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # template_FL
template_FL/requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bitsandbytes==0.45.0
2
+ cryptography==42.0.8
3
+ cupy-cuda12x==13.3.0
4
+ docker-pycreds==0.4.0
5
+ faiss-cpu==1.9.0.post1
6
+ fastrlock==0.8.3
7
+ # flash-attn==2.6.3
8
+ flwr==1.14.0
9
+ flwr-datasets==0.5.0
10
+ fsspec<=2024.10.0,>=2023.1.0
11
+ gitdb==4.0.12
12
+ gitpython==3.1.44
13
+ grpcio==1.64.3
14
+ iterators==0.0.2
15
+ multiprocess==0.70.16
16
+ pathspec==0.12.1
17
+ protobuf==4.25.5
18
+ prv-accountant==0.2.0
19
+ pycryptodome==3.21.0
20
+ ray==2.10.0
21
+ rouge-score==0.1.2
22
+ sentry-sdk==2.19.2
23
+ setproctitle==1.3.4
24
+ shellingham==1.5.4
25
+ smmap==5.0.2
26
+ sympy==1.13.1
27
+ thop==0.1.1-2209072238
28
+ tomli-w==1.1.0
29
+ typer==0.12.5
30
+ wandb==0.19.3
31
+ python-dotenv
32
+ omegaconf
33
+ trl
34
+ evaluate
35
+ google
36
+ deepspeed
template_FL/src/environment.yml ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: fedllm
2
+ channels:
3
+ - xformers
4
+ - pytorch
5
+ - nvidia
6
+ - defaults
7
+ - conda-forge
8
+ - https://repo.anaconda.com/pkgs/main
9
+ - https://repo.anaconda.com/pkgs/r
10
+ dependencies:
11
+ - _libgcc_mutex=0.1=conda_forge
12
+ - _openmp_mutex=4.5=2_gnu
13
+ - about-time=4.2.1=pyhd8ed1ab_1
14
+ - absl-py=2.1.0=pyhd8ed1ab_1
15
+ - accelerate=1.2.1=pyhd8ed1ab_1
16
+ - aiohappyeyeballs=2.4.4=pyhd8ed1ab_1
17
+ - aiohttp=3.11.11=py311h2dc5d0c_0
18
+ - aiosignal=1.3.2=pyhd8ed1ab_0
19
+ - alive-progress=3.2.0=pyhd8ed1ab_0
20
+ - alsa-lib=1.2.8=h166bdaf_0
21
+ - annotated-types=0.7.0=pyhd8ed1ab_1
22
+ - antlr-python-runtime=4.9.3=pyhd8ed1ab_1
23
+ - anyio=4.8.0=pyhd8ed1ab_0
24
+ - aom=3.5.0=h27087fc_0
25
+ - argon2-cffi=23.1.0=pyhd8ed1ab_1
26
+ - argon2-cffi-bindings=21.2.0=py311h9ecbd09_5
27
+ - arrow=1.3.0=pyhd8ed1ab_1
28
+ - asttokens=3.0.0=pyhd8ed1ab_1
29
+ - async-lru=2.0.4=pyhd8ed1ab_1
30
+ - async-timeout=4.0.3=pyhd8ed1ab_0
31
+ - attr=2.5.1=h166bdaf_1
32
+ - attrs=24.3.0=pyh71513ae_0
33
+ - autograd=1.7.0=pyhd8ed1ab_1
34
+ - aws-c-auth=0.7.3=he2921ad_3
35
+ - aws-c-cal=0.6.2=hc309b26_1
36
+ - aws-c-common=0.9.0=hd590300_0
37
+ - aws-c-compression=0.2.17=h4d4d85c_2
38
+ - aws-c-event-stream=0.3.2=h2e3709c_0
39
+ - aws-c-http=0.7.12=hc865f51_1
40
+ - aws-c-io=0.13.32=h1a03231_3
41
+ - aws-c-mqtt=0.9.5=h3a0376c_1
42
+ - aws-c-s3=0.3.14=h1678ad6_3
43
+ - aws-c-sdkutils=0.1.12=h4d4d85c_1
44
+ - aws-checksums=0.1.17=h4d4d85c_1
45
+ - aws-crt-cpp=0.23.0=h40cdbb9_5
46
+ - aws-sdk-cpp=1.10.57=h6f6b8fa_21
47
+ - babel=2.16.0=pyhd8ed1ab_1
48
+ - beautifulsoup4=4.12.3=pyha770c72_1
49
+ - binaryornot=0.4.4=pyhd8ed1ab_2
50
+ - blas=1.0=mkl
51
+ - bleach=6.2.0=pyhd8ed1ab_3
52
+ - bleach-with-css=6.2.0=hd8ed1ab_3
53
+ - blosc=1.21.5=h0f2a231_0
54
+ - boltons=24.0.0=pyhd8ed1ab_1
55
+ - brotli=1.0.9=h166bdaf_9
56
+ - brotli-bin=1.0.9=h166bdaf_9
57
+ - brotli-python=1.0.9=py311ha362b79_9
58
+ - brunsli=0.1=h9c3ff4c_0
59
+ - bzip2=1.0.8=h5eee18b_6
60
+ - c-ares=1.34.4=hb9d3cd8_0
61
+ - c-blosc2=2.12.0=hb4ffafa_0
62
+ - ca-certificates=2024.12.31=h06a4308_0
63
+ - cached-property=1.5.2=hd8ed1ab_1
64
+ - cached_property=1.5.2=pyha770c72_1
65
+ - cachetools=5.5.0=pyhd8ed1ab_1
66
+ - cairo=1.16.0=ha61ee94_1014
67
+ - certifi=2024.12.14=pyhd8ed1ab_0
68
+ - cffi=1.17.1=py311hf29c0ef_0
69
+ - cfitsio=4.2.0=hd9d235c_0
70
+ - chardet=5.2.0=py311h38be061_2
71
+ - charls=2.4.2=h59595ed_0
72
+ - charset-normalizer=3.4.1=pyhd8ed1ab_0
73
+ - click=8.1.8=pyh707e725_0
74
+ - cma=3.2.2=pyh050c7b8_0
75
+ - colorama=0.4.6=pyhd8ed1ab_1
76
+ - comm=0.2.2=pyhd8ed1ab_1
77
+ - conda=23.7.4=py311h38be061_0
78
+ - conda-package-handling=2.4.0=pyh7900ff3_2
79
+ - conda-package-streaming=0.11.0=pyhd8ed1ab_0
80
+ - contourpy=1.3.1=py311hd18a35c_0
81
+ - cookiecutter=2.6.0=pyhd8ed1ab_1
82
+ - cpp-expected=1.1.0=hf52228f_0
83
+ - cuda-cudart=12.4.127=0
84
+ - cuda-cupti=12.4.127=0
85
+ - cuda-libraries=12.4.1=0
86
+ - cuda-nvrtc=12.4.127=0
87
+ - cuda-nvtx=12.4.127=0
88
+ - cuda-opencl=12.6.77=0
89
+ - cuda-runtime=12.4.1=0
90
+ - cuda-version=12.6=3
91
+ - cycler=0.12.1=pyhd8ed1ab_1
92
+ - dataclasses=0.8=pyhc8e2a94_3
93
+ - dataclasses-json=0.6.7=pyhd8ed1ab_1
94
+ - datasets=2.19.2=pyhd8ed1ab_0
95
+ - dav1d=1.2.1=hd590300_0
96
+ - dbus=1.13.6=h5008d03_3
97
+ - debugpy=1.8.11=py311hfdbb021_0
98
+ - decorator=5.1.1=pyhd8ed1ab_1
99
+ - deepspeed=0.16.2=cpu_py311hd0a74d0_0
100
+ - defusedxml=0.7.1=pyhd8ed1ab_0
101
+ - deprecated=1.2.15=pyhd8ed1ab_1
102
+ - dill=0.3.8=pyhd8ed1ab_0
103
+ - distro=1.9.0=pyhd8ed1ab_1
104
+ - docstring_parser=0.16=pyhd8ed1ab_0
105
+ - einops=0.8.0=pyhd8ed1ab_1
106
+ - entrypoints=0.4=pyhd8ed1ab_1
107
+ - eval_type_backport=0.2.2=pyha770c72_0
108
+ - evaluate=0.4.1=pyhd8ed1ab_0
109
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
110
+ - executing=2.1.0=pyhd8ed1ab_1
111
+ - expat=2.6.4=h5888daf_0
112
+ - ffmpeg=4.3=hf484d3e_0
113
+ - fftw=3.3.10=nompi_hf1063bd_110
114
+ - filelock=3.16.1=pyhd8ed1ab_1
115
+ - fmt=10.2.1=h00ab1b0_0
116
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
117
+ - font-ttf-inconsolata=3.000=h77eed37_0
118
+ - font-ttf-source-code-pro=2.038=h77eed37_0
119
+ - font-ttf-ubuntu=0.83=h77eed37_3
120
+ - fontconfig=2.14.2=h14ed4e7_0
121
+ - fonts-conda-ecosystem=1=0
122
+ - fonts-conda-forge=1=0
123
+ - fonttools=4.55.3=py311h2dc5d0c_1
124
+ - fqdn=1.5.1=pyhd8ed1ab_1
125
+ - freetype=2.12.1=h267a509_2
126
+ - frozendict=2.4.6=py311h9ecbd09_0
127
+ - frozenlist=1.5.0=py311h9ecbd09_0
128
+ - functorch=2.0.0=pyhd8ed1ab_0
129
+ - fvcore=0.1.5.post20221221=pyhd8ed1ab_0
130
+ - gdown=5.2.0=pyhd8ed1ab_1
131
+ - gettext=0.22.5=he02047a_3
132
+ - gettext-tools=0.22.5=he02047a_3
133
+ - gflags=2.2.2=h5888daf_1005
134
+ - giflib=5.2.2=hd590300_0
135
+ - glib=2.78.4=hfc55251_0
136
+ - glib-tools=2.78.4=hfc55251_0
137
+ - glog=0.6.0=h6f12383_0
138
+ - gmp=6.3.0=hac33072_2
139
+ - gmpy2=2.1.5=py311h0f6cedb_3
140
+ - gnutls=3.6.13=h85f3911_1
141
+ - grapheme=0.6.0=pyhd8ed1ab_1
142
+ - graphite2=1.3.13=h59595ed_1003
143
+ - greenlet=3.1.1=py311hfdbb021_1
144
+ - gst-plugins-base=1.22.0=h4243ec0_2
145
+ - gstreamer=1.22.0=h25f0c4b_2
146
+ - gstreamer-orc=0.4.40=hb9d3cd8_0
147
+ - h11=0.14.0=pyhd8ed1ab_1
148
+ - h2=4.1.0=pyhd8ed1ab_1
149
+ - harfbuzz=6.0.0=h8e241bc_0
150
+ - hjson-py=3.1.0=pyhd8ed1ab_1
151
+ - hpack=4.0.0=pyhd8ed1ab_1
152
+ - httpcore=1.0.7=pyh29332c3_1
153
+ - httpx=0.28.1=pyhd8ed1ab_0
154
+ - huggingface_hub=0.27.1=pyhd8ed1ab_0
155
+ - hyperframe=6.0.1=pyhd8ed1ab_1
156
+ - icu=70.1=h27087fc_0
157
+ - idna=3.10=pyhd8ed1ab_1
158
+ - imagecodecs=2023.1.23=py311ha5a3c35_0
159
+ - imageio=2.36.1=pyh12aca89_1
160
+ - importlib-metadata=8.5.0=pyha770c72_1
161
+ - importlib_metadata=8.5.0=hd8ed1ab_1
162
+ - importlib_resources=6.5.2=pyhd8ed1ab_0
163
+ - intel-openmp=2023.1.0=hdb19cb5_46306
164
+ - ipykernel=6.29.5=pyh3099207_0
165
+ - ipython=8.31.0=pyh707e725_0
166
+ - ipywidgets=8.1.5=pyhd8ed1ab_1
167
+ - isoduration=20.11.0=pyhd8ed1ab_1
168
+ - jack=1.9.22=h11f4161_0
169
+ - jedi=0.19.2=pyhd8ed1ab_1
170
+ - jinja2=3.1.5=pyhd8ed1ab_0
171
+ - jiter=0.8.2=py311h9e33e62_0
172
+ - joblib=1.4.2=pyhd8ed1ab_1
173
+ - jpeg=9e=h166bdaf_2
174
+ - json5=0.10.0=pyhd8ed1ab_1
175
+ - jsonpatch=1.33=pyhd8ed1ab_1
176
+ - jsonpointer=3.0.0=py311h38be061_1
177
+ - jsonschema=4.23.0=pyhd8ed1ab_1
178
+ - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
179
+ - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
180
+ - jupyter=1.1.1=pyhd8ed1ab_1
181
+ - jupyter-lsp=2.2.5=pyhd8ed1ab_1
182
+ - jupyter_client=8.6.3=pyhd8ed1ab_1
183
+ - jupyter_console=6.6.3=pyhd8ed1ab_1
184
+ - jupyter_core=5.7.2=pyh31011fe_1
185
+ - jupyter_events=0.11.0=pyhd8ed1ab_0
186
+ - jupyter_server=2.15.0=pyhd8ed1ab_0
187
+ - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
188
+ - jupyterlab=4.3.4=pyhd8ed1ab_0
189
+ - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
190
+ - jupyterlab_server=2.27.3=pyhd8ed1ab_1
191
+ - jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
192
+ - jxrlib=1.1=hd590300_3
193
+ - keyutils=1.6.1=h166bdaf_0
194
+ - kiwisolver=1.4.7=py311hd18a35c_0
195
+ - krb5=1.20.1=h81ceb04_0
196
+ - lame=3.100=h166bdaf_1003
197
+ - langchain=0.2.5=pyhd8ed1ab_0
198
+ - langchain-core=0.2.43=pyhd8ed1ab_0
199
+ - langchain-text-splitters=0.2.4=pyhd8ed1ab_0
200
+ - langsmith=0.1.147=pyhd8ed1ab_0
201
+ - lazy-loader=0.4=pyhd8ed1ab_2
202
+ - lazy_loader=0.4=pyhd8ed1ab_2
203
+ - lcms2=2.15=hfd0df8a_0
204
+ - ld_impl_linux-64=2.40=h12ee557_0
205
+ - lerc=4.0.0=h27087fc_0
206
+ - libabseil=20230125.3=cxx17_h59595ed_0
207
+ - libaec=1.1.3=h59595ed_0
208
+ - libaio=0.3.113=h166bdaf_0
209
+ - libarchive=3.6.2=h3d51595_0
210
+ - libarrow=13.0.0=hb9dc469_0_cpu
211
+ - libasprintf=0.22.5=he8f35ee_3
212
+ - libasprintf-devel=0.22.5=he8f35ee_3
213
+ - libavif=0.11.1=h8182462_2
214
+ - libblas=3.9.0=1_h86c2bf4_netlib
215
+ - libbrotlicommon=1.0.9=h166bdaf_9
216
+ - libbrotlidec=1.0.9=h166bdaf_9
217
+ - libbrotlienc=1.0.9=h166bdaf_9
218
+ - libcap=2.67=he9d0100_0
219
+ - libcblas=3.9.0=8_h3b12eaf_netlib
220
+ - libclang=15.0.7=default_h127d8a8_5
221
+ - libclang13=15.0.7=default_h5d6823c_5
222
+ - libcrc32c=1.1.2=h9c3ff4c_0
223
+ - libcublas=12.4.5.8=0
224
+ - libcufft=11.2.1.3=0
225
+ - libcufile=1.11.1.6=0
226
+ - libcups=2.3.3=h36d4200_3
227
+ - libcurand=10.3.7.77=0
228
+ - libcurl=8.11.1=hc9e6f67_0
229
+ - libcusolver=11.6.1.9=0
230
+ - libcusparse=12.3.1.170=0
231
+ - libdb=6.2.32=h9c3ff4c_0
232
+ - libdeflate=1.17=h0b41bf4_0
233
+ - libedit=3.1.20191231=he28a2e2_2
234
+ - libev=4.33=hd590300_2
235
+ - libevent=2.1.10=h28343ad_4
236
+ - libexpat=2.6.4=h5888daf_0
237
+ - libffi=3.4.4=h6a678d5_1
238
+ - libflac=1.4.3=h59595ed_0
239
+ - libgcc=14.2.0=h77fa898_1
240
+ - libgcc-ng=14.2.0=h69a702a_1
241
+ - libgcrypt=1.11.0=ha770c72_2
242
+ - libgcrypt-devel=1.11.0=hb9d3cd8_2
243
+ - libgcrypt-lib=1.11.0=hb9d3cd8_2
244
+ - libgcrypt-tools=1.11.0=hb9d3cd8_2
245
+ - libgettextpo=0.22.5=he02047a_3
246
+ - libgettextpo-devel=0.22.5=he02047a_3
247
+ - libgfortran=14.2.0=h69a702a_1
248
+ - libgfortran-ng=14.2.0=h69a702a_1
249
+ - libgfortran5=14.2.0=hd5240d6_1
250
+ - libglib=2.78.4=h783c2da_0
251
+ - libgomp=14.2.0=h77fa898_1
252
+ - libgoogle-cloud=2.12.0=h840a212_1
253
+ - libgpg-error=1.51=hbd13f7d_1
254
+ - libgrpc=1.56.2=h3905398_1
255
+ - libhwloc=2.9.1=hd6dc26d_0
256
+ - libiconv=1.17=hd590300_2
257
+ - libjpeg-turbo=2.0.0=h9bf148f_0
258
+ - liblapack=3.9.0=8_h3b12eaf_netlib
259
+ - libllvm15=15.0.7=hadd5161_1
260
+ - libmamba=1.5.1=h744094f_0
261
+ - libmambapy=1.5.1=py311hf2555c7_0
262
+ - libnghttp2=1.57.0=h2d74bed_0
263
+ - libnpp=12.2.5.30=0
264
+ - libnsl=2.0.1=hd590300_0
265
+ - libnuma=2.0.18=h4ab18f5_2
266
+ - libnvfatbin=12.6.77=0
267
+ - libnvjitlink=12.4.127=0
268
+ - libnvjpeg=12.3.1.117=0
269
+ - libogg=1.3.5=h4ab18f5_0
270
+ - libopus=1.3.1=h7f98852_1
271
+ - libpng=1.6.39=h5eee18b_0
272
+ - libpq=15.3=hbcd7760_1
273
+ - libprotobuf=4.23.3=hd1fb520_1
274
+ - libsentencepiece=0.1.99=h28b9611_1
275
+ - libsndfile=1.2.2=hc60ed4a_1
276
+ - libsodium=1.0.18=h36c2ea0_1
277
+ - libsolv=0.7.30=he621ea3_1
278
+ - libsqlite=3.46.0=hde9e2c9_0
279
+ - libssh2=1.11.1=h251f7ec_0
280
+ - libstdcxx=14.2.0=hc0a3c3a_1
281
+ - libstdcxx-ng=14.2.0=h4852527_1
282
+ - libsystemd0=253=h8c4010b_1
283
+ - libthrift=0.18.1=h5e4af38_0
284
+ - libtiff=4.5.0=h6adf6a1_2
285
+ - libtool=2.4.7=he02047a_1
286
+ - libudev1=253=h0b41bf4_1
287
+ - libutf8proc=2.8.0=hf23e847_1
288
+ - libuuid=2.38.1=h0b41bf4_0
289
+ - libvorbis=1.3.7=h9c3ff4c_0
290
+ - libwebp=1.2.4=h1daa5a0_1
291
+ - libwebp-base=1.2.4=h166bdaf_0
292
+ - libxcb=1.13=h7f98852_1004
293
+ - libxkbcommon=1.5.0=h79f4944_1
294
+ - libxml2=2.10.3=hca2bb57_4
295
+ - libzlib=1.2.13=h4ab18f5_6
296
+ - libzopfli=1.0.3=h9c3ff4c_0
297
+ - lightning=2.5.0.post0=pyhd8ed1ab_0
298
+ - lightning-utilities=0.11.9=pyhd8ed1ab_1
299
+ - llvm-openmp=12.0.1=h4bd325d_1
300
+ - lz4-c=1.9.4=hcb278e6_0
301
+ - lzo=2.10=hd590300_1001
302
+ - mamba=1.5.1=py311h3072747_0
303
+ - markdown=3.6=pyhd8ed1ab_0
304
+ - markdown-it-py=3.0.0=pyhd8ed1ab_1
305
+ - markupsafe=3.0.2=py311h2dc5d0c_1
306
+ - marshmallow=3.25.1=pyhd8ed1ab_0
307
+ - matplotlib=3.9.1=py311h38be061_1
308
+ - matplotlib-base=3.9.1=py311h74b4f7c_2
309
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
310
+ - mdurl=0.1.2=pyhd8ed1ab_1
311
+ - mistune=3.1.0=pyhd8ed1ab_0
312
+ - mkl=2023.1.0=h213fc3f_46344
313
+ - mpc=1.3.1=h24ddda3_1
314
+ - mpfr=4.2.1=h90cbb55_3
315
+ - mpg123=1.32.9=hc50e24c_0
316
+ - mpmath=1.3.0=pyhd8ed1ab_1
317
+ - msgpack-python=1.1.0=py311hd18a35c_0
318
+ - multidict=6.1.0=py311h2dc5d0c_2
319
+ - munkres=1.1.4=pyh9f0ad1d_0
320
+ - mypy_extensions=1.0.0=pyha770c72_1
321
+ - mysql-common=8.0.33=hf1915f5_6
322
+ - mysql-libs=8.0.33=hca2cd23_6
323
+ - nbclient=0.10.2=pyhd8ed1ab_0
324
+ - nbconvert-core=7.16.5=pyhd8ed1ab_1
325
+ - nbformat=5.10.4=pyhd8ed1ab_1
326
+ - ncurses=6.4=h6a678d5_0
327
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
328
+ - nettle=3.6=he412f7d_0
329
+ - networkx=3.4.2=pyh267e887_2
330
+ - ninja=1.12.1=h297d8ca_0
331
+ - nlohmann_json=3.11.3=he02047a_1
332
+ - nltk=3.9.1=pyhd8ed1ab_1
333
+ - notebook=7.3.2=pyhd8ed1ab_0
334
+ - notebook-shim=0.2.4=pyhd8ed1ab_1
335
+ - nspr=4.36=h5888daf_0
336
+ - nss=3.100=hca3bf56_0
337
+ - numpy=1.26.4=py311h64a7726_0
338
+ - nvidia-ml-py=12.560.30=pyhd8ed1ab_1
339
+ - nvitop=1.4.1=pyh707e725_1
340
+ - omegaconf=2.3.0=pyhd8ed1ab_0
341
+ - opacus=1.5.2=pyhd8ed1ab_1
342
+ - openai=1.59.7=pyhd8ed1ab_0
343
+ - openh264=2.1.1=h780b84a_0
344
+ - openjpeg=2.5.0=hfec8fc6_2
345
+ - openssl=3.1.7=hb9d3cd8_0
346
+ - opt-einsum=3.4.0=hd8ed1ab_1
347
+ - opt_einsum=3.4.0=pyhd8ed1ab_1
348
+ - orc=1.9.0=h385abfd_1
349
+ - orjson=3.10.14=py311h9e33e62_0
350
+ - overrides=7.7.0=pyhd8ed1ab_1
351
+ - packaging=24.2=pyhd8ed1ab_2
352
+ - pandas=2.2.3=py311h7db5c69_1
353
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
354
+ - parso=0.8.4=pyhd8ed1ab_1
355
+ - patsy=1.0.1=pyhd8ed1ab_1
356
+ - pcre2=10.42=hebb0a14_1
357
+ - peft=0.14.0=pyhd8ed1ab_0
358
+ - pexpect=4.9.0=pyhd8ed1ab_1
359
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
360
+ - pillow=9.4.0=py311h6a678d5_0
361
+ - pip=24.2=py311h06a4308_0
362
+ - pixman=0.44.2=h29eaf8c_0
363
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
364
+ - platformdirs=4.3.6=pyhd8ed1ab_1
365
+ - pluggy=1.5.0=pyhd8ed1ab_1
366
+ - ply=3.11=pyhd8ed1ab_3
367
+ - portalocker=3.0.0=py311h38be061_0
368
+ - prometheus_client=0.21.1=pyhd8ed1ab_0
369
+ - prompt-toolkit=3.0.48=pyha770c72_1
370
+ - prompt_toolkit=3.0.48=hd8ed1ab_1
371
+ - propcache=0.2.1=py311h9ecbd09_0
372
+ - psutil=6.1.1=py311h9ecbd09_0
373
+ - pthread-stubs=0.4=hb9d3cd8_1002
374
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
375
+ - pulseaudio=16.1=hcb278e6_3
376
+ - pulseaudio-client=16.1=h5195f5e_3
377
+ - pulseaudio-daemon=16.1=ha8d29e2_3
378
+ - pure_eval=0.2.3=pyhd8ed1ab_1
379
+ - py-cpuinfo=9.0.0=pyhd8ed1ab_1
380
+ - pyarrow=13.0.0=py311h39c9aba_0_cpu
381
+ - pyarrow-hotfix=0.6=pyhd8ed1ab_1
382
+ - pybind11-abi=4=hd8ed1ab_3
383
+ - pycosat=0.6.6=py311h9ecbd09_2
384
+ - pycparser=2.22=pyh29332c3_1
385
+ - pydantic=2.10.5=pyh3cfb1c2_0
386
+ - pydantic-core=2.27.2=py311h9e33e62_0
387
+ - pygments=2.19.1=pyhd8ed1ab_0
388
+ - pymoo=0.6.1.3=py311h7db5c69_0
389
+ - pynvml=12.0.0=pyhd8ed1ab_0
390
+ - pyopenssl=24.3.0=pyhd8ed1ab_0
391
+ - pyparsing=3.2.1=pyhd8ed1ab_0
392
+ - pypdf=3.17.4=pyhd8ed1ab_0
393
+ - pyqt=5.15.9=py311hf0fb5b6_5
394
+ - pyqt5-sip=12.12.2=py311hb755f60_5
395
+ - pysocks=1.7.1=pyha55dd90_7
396
+ - python=3.11.6=hab00c5b_0_cpython
397
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
398
+ - python-dotenv=1.0.1=pyhd8ed1ab_1
399
+ - python-fastjsonschema=2.21.1=pyhd8ed1ab_0
400
+ - python-json-logger=2.0.7=pyhd8ed1ab_0
401
+ - python-slugify=8.0.4=pyhd8ed1ab_1
402
+ - python-tzdata=2024.2=pyhd8ed1ab_1
403
+ - python-xxhash=3.5.0=py311h9ecbd09_1
404
+ - python_abi=3.11=2_cp311
405
+ - pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0
406
+ - pytorch-cuda=12.4=hc786d27_7
407
+ - pytorch-lightning=2.5.0.post0=pyh101cb37_0
408
+ - pytorch-mutex=1.0=cuda
409
+ - pytz=2024.1=pyhd8ed1ab_0
410
+ - pywavelets=1.8.0=py311h9f3472d_0
411
+ - pyyaml=6.0.2=py311h9ecbd09_1
412
+ - pyzmq=26.2.0=py311h7deb3e3_0
413
+ - qhull=2020.2=h434a139_5
414
+ - qt-main=5.15.8=h5d23da1_6
415
+ - rdma-core=28.9=h59595ed_1
416
+ - re2=2023.03.02=h8c504da_0
417
+ - readline=8.2=h5eee18b_0
418
+ - referencing=0.35.1=pyhd8ed1ab_1
419
+ - regex=2024.11.6=py311h9ecbd09_0
420
+ - reproc=14.2.5.post0=hb9d3cd8_0
421
+ - reproc-cpp=14.2.5.post0=h5888daf_0
422
+ - requests=2.32.3=pyhd8ed1ab_1
423
+ - requests-toolbelt=1.0.0=pyhd8ed1ab_1
424
+ - responses=0.18.0=pyhd8ed1ab_0
425
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_1
426
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
427
+ - rich=13.9.4=pyhd8ed1ab_1
428
+ - rpds-py=0.22.3=py311h9e33e62_0
429
+ - ruamel.yaml=0.17.40=py311h459d7ec_0
430
+ - ruamel.yaml.clib=0.2.8=py311h9ecbd09_1
431
+ - s2n=1.3.51=h06160fa_0
432
+ - sacrebleu=2.1.0=pyhd8ed1ab_0
433
+ - sacremoses=0.0.53=pyhd8ed1ab_0
434
+ - safetensors=0.5.2=py311h9e33e62_0
435
+ - scikit-image=0.25.0=py311h7db5c69_0
436
+ - scikit-learn=1.6.1=py311h57cc02b_0
437
+ - scipy=1.15.1=py311hc1ac118_0
438
+ - seaborn=0.13.2=hd8ed1ab_3
439
+ - seaborn-base=0.13.2=pyhd8ed1ab_3
440
+ - send2trash=1.8.3=pyh0d859eb_1
441
+ - sentence-transformers=2.7.0=pyhd8ed1ab_0
442
+ - sentencepiece=0.1.99=h38be061_1
443
+ - sentencepiece-python=0.1.99=py311hf03188e_1
444
+ - sentencepiece-spm=0.1.99=h28b9611_1
445
+ - setuptools=75.1.0=py311h06a4308_0
446
+ - shtab=1.7.1=pyhd8ed1ab_1
447
+ - simdjson=3.11.5=h84d6215_0
448
+ - sip=6.7.12=py311hb755f60_0
449
+ - six=1.17.0=pyhd8ed1ab_0
450
+ - snappy=1.1.10=hdb0a2a9_1
451
+ - sniffio=1.3.1=pyhd8ed1ab_1
452
+ - soupsieve=2.5=pyhd8ed1ab_1
453
+ - spdlog=1.14.1=h597fd29_0
454
+ - sqlalchemy=2.0.37=py311h9ecbd09_0
455
+ - sqlite=3.45.3=h5eee18b_0
456
+ - stack_data=0.6.3=pyhd8ed1ab_1
457
+ - statsmodels=0.14.4=py311h9f3472d_0
458
+ - tabulate=0.9.0=pyhd8ed1ab_2
459
+ - tbb=2021.9.0=hf52228f_0
460
+ - tenacity=8.5.0=pyhd8ed1ab_0
461
+ - tensorboard=2.18.0=pyhd8ed1ab_1
462
+ - tensorboard-data-server=0.7.0=py311h63ff55d_1
463
+ - termcolor=2.5.0=pyhd8ed1ab_1
464
+ - terminado=0.18.1=pyh0d859eb_0
465
+ - text-unidecode=1.3=pyhd8ed1ab_2
466
+ - threadpoolctl=3.5.0=pyhc1e730c_0
467
+ - tifffile=2023.8.12=pyhd8ed1ab_0
468
+ - tinycss2=1.4.0=pyhd8ed1ab_0
469
+ - tk=8.6.14=h39e8969_0
470
+ - tokenizers=0.13.3=py311h1b04a43_0
471
+ - toml=0.10.2=pyhd8ed1ab_1
472
+ - tomli=2.2.1=pyhd8ed1ab_1
473
+ - toolz=1.0.0=pyhd8ed1ab_1
474
+ - torchaudio=2.5.1=py311_cu124
475
+ - torchmetrics=1.6.1=pyhd8ed1ab_0
476
+ - torchtriton=3.1.0=py311
477
+ - torchvision=0.20.1=py311_cu124
478
+ - tornado=6.4.2=py311h9ecbd09_0
479
+ - tqdm=4.67.1=pyhd8ed1ab_1
480
+ - traitlets=5.14.3=pyhd8ed1ab_1
481
+ - transformers=4.33.3=pyhd8ed1ab_0
482
+ - transforms3d=0.4.2=pyhd8ed1ab_1
483
+ - trl=0.10.1=pyhd8ed1ab_0
484
+ - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
485
+ - typing=3.10.0.0=pyhd8ed1ab_2
486
+ - typing-extensions=4.12.2=hd8ed1ab_1
487
+ - typing_extensions=4.12.2=pyha770c72_1
488
+ - typing_inspect=0.9.0=pyhd8ed1ab_1
489
+ - typing_utils=0.1.0=pyhd8ed1ab_1
490
+ - tyro=0.9.1=pyhff2d567_0
491
+ - tzdata=2024b=h04d1e81_0
492
+ - ucx=1.14.1=h195a15c_5
493
+ - unicodedata2=16.0.0=py311h9ecbd09_0
494
+ - uri-template=1.3.0=pyhd8ed1ab_1
495
+ - urllib3=2.3.0=pyhd8ed1ab_0
496
+ - wcwidth=0.2.13=pyhd8ed1ab_1
497
+ - webcolors=24.11.1=pyhd8ed1ab_0
498
+ - webencodings=0.5.1=pyhd8ed1ab_3
499
+ - websocket-client=1.8.0=pyhd8ed1ab_1
500
+ - werkzeug=3.1.3=pyhd8ed1ab_1
501
+ - wheel=0.44.0=py311h06a4308_0
502
+ - widgetsnbextension=4.0.13=pyhd8ed1ab_1
503
+ - wrapt=1.17.1=py311h9ecbd09_0
504
+ - xcb-util=0.4.0=h516909a_0
505
+ - xcb-util-image=0.4.0=h166bdaf_0
506
+ - xcb-util-keysyms=0.4.0=h516909a_0
507
+ - xcb-util-renderutil=0.3.9=h166bdaf_0
508
+ - xcb-util-wm=0.4.1=h516909a_0
509
+ - xformers=0.0.28.post3=py311_cu12.1.0_pyt2.5.1
510
+ - xkeyboard-config=2.38=h0b41bf4_0
511
+ - xorg-kbproto=1.0.7=hb9d3cd8_1003
512
+ - xorg-libice=1.1.2=hb9d3cd8_0
513
+ - xorg-libsm=1.2.5=he73a12e_0
514
+ - xorg-libx11=1.8.4=h0b41bf4_0
515
+ - xorg-libxau=1.0.12=hb9d3cd8_0
516
+ - xorg-libxdmcp=1.1.5=hb9d3cd8_0
517
+ - xorg-libxext=1.3.4=h0b41bf4_2
518
+ - xorg-libxrender=0.9.10=h7f98852_1003
519
+ - xorg-renderproto=0.11.1=hb9d3cd8_1003
520
+ - xorg-xextproto=7.3.0=hb9d3cd8_1004
521
+ - xorg-xproto=7.0.31=hb9d3cd8_1008
522
+ - xxhash=0.8.2=hd590300_0
523
+ - xz=5.4.6=h5eee18b_1
524
+ - yacs=0.1.8=pyhd8ed1ab_1
525
+ - yaml=0.2.5=h7f98852_2
526
+ - yaml-cpp=0.7.0=h59595ed_3
527
+ - yarl=1.18.3=py311h9ecbd09_0
528
+ - zeromq=4.3.5=h59595ed_1
529
+ - zfp=1.0.1=h5888daf_2
530
+ - zipp=3.21.0=pyhd8ed1ab_1
531
+ - zlib=1.2.13=h4ab18f5_6
532
+ - zlib-ng=2.0.7=h0b41bf4_0
533
+ - zstandard=0.23.0=py311hbc35293_1
534
+ - zstd=1.5.6=hc292b87_0
535
+ - pip:
536
+ - bitsandbytes==0.45.0
537
+ - cryptography==42.0.8
538
+ - cupy-cuda12x==13.3.0
539
+ - docker-pycreds==0.4.0
540
+ - faiss-cpu==1.9.0.post1
541
+ - fastrlock==0.8.3
542
+ - flash-attn==2.6.3
543
+ - flwr==1.14.0
544
+ - flwr-datasets==0.5.0
545
+ - fsspec==2024.12.0
546
+ - gitdb==4.0.12
547
+ - gitpython==3.1.44
548
+ - grpcio==1.64.3
549
+ - iterators==0.0.2
550
+ - multiprocess==0.70.16
551
+ - pathspec==0.12.1
552
+ - protobuf==4.25.5
553
+ - prv-accountant==0.2.0
554
+ - pycryptodome==3.21.0
555
+ - ray==2.10.0
556
+ - rouge-score==0.1.2
557
+ - sentry-sdk==2.19.2
558
+ - setproctitle==1.3.4
559
+ - shellingham==1.5.4
560
+ - smmap==5.0.2
561
+ - sympy==1.13.1
562
+ - thop==0.1.1-2209072238
563
+ - tomli-w==1.1.0
564
+ - typer==0.12.5
565
+ - wandb==0.19.3
template_FL/src/ex.env.example ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ WANDB_API_KEY = ""
2
+ WANDB_NAME = "FL@CSS25"
3
+ HF_TOKEN = ""
template_FL/src/fedllm/Untitled.ipynb ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5f1cdd6d-0f6c-447c-8b11-665d0201215e",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [
11
+ {
12
+ "ename": "ModuleNotFoundError",
13
+ "evalue": "No module named 'flwr_datasets'",
14
+ "output_type": "error",
15
+ "traceback": [
16
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
17
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
18
+ "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mflwr_datasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpartitioner\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m IidPartitioner\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mflwr_datasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m FederatedDataset\n\u001b[1;32m 5\u001b[0m partitioner \u001b[38;5;241m=\u001b[39m IidPartitioner(num_partitions\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n",
19
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'flwr_datasets'"
20
+ ]
21
+ }
22
+ ],
23
+ "source": [
24
+ "from flwr_datasets.partitioner import IidPartitioner\n",
25
+ "from flwr_datasets import FederatedDataset\n",
26
+ "\n",
27
+ "\n",
28
+ "partitioner = IidPartitioner(num_partitions=10)\n",
29
+ "FDS = FederatedDataset(\n",
30
+ " dataset=\"vicgalle/alpaca-gpt4\",\n",
31
+ " partitioners={\"train\": partitioner},\n",
32
+ ")\n",
33
+ "client_trainset = FDS.load_partition(1, \"train\")\n",
34
+ "client_trainset = client_trainset.rename_column(\"output\", \"response\")\n",
35
+ "client_trainset"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 12,
41
+ "id": "fecf675d-1279-4fbd-a4e3-15d245582df4",
42
+ "metadata": {
43
+ "tags": []
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "import datasets\n",
48
+ "from datasets import load_dataset, DatasetDict\n",
49
+ "import pandas as pd\n",
50
+ "from functools import partial\n",
51
+ "from sklearn.model_selection import train_test_split\n",
52
+ "\n",
53
+ "\n",
54
+ "def get_dataset(dataset_name, local_data_dir=None):\n",
55
+ "\n",
56
+ " if dataset_name in [\"gsm8k\"]:\n",
57
+ " dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
58
+ " dataset = load_dataset(dataset_name, name=\"main\")\n",
59
+ " elif dataset_name in [\"lighteval/MATH\"]:\n",
60
+ " dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
61
+ " dataset = load_dataset(dataset_name, name=\"all\")\n",
62
+ " else:\n",
63
+ " dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
64
+ " dataset = load_dataset(dataset_name)\n",
65
+ "\n",
66
+ " return dataset\n",
67
+ "\n",
68
+ "# Function to split a dataset dictionary into two 50/50 parts\n",
69
+ "def split_dataset_50_50(dataset_dict):\n",
70
+ " split_datasets = {\n",
71
+ " 'ds_1': None,\n",
72
+ " 'ds_2': None\n",
73
+ " }\n",
74
+ " for split in ['train', 'valid', 'test']:\n",
75
+ " if split in dataset_dict:\n",
76
+ " dataset_split_1, dataset_split_2 = train_test_split(\n",
77
+ " dataset_dict[split], test_size=0.5, shuffle=True, seed=42\n",
78
+ " )\n",
79
+ " print(f\">> ===== After split, Dataset1 {split} has {len(dataset_split_1)} examples. =====\")\n",
80
+ " print(f\">> ===== After split, Dataset2 {split} has {len(dataset_split_2)} examples. =====\")\n",
81
+ " split_datasets['ds_1'][split] = dataset_split_1\n",
82
+ " split_datasets['ds_2'][split] = dataset_split_2\n",
83
+ " return DatasetDict(split_datasets['ds_1']), DatasetDict(split_datasets['ds_2'])\n",
84
+ "\n",
85
+ "\n",
86
+ "def process_sft_dataset(dataset_name, dataset, dataset_sample):\n",
87
+ " if dataset_name in [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"yahma/alpaca-cleaned\", \"FinGPT/fingpt-sentiment-train\"]:\n",
88
+ " dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
89
+ " elif dataset_name in [\"WizardLM/WizardLM_evol_instruct_70k\"]:\n",
90
+ " dataset = dataset.rename_column(\"output\", \"response\")\n",
91
+ " elif dataset_name in [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\", \"gbharti/finance-alpaca\"]:\n",
92
+ " dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
93
+ " elif dataset_name in [\"TIGER-Lab/MathInstruct\"]:\n",
94
+ " df = pd.DataFrame(dataset)\n",
95
+ " df = df.drop_duplicates(subset=['instruction'])\n",
96
+ " dataset = datasets.Dataset.from_pandas(df)\n",
97
+ " dataset = dataset.rename_column(\"output\", \"response\")\n",
98
+ " dataset = dataset.remove_columns(['source'])\n",
99
+ " elif dataset_name in [\"lighteval/MATH\"]:\n",
100
+ " dataset = dataset.rename_column(\"solution\", \"response\")\n",
101
+ " dataset = dataset.rename_column(\"problem\", \"instruction\")\n",
102
+ " dataset = dataset.remove_columns(['level', 'type'])\n",
103
+ " elif dataset_name in ['gsm8k']:\n",
104
+ " dataset = dataset.rename_column(\"question\", \"instruction\")\n",
105
+ " dataset = dataset.rename_column(\"answer\", \"response\")\n",
106
+ " elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: # TODO: 'lavita/ChatDoctor-HealthCareMagic-100k'. not sure whether to discard the instruction.\n",
107
+ " dataset = dataset.remove_columns(['instruction'])\n",
108
+ " dataset = dataset.rename_column(\"input\", \"instruction\")\n",
109
+ " dataset = dataset.rename_column(\"output\", \"response\")\n",
110
+ " else:\n",
111
+ " raise NotImplementedError(f\"Dataset {dataset_name} is not supported.\")\n",
112
+ " dataset = dataset.shuffle(seed=2023)\n",
113
+ " if dataset_sample:\n",
114
+ " num_sample = min(len(dataset), dataset_sample)\n",
115
+ " dataset = dataset.select(range(num_sample))\n",
116
+ " print(f\">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====\")\n",
117
+ " print(f\">> ===== Spliting two parts datasets =====\")\n",
118
+ " \n",
119
+ " if len(dataset['train']) > 10000 and len(dataset['test']) >= 2000:\n",
120
+ " dataset = split_dataset_50_50(dataset)\n",
121
+ " return dataset\n",
122
+ "\n",
123
+ "def alpaca_format(example):\n",
124
+ " if example['input'] == \"\":\n",
125
+ " example[\"instruction\"] = example[\"instruction\"]\n",
126
+ " else:\n",
127
+ " example[\"instruction\"] = example[\"instruction\"] + \" \" + example['input']\n",
128
+ " example[\"response\"] = example['output']\n",
129
+ " return example\n",
130
+ "\n",
131
+ "\n"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 25,
137
+ "id": "73d16824-ec97-4e5f-87bc-c432253b93c5",
138
+ "metadata": {
139
+ "tags": []
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "import datasets\n",
144
+ "import pandas as pd\n",
145
+ "from datasets import Dataset, DatasetDict, load_dataset\n",
146
+ "from sklearn.model_selection import train_test_split\n",
147
+ "from functools import partial\n",
148
+ "\n",
149
+ "\n",
150
+ "class DatasetAbstract:\n",
151
+ " def __init__(self, dataset_name: list[str], category: str):\n",
152
+ " self.dataset_name = dataset_name\n",
153
+ " self.metadata = {\n",
154
+ " 'domain': category\n",
155
+ " }\n",
156
+ " \n",
157
+ " def _processing_data(self):\n",
158
+ " pass\n",
159
+ " \n",
160
+ " @classmethod\n",
161
+ " def get_dataset(cls, dataset_name, local_data_dir=None):\n",
162
+ " if dataset_name in [\"gsm8k\"]:\n",
163
+ " dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
164
+ " dataset = load_dataset(dataset_name, name=\"main\")\n",
165
+ " else:\n",
166
+ " dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name\n",
167
+ " dataset = load_dataset(dataset_name)\n",
168
+ " \n",
169
+ " return dataset\n",
170
+ " \n",
171
+ " def get_split_dataset(self, dataset):\n",
172
+ " print(f\">> ===== After processing, Dataset has {len(dataset)} examples. =====\")\n",
173
+ " if len(dataset) > 10000:\n",
174
+ " ds_part1, ds_part2 = train_test_split(\n",
175
+ " dataset, test_size=0.5, shuffle=True, random_state=42\n",
176
+ " )\n",
177
+ " print(f\">> ===== After split, Dataset1 has {len(ds_part1)} examples and Dataset2 has {len(ds_part2)} examples. =====\")\n",
178
+ " list_dataset = []\n",
179
+ " for subset in [ds_part1, ds_part2]:\n",
180
+ " train, test = train_test_split(\n",
181
+ " subset, test_size=0.2, shuffle=True, random_state=42\n",
182
+ " )\n",
183
+ " ds = DatasetDict({\n",
184
+ " \"train\": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),\n",
185
+ " \"test\": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])\n",
186
+ " })\n",
187
+ " list_dataset.append(ds)\n",
188
+ " return list_dataset\n",
189
+ " \n",
190
+ " else:\n",
191
+ " train, test = train_test_split(\n",
192
+ " dataset , test_size=0.2, shuffle=True, random_state=42\n",
193
+ " )\n",
194
+ " ds = DatasetDict(\n",
195
+ " {\n",
196
+ " \"train\": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),\n",
197
+ " \"test\": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])\n",
198
+ " }\n",
199
+ " )\n",
200
+ " return [ds]\n",
201
+ "\n",
202
+ " \n",
203
+ "class GeneralDataset(DatasetAbstract):\n",
204
+ " \n",
205
+ " def __init__(self):\n",
206
+ " list_dataset = [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\"]\n",
207
+ " super().__init__(list_dataset, 'general')\n",
208
+ " self._processing_data()\n",
209
+ " \n",
210
+ " def _processing_data(self):\n",
211
+ " datasets = []\n",
212
+ " for dataset_name in self.dataset_name:\n",
213
+ " datasets.append(\n",
214
+ " pd.DataFrame(super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train'])\n",
215
+ " )\n",
216
+ " dataset = pd.concat(datasets, ignore_index=True)\n",
217
+ " self.list_dataset = self.get_split_dataset(dataset)\n",
218
+ " \n",
219
+ " \n",
220
+ "\n",
221
+ "class FinanceDataset(DatasetAbstract):\n",
222
+ " \n",
223
+ " def __init__(self):\n",
224
+ " list_dataset = [\"gbharti/finance-alpaca\", \"FinGPT/fingpt-sentiment-train\"]\n",
225
+ " super().__init__(list_dataset, 'finance')\n",
226
+ " \n",
227
+ " self._processing_data()\n",
228
+ " \n",
229
+ " def _processing_data(self):\n",
230
+ " datasets = []\n",
231
+ " for dataset_name in self.dataset_name:\n",
232
+ " ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
233
+ " if dataset_name == 'gbharti/finance-alpaca':\n",
234
+ " ds = ds.remove_columns(['text'])\n",
235
+ " df = pd.DataFrame(ds)\n",
236
+ " datasets.append(df)\n",
237
+ " dataset = pd.concat(datasets, ignore_index=True)\n",
238
+ " self.list_dataset = self.get_split_dataset(dataset)\n",
239
+ " \n",
240
+ "\n",
241
+ "class MathDataset(DatasetAbstract):\n",
242
+ " \n",
243
+ " def __init__(self):\n",
244
+ " list_dataset = [\"TIGER-Lab/MathInstruct\", \"xDAN2099/lighteval-MATH\", \"gsm8k\"]\n",
245
+ " super().__init__(list_dataset, 'math')\n",
246
+ " self._processing_data()\n",
247
+ " \n",
248
+ " \n",
249
+ " def get_split_dataset(self, dataset):\n",
250
+ " dataset_train, dataset_test = dataset[0], dataset[1]\n",
251
+ " print(f\">> ===== After processing, Dataset has {len(dataset_train)} examples. =====\")\n",
252
+ " if len(dataset_train) > 10000:\n",
253
+ " ds_train_part1, ds_train_part2 = train_test_split(\n",
254
+ " dataset_train, test_size=0.5, shuffle=True, random_state=42\n",
255
+ " )\n",
256
+ " ds_test_part1, ds_test_part2 = train_test_split(\n",
257
+ " dataset_test, test_size=0.5, shuffle=True, random_state=42\n",
258
+ " )\n",
259
+ " print(f\">> ===== After split, Dataset1 has {len(ds_train_part1)} examples and Dataset2 has {len(ds_train_part2)} examples. =====\")\n",
260
+ " list_dataset = []\n",
261
+ " for i in range(2):\n",
262
+ " ds = DatasetDict({\n",
263
+ " \"train\": Dataset.from_pandas(eval(f'ds_train_part{i+1}')).remove_columns(['__index_level_0__']), \n",
264
+ " \"test\": Dataset.from_pandas(eval(f'ds_test_part{i+1}')).remove_columns(['__index_level_0__'])\n",
265
+ " })\n",
266
+ " list_dataset.append(ds)\n",
267
+ " return list_dataset\n",
268
+ " \n",
269
+ " else:\n",
270
+ " ds = DatasetDict(\n",
271
+ " {\n",
272
+ " \"train\": Dataset.from_pandas(dataset_train).remove_columns(['__index_level_0__']),\n",
273
+ " \"test\": Dataset.from_pandas(dataset_test).remove_columns(['__index_level_0__'])\n",
274
+ " }\n",
275
+ " )\n",
276
+ " \n",
277
+ " def _processing_data(self):\n",
278
+ " datasets_train, datasets_test = [], []\n",
279
+ " for dataset_name in self.dataset_name:\n",
280
+ " ds_tmp = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)\n",
281
+ " if dataset_name == 'TIGER-Lab/MathInstruct':\n",
282
+ " df = pd.DataFrame(ds_tmp['train'])\n",
283
+ " df = df.drop_duplicates(subset=['instruction'])\n",
284
+ " df = df.drop(['source'], axis=1)\n",
285
+ " df_train, df_test = train_test_split(df, test_size=0.3, shuffle=True, random_state=42)\n",
286
+ " \n",
287
+ " elif dataset_name == \"xDAN2099/lighteval-MATH\":\n",
288
+ " ds_tmp = ds_tmp.remove_columns(['level', 'type'])\n",
289
+ " ds_tmp = ds_tmp.rename_column(\"solution\", \"output\")\n",
290
+ " ds_tmp = ds_tmp.rename_column(\"problem\", \"instruction\")\n",
291
+ " df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])\n",
292
+ " \n",
293
+ " elif dataset_name == 'gsm8k':\n",
294
+ " ds_tmp = ds_tmp.rename_column(\"answer\", \"output\")\n",
295
+ " ds_tmp = ds_tmp.rename_column(\"question\", \"instruction\")\n",
296
+ " df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])\n",
297
+ " \n",
298
+ " df_train['input'] = [''] * len(df_train)\n",
299
+ " df_test['input'] = [''] * len(df_test)\n",
300
+ " datasets_train.append(df_train)\n",
301
+ " datasets_test.append(df_test)\n",
302
+ " \n",
303
+ " dataset_train = pd.concat(datasets_train, ignore_index=True)\n",
304
+ " dataset_test = pd.concat(datasets_test, ignore_index=True)\n",
305
+ " dataset = [dataset_train, dataset_test]\n",
306
+ " self.list_dataset = self.get_split_dataset(dataset)\n",
307
+ " \n",
308
+ "\n",
309
+ "class MedicalDataset(DatasetAbstract):\n",
310
+ " \n",
311
+ " def __init__(self):\n",
312
+ " list_dataset = [\"medalpaca/medical_meadow_medical_flashcards\"]\n",
313
+ " super().__init__(list_dataset, 'medical')\n",
314
+ " self._processing_data()\n",
315
+ " \n",
316
+ " def _processing_data(self):\n",
317
+ " datasets = []\n",
318
+ " for dataset_name in self.dataset_name:\n",
319
+ " ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
320
+ " if dataset_name == 'medalpaca/medical_meadow_medical_flashcards':\n",
321
+ " ds = ds.remove_columns(['instruction'])\n",
322
+ " ds = ds.rename_column(\"input\", \"instruction\")\n",
323
+ " \n",
324
+ " df = pd.DataFrame(ds)\n",
325
+ " df['input'] = [''] * len(df)\n",
326
+ " datasets.append(df)\n",
327
+ " dataset = pd.concat(datasets, ignore_index=True)\n",
328
+ " self.list_dataset = self.get_split_dataset(dataset)\n",
329
+ " \n",
330
+ "class CodeDataset(DatasetAbstract):\n",
331
+ " \n",
332
+ " def __init__(self):\n",
333
+ " list_dataset = [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"WizardLMTeam/WizardLM_evol_instruct_70k\"]\n",
334
+ " super().__init__(list_dataset, 'code')\n",
335
+ " self._processing_data()\n",
336
+ " \n",
337
+ " def _processing_data(self):\n",
338
+ " datasets = []\n",
339
+ " for dataset_name in self.dataset_name:\n",
340
+ " ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']\n",
341
+ " df = pd.DataFrame(ds)\n",
342
+ " if dataset_name == 'WizardLMTeam/WizardLM_evol_instruct_70k':\n",
343
+ " df['input'] = [''] * len(df)\n",
344
+ " datasets.append(df)\n",
345
+ " dataset = pd.concat(datasets, ignore_index=True)\n",
346
+ " self.list_dataset = self.get_split_dataset(dataset)\n",
347
+ " \n",
348
+ "client_id_dataset = {\n",
349
+ " '1': GeneralDataset().list_dataset[0],\n",
350
+ " '2': GeneralDataset().list_dataset[1],\n",
351
+ " '3': FinanceDataset().list_dataset[0],\n",
352
+ " '4': FinanceDataset().list_dataset[1],\n",
353
+ " '5': MathDataset().list_dataset[0],\n",
354
+ " '6': MathDataset().list_dataset[1],\n",
355
+ " '7': MedicalDataset().list_dataset[0],\n",
356
+ " '8': MedicalDataset().list_dataset[1],\n",
357
+ " '9': CodeDataset().list_dataset[0],\n",
358
+ " '10': CodeDataset().list_dataset[1],\n",
359
+ "}"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 27,
365
+ "id": "85a86706-fd9b-4618-b6e0-b6b2743d1246",
366
+ "metadata": {
367
+ "tags": []
368
+ },
369
+ "outputs": [
370
+ {
371
+ "data": {
372
+ "application/vnd.jupyter.widget-view+json": {
373
+ "model_id": "a5ca5da771024ac5acffe25a7e625d6e",
374
+ "version_major": 2,
375
+ "version_minor": 0
376
+ },
377
+ "text/plain": [
378
+ "README.md: 0%| | 0.00/709 [00:00<?, ?B/s]"
379
+ ]
380
+ },
381
+ "metadata": {},
382
+ "output_type": "display_data"
383
+ },
384
+ {
385
+ "data": {
386
+ "application/vnd.jupyter.widget-view+json": {
387
+ "model_id": "33deb82850224a00912678f33da2f5c7",
388
+ "version_major": 2,
389
+ "version_minor": 0
390
+ },
391
+ "text/plain": [
392
+ "Cleaned_date.json: 0%| | 0.00/42.9M [00:00<?, ?B/s]"
393
+ ]
394
+ },
395
+ "metadata": {},
396
+ "output_type": "display_data"
397
+ },
398
+ {
399
+ "data": {
400
+ "application/vnd.jupyter.widget-view+json": {
401
+ "model_id": "2161bc54748c4eaf990d85c8dfb1d616",
402
+ "version_major": 2,
403
+ "version_minor": 0
404
+ },
405
+ "text/plain": [
406
+ "Generating train split: 0%| | 0/68912 [00:00<?, ? examples/s]"
407
+ ]
408
+ },
409
+ "metadata": {},
410
+ "output_type": "display_data"
411
+ },
412
+ {
413
+ "data": {
414
+ "application/vnd.jupyter.widget-view+json": {
415
+ "model_id": "f0abe4e86e9b40e7a64dbd4ba80a1b3d",
416
+ "version_major": 2,
417
+ "version_minor": 0
418
+ },
419
+ "text/plain": [
420
+ "README.md: 0%| | 0.00/529 [00:00<?, ?B/s]"
421
+ ]
422
+ },
423
+ "metadata": {},
424
+ "output_type": "display_data"
425
+ },
426
+ {
427
+ "data": {
428
+ "application/vnd.jupyter.widget-view+json": {
429
+ "model_id": "4b662fd05288467b90064ce383f5ecea",
430
+ "version_major": 2,
431
+ "version_minor": 0
432
+ },
433
+ "text/plain": [
434
+ "(…)-00000-of-00001-dabab110260ac909.parquet: 0%| | 0.00/6.42M [00:00<?, ?B/s]"
435
+ ]
436
+ },
437
+ "metadata": {},
438
+ "output_type": "display_data"
439
+ },
440
+ {
441
+ "data": {
442
+ "application/vnd.jupyter.widget-view+json": {
443
+ "model_id": "5ee6a13f60124db4af266207872fc6fb",
444
+ "version_major": 2,
445
+ "version_minor": 0
446
+ },
447
+ "text/plain": [
448
+ "Generating train split: 0%| | 0/76772 [00:00<?, ? examples/s]"
449
+ ]
450
+ },
451
+ "metadata": {},
452
+ "output_type": "display_data"
453
+ },
454
+ {
455
+ "name": "stdout",
456
+ "output_type": "stream",
457
+ "text": [
458
+ ">> ===== After processing, Dataset has 145684 examples. =====\n",
459
+ ">> ===== After split, Dataset1 has 72842 examples and Dataset2 has 72842 examples. =====\n"
460
+ ]
461
+ },
462
+ {
463
+ "data": {
464
+ "text/plain": [
465
+ "[DatasetDict({\n",
466
+ " train: Dataset({\n",
467
+ " features: ['instruction', 'input', 'output'],\n",
468
+ " num_rows: 58273\n",
469
+ " })\n",
470
+ " test: Dataset({\n",
471
+ " features: ['instruction', 'input', 'output'],\n",
472
+ " num_rows: 14569\n",
473
+ " })\n",
474
+ " }),\n",
475
+ " DatasetDict({\n",
476
+ " train: Dataset({\n",
477
+ " features: ['instruction', 'input', 'output'],\n",
478
+ " num_rows: 58273\n",
479
+ " })\n",
480
+ " test: Dataset({\n",
481
+ " features: ['instruction', 'input', 'output'],\n",
482
+ " num_rows: 14569\n",
483
+ " })\n",
484
+ " })]"
485
+ ]
486
+ },
487
+ "execution_count": 27,
488
+ "metadata": {},
489
+ "output_type": "execute_result"
490
+ }
491
+ ],
492
+ "source": [
493
+ "FinanceDataset().list_dataset"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": 24,
499
+ "id": "78c926d4-9462-4a8a-b50a-ef5ed1e5c25e",
500
+ "metadata": {
501
+ "tags": []
502
+ },
503
+ "outputs": [
504
+ {
505
+ "data": {
506
+ "text/html": [
507
+ "<div>\n",
508
+ "<style scoped>\n",
509
+ " .dataframe tbody tr th:only-of-type {\n",
510
+ " vertical-align: middle;\n",
511
+ " }\n",
512
+ "\n",
513
+ " .dataframe tbody tr th {\n",
514
+ " vertical-align: top;\n",
515
+ " }\n",
516
+ "\n",
517
+ " .dataframe thead th {\n",
518
+ " text-align: right;\n",
519
+ " }\n",
520
+ "</style>\n",
521
+ "<table border=\"1\" class=\"dataframe\">\n",
522
+ " <thead>\n",
523
+ " <tr style=\"text-align: right;\">\n",
524
+ " <th></th>\n",
525
+ " <th>letter</th>\n",
526
+ " <th>number</th>\n",
527
+ " </tr>\n",
528
+ " </thead>\n",
529
+ " <tbody>\n",
530
+ " <tr>\n",
531
+ " <th>0</th>\n",
532
+ " <td>a</td>\n",
533
+ " <td>1</td>\n",
534
+ " </tr>\n",
535
+ " <tr>\n",
536
+ " <th>1</th>\n",
537
+ " <td>b</td>\n",
538
+ " <td>2</td>\n",
539
+ " </tr>\n",
540
+ " </tbody>\n",
541
+ "</table>\n",
542
+ "</div>"
543
+ ],
544
+ "text/plain": [
545
+ " letter number\n",
546
+ "0 a 1\n",
547
+ "1 b 2"
548
+ ]
549
+ },
550
+ "execution_count": 24,
551
+ "metadata": {},
552
+ "output_type": "execute_result"
553
+ }
554
+ ],
555
+ "source": [
556
+ "client_id_dataset = {\n",
557
+ " '1':\n",
558
+ "}"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": 25,
564
+ "id": "9b9babb9-01fe-4b01-be23-d1f334dd450c",
565
+ "metadata": {
566
+ "tags": []
567
+ },
568
+ "outputs": [
569
+ {
570
+ "data": {
571
+ "text/html": [
572
+ "<div>\n",
573
+ "<style scoped>\n",
574
+ " .dataframe tbody tr th:only-of-type {\n",
575
+ " vertical-align: middle;\n",
576
+ " }\n",
577
+ "\n",
578
+ " .dataframe tbody tr th {\n",
579
+ " vertical-align: top;\n",
580
+ " }\n",
581
+ "\n",
582
+ " .dataframe thead th {\n",
583
+ " text-align: right;\n",
584
+ " }\n",
585
+ "</style>\n",
586
+ "<table border=\"1\" class=\"dataframe\">\n",
587
+ " <thead>\n",
588
+ " <tr style=\"text-align: right;\">\n",
589
+ " <th></th>\n",
590
+ " <th>letter</th>\n",
591
+ " <th>number</th>\n",
592
+ " </tr>\n",
593
+ " </thead>\n",
594
+ " <tbody>\n",
595
+ " <tr>\n",
596
+ " <th>0</th>\n",
597
+ " <td>a</td>\n",
598
+ " <td>1</td>\n",
599
+ " </tr>\n",
600
+ " <tr>\n",
601
+ " <th>1</th>\n",
602
+ " <td>b</td>\n",
603
+ " <td>2</td>\n",
604
+ " </tr>\n",
605
+ " </tbody>\n",
606
+ "</table>\n",
607
+ "</div>"
608
+ ],
609
+ "text/plain": [
610
+ " letter number\n",
611
+ "0 a 1\n",
612
+ "1 b 2"
613
+ ]
614
+ },
615
+ "execution_count": 25,
616
+ "metadata": {},
617
+ "output_type": "execute_result"
618
+ }
619
+ ],
620
+ "source": [
621
+ "import pandas as pd\n",
622
+ "df1 = pd.DataFrame([['a', 1], ['b', 2]],\n",
623
+ " columns=['letter', 'number'])\n",
624
+ "\n",
625
+ "df2 = pd.DataFrame([['c', 3], ['d', 4]],\n",
626
+ " columns=['letter', 'number'])\n",
627
+ "\n",
628
+ "df3 = pd.DataFrame([['e', 5], ['f', 6]],\n",
629
+ " columns=['letter', 'number'])\n",
630
+ "\n",
631
+ "pd.concat([df1], ignore_index=True)"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 19,
637
+ "id": "fa021e4a-5e8b-470f-bf6f-db7b9e1f9eb6",
638
+ "metadata": {
639
+ "tags": []
640
+ },
641
+ "outputs": [
642
+ {
643
+ "name": "stdout",
644
+ "output_type": "stream",
645
+ "text": [
646
+ ">> ===== After processing, Dataset gsm8k has 2 examples. =====\n",
647
+ ">> ===== Spliting two parts datasets =====\n"
648
+ ]
649
+ },
650
+ {
651
+ "data": {
652
+ "text/plain": [
653
+ "DatasetDict({\n",
654
+ " train: Dataset({\n",
655
+ " features: ['instruction', 'output'],\n",
656
+ " num_rows: 7473\n",
657
+ " })\n",
658
+ " test: Dataset({\n",
659
+ " features: ['instruction', 'output'],\n",
660
+ " num_rows: 1319\n",
661
+ " })\n",
662
+ "})"
663
+ ]
664
+ },
665
+ "execution_count": 19,
666
+ "metadata": {},
667
+ "output_type": "execute_result"
668
+ }
669
+ ],
670
+ "source": [
671
+ "dataset = get_dataset(dataset_name=\"gsm8k\", local_data_dir=None)\n",
672
+ "datasets = process_sft_dataset(dataset_name=\"gsm8k\", dataset=dataset, dataset_sample=None)\n",
673
+ "datasets = datasets.rename_column(\"response\", \"output\")\n",
674
+ "datasets"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": 18,
680
+ "id": "c8e02304-5b90-4651-ad36-7feb440d7ea4",
681
+ "metadata": {
682
+ "tags": []
683
+ },
684
+ "outputs": [],
685
+ "source": [
686
+ "def clean_llm_text(text):\n",
687
+ " \"\"\"\n",
688
+ " Clean and normalize text from LLM outputs by removing noise and repetitions.\n",
689
+ " \n",
690
+ " Args:\n",
691
+ " text (str): Raw text from LLM prediction\n",
692
+ " \n",
693
+ " Returns:\n",
694
+ " str: Cleaned and normalized text\n",
695
+ " \"\"\"\n",
696
+ " import re\n",
697
+ " \n",
698
+ " # Remove repetitive patterns (like 'cor cor cor' or 'asesases')\n",
699
+ " def remove_repetitions(text):\n",
700
+ " # Split into words\n",
701
+ " words = text.split()\n",
702
+ " cleaned_words = []\n",
703
+ " prev_word = None\n",
704
+ " repetition_count = 0\n",
705
+ " \n",
706
+ " for word in words:\n",
707
+ " if word == prev_word:\n",
708
+ " repetition_count += 1\n",
709
+ " if repetition_count < 2: # Allow up to 2 repetitions for legitimate cases\n",
710
+ " cleaned_words.append(word)\n",
711
+ " else:\n",
712
+ " repetition_count = 0\n",
713
+ " cleaned_words.append(word)\n",
714
+ " prev_word = word\n",
715
+ " \n",
716
+ " return ' '.join(cleaned_words)\n",
717
+ " \n",
718
+ " def remove_repeats(text):\n",
719
+ " # Remove repeated words\n",
720
+ " pattern_words = r'\\b(\\w+)(?:\\s+\\1\\b)+'\n",
721
+ " text = re.sub(pattern_words, r'\\1', text)\n",
722
+ "\n",
723
+ " # Remove repeated character patterns (like 'asasas')\n",
724
+ " pattern_chars = r'(\\w+?)\\1+'\n",
725
+ " text = re.sub(pattern_chars, r'\\1', text)\n",
726
+ "\n",
727
+ " return text\n",
728
+ " \n",
729
+ " # Remove excessive punctuation\n",
730
+ " def normalize_punctuation(text):\n",
731
+ " # Replace multiple exclamation/question marks with single ones\n",
732
+ " text = re.sub(r'!+', '!', text)\n",
733
+ " text = re.sub(r'\\?+', '?', text)\n",
734
+ " # Remove multiple periods (except for ellipsis)\n",
735
+ " text = re.sub(r'\\.{4,}', '...', text)\n",
736
+ " text = text.replace('cor', '').replace('asesa', '')\n",
737
+ " return text\n",
738
+ " \n",
739
+ " # Main cleaning pipeline\n",
740
+ " cleaned_text = text.strip()\n",
741
+ " \n",
742
+ " # Remove common noise patterns\n",
743
+ " noise_patterns = [\n",
744
+ " r'\\n+', # Multiple newlines\n",
745
+ " r'\\s+', # Multiple spaces\n",
746
+ " r'\\\\n', # Literal \\n\n",
747
+ " r'\\\\t', # Literal \\t\n",
748
+ " ]\n",
749
+ " \n",
750
+ " for pattern in noise_patterns:\n",
751
+ " cleaned_text = re.sub(pattern, ' ', cleaned_text)\n",
752
+ " \n",
753
+ " # Apply cleaning functions\n",
754
+ " # cleaned_text = remove_repetitions(cleaned_text)\n",
755
+ " cleaned_text = remove_repeats(cleaned_text)\n",
756
+ " cleaned_text = normalize_punctuation(cleaned_text)\n",
757
+ " cleaned_text = ' '.join(cleaned_text.split()) # Normalize spacing\n",
758
+ " \n",
759
+ " return cleaned_text"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": 20,
765
+ "id": "9d95fbfe-d824-4247-8628-3bb1cffe9065",
766
+ "metadata": {
767
+ "tags": []
768
+ },
769
+ "outputs": [
770
+ {
771
+ "data": {
772
+ "text/plain": [
773
+ "'folowing into their based mamals animals, water animals. 1 Response: animals: , Elephant,Sea Animals: Dolphin, Dolphin'"
774
+ ]
775
+ },
776
+ "execution_count": 20,
777
+ "metadata": {},
778
+ "output_type": "execute_result"
779
+ }
780
+ ],
781
+ "source": [
782
+ "text_str = f\"\"\"at\\n\\nSea Animals: Whale, Fish cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor', \" cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor One example of a technology that uses artificial intelligence is a virtual personal assistant such as Amazon's Alexa, Apple's Siri, or Google Assistant. These devices use natural language processing and machine learning to understand and respond to user's voice commands, providing assistance in tasks such as setting reminders, playing music, or answering questions. cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor cor\"\"\"\n",
783
+ "text_str1 = f\"\"\"following into their based mammals animals, water animals.\\n\\n1 Response:\\n \\n animals: \\n, Elephant,Sea Animals: Dolphin, Dolphin\\n\\nasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesasesa\"\"\"\n",
784
+ "clean_llm_text(text_str1)"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": 10,
790
+ "id": "c2bd139b-55f9-4848-a3d3-0db55d601e67",
791
+ "metadata": {
792
+ "tags": []
793
+ },
794
+ "outputs": [
795
+ {
796
+ "data": {
797
+ "text/plain": [
798
+ "tensor([[1, 2, 3]])"
799
+ ]
800
+ },
801
+ "execution_count": 10,
802
+ "metadata": {},
803
+ "output_type": "execute_result"
804
+ }
805
+ ],
806
+ "source": [
807
+ "import torch\n",
808
+ "\n",
809
+ "x = torch.tensor([1, 2, 3])\n",
810
+ "torch.unsqueeze(x,dim=0)"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 16,
816
+ "id": "2eb4071a-3d82-4d86-bb0e-a1403a2b1c1e",
817
+ "metadata": {
818
+ "tags": []
819
+ },
820
+ "outputs": [
821
+ {
822
+ "name": "stdout",
823
+ "output_type": "stream",
824
+ "text": [
825
+ "TF-IDF Matrix Shape: (136, 25)\n"
826
+ ]
827
+ }
828
+ ],
829
+ "source": []
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "execution_count": null,
834
+ "id": "4f75512f-13e8-4373-8630-8b5269729c32",
835
+ "metadata": {},
836
+ "outputs": [],
837
+ "source": []
838
+ }
839
+ ],
840
+ "metadata": {
841
+ "kernelspec": {
842
+ "display_name": "py11torch",
843
+ "language": "python",
844
+ "name": "py11torch"
845
+ },
846
+ "language_info": {
847
+ "codemirror_mode": {
848
+ "name": "ipython",
849
+ "version": 3
850
+ },
851
+ "file_extension": ".py",
852
+ "mimetype": "text/x-python",
853
+ "name": "python",
854
+ "nbconvert_exporter": "python",
855
+ "pygments_lexer": "ipython3",
856
+ "version": "3.11.8"
857
+ }
858
+ },
859
+ "nbformat": 4,
860
+ "nbformat_minor": 5
861
+ }
template_FL/src/fedllm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """flowertune_llm."""
template_FL/src/fedllm/client_app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """flowertune-llm: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ import warnings
5
+ from typing import Dict, Tuple
6
+
7
+ import torch
8
+ import wandb
9
+ import numpy as np
10
+ from flwr.client import ClientApp, NumPyClient
11
+ from flwr.common import Context
12
+ from flwr.common.config import unflatten_dict
13
+ from flwr.common.typing import NDArrays, Scalar
14
+ from omegaconf import DictConfig
15
+
16
+ from transformers import TrainingArguments, DataCollatorForSeq2Seq, Trainer, EarlyStoppingCallback, BertForSequenceClassification, GenerationConfig
17
+
18
+ from trl import SFTTrainer, SFTConfig
19
+ from deepspeed.profiling.flops_profiler import get_model_profile
20
+ from deepspeed.accelerator import get_accelerator
21
+
22
+ from .trainer import ManualTrainer
23
+
24
+ from .dataset import (
25
+ get_data_collator_and_propt_formatting,
26
+ load_data,
27
+ load_data_homo,
28
+ load_data_hete,
29
+ replace_keys,
30
+ )
31
+ from .models import *
32
+
33
+ from .flwr_mods import get_wandb_mod
34
+ from .metrics import exact_match, f1, get_rouge_score
35
+ from .utils import clean_output_text
36
+ from .make_data import Prompter, generate_and_tokenize_prompt
37
+
38
+ # Avoid warnings
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
40
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
41
+ warnings.filterwarnings("ignore", category=UserWarning)
42
+
43
+
44
+ def input_constructor(batch_size, seq_len, tokenizer):
45
+ fake_seq = ""
46
+ for _ in range(seq_len - 2): # ignore the two special tokens [CLS] and [SEP]
47
+ fake_seq += tokenizer.pad_token
48
+ inputs = tokenizer([fake_seq] * batch_size,
49
+ padding=True,
50
+ truncation=True,
51
+ max_length=seq_len,
52
+ return_tensors="pt")
53
+ labels = torch.tensor([1] * batch_size)
54
+ inputs = dict(inputs)
55
+ # inputs.update({"labels": torch.unsqueeze(labels,dim=0)})
56
+
57
+ # To device
58
+ inputs = {k: v.to('cuda:0') for k, v in inputs.items()}
59
+ return inputs
60
+
61
+ def convert_to_float(value_str):
62
+ value, unit = value_str.split()
63
+ value = float(value)
64
+ if unit == 'T' or 'T' in unit:
65
+ return value * 1e12
66
+ elif unit == 'G' or 'G' in unit:
67
+ return value * 1e9
68
+ elif unit == 'M' or 'M' in unit:
69
+ return value * 1e6
70
+ elif unit == 'K' or 'K' in unit:
71
+ return value * 1e3
72
+ return value
73
+
74
+ # pylint: disable=too-many-arguments
75
+ # pylint: disable=too-many-instance-attributes
76
+ class FlowerClient(NumPyClient):
77
+ """Standard Flower client for CNN training."""
78
+
79
+ def __init__(
80
+ self,
81
+ model_cfg: DictConfig,
82
+ train_cfg: DictConfig,
83
+ mates_args: DictConfig,
84
+ trainset,
85
+ valset,
86
+ num_rounds,
87
+ ): # pylint: disable=too-many-arguments
88
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89
+ self.train_cfg = train_cfg
90
+
91
+ self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
92
+ # self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
93
+
94
+ self.num_rounds = num_rounds
95
+ self.trainset = trainset
96
+ self.valset = valset
97
+ self.mates_args = mates_args
98
+ self.holdoutset = None
99
+ self.refset = None
100
+ self.data_influence_model = None
101
+ self.data_influence_tokenizer = None
102
+
103
+ # instantiate model
104
+ self.model, self.tokenizer = get_model(model_cfg)
105
+
106
+ if self.mates_args.state:
107
+ self.data_influence_model, self.data_influence_tokenizer = get_data_influence_model(model_cfg)
108
+
109
+ # (
110
+ # self.data_collator,
111
+ # self.formatting_prompts_func
112
+ # ) = get_data_collator_and_propt_formatting(self.tokenizer)
113
+
114
+ self.data_collator = DataCollatorForSeq2Seq(
115
+ self.tokenizer,
116
+ pad_to_multiple_of=8,
117
+ return_tensors="pt",
118
+ padding=True,
119
+ )
120
+
121
+ self.train_on_inputs = self.train_cfg.train_on_inputs
122
+
123
+ self._make_dataset()
124
+
125
+ def compute_metrics(self, pred):
126
+ labels_ids = pred['label_ids']
127
+ pred_ids = pred['predictions']
128
+
129
+ # Replace -100 with pad token id in labels
130
+ labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
131
+
132
+ print(f"Shape of predictions: {np.shape(pred_ids)}")
133
+ print(f"Shape of labels: {np.shape(labels_ids)}")
134
+
135
+ # Decode predictions and labels
136
+ pred_str = self.tokenizer.batch_decode(
137
+ pred_ids, skip_special_tokens=True
138
+ )
139
+ label_str = self.tokenizer.batch_decode(
140
+ labels_ids, skip_special_tokens=True
141
+ )
142
+
143
+ # Remove any extra whitespace from the decoded strings
144
+ pred_str = [s.strip() for s in pred_str]
145
+ label_str = [s.strip() for s in label_str]
146
+
147
+ return {
148
+ **get_rouge_score(predictions=pred_str, targets=label_str),
149
+ **f1(predictions=pred_str, targets=label_str),
150
+ }
151
+
152
+ def _make_dataset(self):
153
+ prompter = Prompter(self.train_cfg.prompt_template_name, self.train_cfg.verbose)
154
+ tmp_dict = {
155
+ "prompter": prompter,
156
+ "seq_length": self.train_cfg.seq_length,
157
+ "train_on_inputs": self.train_on_inputs,
158
+ "tokenizer": self.tokenizer,
159
+ }
160
+
161
+ # Process trainset
162
+ self.trainset = (
163
+ self.trainset
164
+ .shuffle()
165
+ .map(
166
+ lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
167
+ num_proc=8,
168
+ )
169
+ )
170
+
171
+ # Process valset
172
+ self.valset = (
173
+ self.valset
174
+ .shuffle()
175
+ .map(
176
+ lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
177
+ num_proc=8,
178
+ )
179
+ )
180
+
181
+ # Create holdoutset and refset if state is True
182
+ if self.mates_args.state:
183
+ trainset_size = len(self.trainset)
184
+
185
+ # Calculate sizes for holdout and reference sets
186
+ holdout_size = int(trainset_size * self.mates_args.holdout_ratio)
187
+ ref_size = int(trainset_size * self.mates_args.reference_ratio)
188
+
189
+ # Shuffle the trainset to ensure randomness
190
+ shuffled_indices = list(range(trainset_size))
191
+ self.trainset = self.trainset.shuffle()
192
+
193
+ # Split the dataset
194
+ holdout_indices = shuffled_indices[:holdout_size]
195
+ ref_indices = shuffled_indices[holdout_size:holdout_size + ref_size]
196
+
197
+ # Create holdoutset and refset
198
+ self.holdoutset = self.trainset.select(holdout_indices)
199
+ self.refset = self.trainset.select(ref_indices)
200
+
201
+ print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
202
+
203
+
204
+ def fit(
205
+ self, parameters: NDArrays, config: Dict[str, Scalar]
206
+ ) -> Tuple[NDArrays, int, Dict]:
207
+ """Implement distributed fit function for a given client."""
208
+ if self.mates_args.state and int(config["current_round"]) != 1:
209
+ main_model_params, data_influence_model_params = split_models(parameters)
210
+ set_parameters(self.model, main_model_params)
211
+ set_parameters_bert(self.data_influence_model, data_influence_model_params)
212
+ else:
213
+ set_parameters(self.model, parameters)
214
+
215
+ new_lr = cosine_annealing(
216
+ int(config["current_round"]),
217
+ self.num_rounds,
218
+ self.train_cfg.learning_rate_max,
219
+ self.train_cfg.learning_rate_min,
220
+ )
221
+
222
+ self.training_arguments.learning_rate = new_lr
223
+ self.training_arguments.output_dir = config["save_path"]
224
+
225
+ # Initialize callback
226
+ early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5)
227
+
228
+ # Construct supervised trainer
229
+ # trainer = SFTTrainer(
230
+ # model=self.model,
231
+ # tokenizer=self.tokenizer,
232
+ # args=self.training_arguments,
233
+ # train_dataset=self.trainset,
234
+ # eval_dataset=self.valset,
235
+ # formatting_func=self.formatting_prompts_func,
236
+ # data_collator=self.data_collator,
237
+ # compute_metrics=self.compute_metrics,
238
+ # callbacks=[flops_callback, early_stopping_callback]
239
+ # )
240
+
241
+ # # Constuct baseline Trainer
242
+ # trainer = Trainer(
243
+ # model=self.model,
244
+ # train_dataset=self.trainset,
245
+ # eval_dataset=self.valset.select(range(10)),
246
+ # args=self.training_arguments,
247
+ # data_collator=self.data_collator,
248
+ # compute_metrics=self.compute_metrics,
249
+ # callbacks=[early_stopping_callback]
250
+ # )
251
+
252
+ trainer = ManualTrainer(
253
+ model=self.model,
254
+ tokenizer = self.tokenizer,
255
+ train_dataset=self.trainset,
256
+ val_dataset=self.valset.select(range(10)),
257
+ holdout_dataset=self.holdoutset,
258
+ reference_dataset=self.refset,
259
+ args=self.training_arguments,
260
+ data_collator=self.data_collator,
261
+ compute_metrics=self.compute_metrics,
262
+ mates_args=self.mates_args,
263
+ data_influence_model=self.data_influence_model,
264
+ data_influence_tokenizer=self.data_influence_tokenizer,
265
+ )
266
+
267
+ # Train the model
268
+ results = trainer.train()
269
+
270
+ if self.mates_args.state:
271
+ # After training
272
+ main_model_params = get_parameters(self.model)
273
+ data_influence_model_params = model_parameters_to_ndarrays(self.data_influence_model)
274
+ final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
275
+ else:
276
+ final_model_params = get_parameters(self.model)
277
+
278
+ # Calculate FLOPs
279
+ with get_accelerator().device('cuda:0'):
280
+ batch_size = self.training_arguments.per_device_eval_batch_size
281
+ seq_len = self.train_cfg.seq_length
282
+ flops1, macs1, params1 = get_model_profile(
283
+ self.model,
284
+ kwargs=input_constructor(batch_size, seq_len, self.tokenizer),
285
+ print_profile=True,
286
+ detailed=False,
287
+ )
288
+ flops2, macs2, params2 = get_model_profile(
289
+ self.data_influence_model,
290
+ kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
291
+ print_profile=True,
292
+ detailed=False,
293
+ )
294
+ flops1_value, flops2_value = convert_to_float(flops1), convert_to_float(flops2)
295
+ macs1_value, macs2_value = convert_to_float(macs1), convert_to_float(macs2)
296
+ params1_value, params2_value = convert_to_float(params1), convert_to_float(params2)
297
+ wandb.log({"total_flops": flops1_value + flops2_value, "macs": macs1_value + macs2_value, "params": params1_value + params2_value}) # wa
298
+
299
+ return (
300
+ final_model_params,
301
+ len(self.trainset),
302
+ {"train_loss": results['training_loss'], "flops": flops1_value + flops2_value},
303
+ )
304
+
305
+
306
+ def client_fn(context: Context) -> FlowerClient:
307
+ """Create a Flower client representing a single organization."""
308
+ partition_id = context.node_config["partition-id"]
309
+ num_partitions = context.node_config["num-partitions"]
310
+ num_rounds = context.run_config["num-server-rounds"]
311
+ cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
312
+
313
+ # Let's get the client partition
314
+ if cfg.dataset.type == 'homo':
315
+ client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
316
+ else:
317
+ client_set = load_data_hete(partition_id)
318
+
319
+ return FlowerClient(
320
+ cfg.model,
321
+ cfg.train,
322
+ cfg.mates,
323
+ client_set['train'],
324
+ client_set['test'],
325
+ num_rounds,
326
+ ).to_client()
327
+
328
+
329
+ # Flower ClientApp
330
+ app = ClientApp(
331
+ client_fn,
332
+ mods=[
333
+ get_wandb_mod("FL@CSS25"),
334
+ ],
335
+ )
template_FL/src/fedllm/data_domains.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import pandas as pd
3
+ from datasets import Dataset, DatasetDict, load_dataset
4
+ from sklearn.model_selection import train_test_split
5
+ from functools import partial
6
+
7
+ global_test_set_hete = {}
8
+
9
+ class DatasetAbstract:
10
+ def __init__(self, dataset_name: list[str], category: str):
11
+ self.dataset_name = dataset_name
12
+ self.metadata = {
13
+ 'domain': category
14
+ }
15
+
16
+ def _processing_data(self):
17
+ pass
18
+
19
+ @classmethod
20
+ def get_dataset(cls, dataset_name, local_data_dir=None):
21
+ if dataset_name in ["gsm8k"]:
22
+ dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name
23
+ dataset = load_dataset(dataset_name, name="main")
24
+ else:
25
+ dataset_name = local_data_dir + dataset_name if local_data_dir is not None else dataset_name
26
+ dataset = load_dataset(dataset_name)
27
+
28
+ return dataset
29
+
30
+ def get_split_dataset(self, dataset):
31
+ print(f">> ===== After processing, Dataset has {len(dataset)} examples. =====")
32
+ if len(dataset) > 10000:
33
+ ds_part1, ds_part2 = train_test_split(
34
+ dataset, test_size=0.5, shuffle=True, random_state=42
35
+ )
36
+ print(f">> ===== After split, Dataset1 has {len(ds_part1)} examples and Dataset2 has {len(ds_part2)} examples. =====")
37
+ list_dataset = []
38
+ list_global_set = []
39
+ for subset in [ds_part1, ds_part2]:
40
+ train, test = train_test_split(
41
+ subset, test_size=0.2, shuffle=True, random_state=42
42
+ )
43
+ test, global_test = train_test_split(
44
+ subset, test_size=0.1, shuffle=True, random_state=42
45
+ )
46
+ ds = DatasetDict({
47
+ "train": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),
48
+ "test": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
49
+ })
50
+ list_dataset.append(ds)
51
+ list_global_set.append(global_test)
52
+
53
+ list_global_set = pd.concat(list_global_set, ignore_index=True)
54
+ list_global_set = Dataset.from_pandas(list_global_set)
55
+ return list_dataset, list_global_set
56
+
57
+ else:
58
+ train, test = train_test_split(
59
+ dataset , test_size=0.2, shuffle=True, random_state=42
60
+ )
61
+ test, global_test = train_test_split(
62
+ subset, test_size=0.1, shuffle=True, random_state=42
63
+ )
64
+ ds = DatasetDict(
65
+ {
66
+ "train": Dataset.from_pandas(train).remove_columns(['__index_level_0__']),
67
+ "test": Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
68
+ }
69
+ )
70
+ global_set = Dataset.from_pandas(global_test).remove_columns(['__index_level_0__'])
71
+ return [ds], global_set
72
+
73
+
74
+ class GeneralDataset(DatasetAbstract):
75
+
76
+ def __init__(self):
77
+ list_dataset = ["tatsu-lab/alpaca", "vicgalle/alpaca-gpt4"]
78
+ super().__init__(list_dataset, 'general')
79
+ self._processing_data()
80
+
81
+ def _processing_data(self):
82
+ datasets = []
83
+ for dataset_name in self.dataset_name:
84
+ datasets.append(
85
+ pd.DataFrame(super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train'])
86
+ )
87
+ dataset = pd.concat(datasets, ignore_index=True)
88
+ self.list_dataset, global_test = self.get_split_dataset(dataset)
89
+ global global_test_set_hete
90
+ global_test_set_hete.update(
91
+ {self.metadata['domain']: global_test}
92
+ )
93
+
94
+
95
+ class FinanceDataset(DatasetAbstract):
96
+
97
+ def __init__(self):
98
+ list_dataset = ["gbharti/finance-alpaca", "FinGPT/fingpt-sentiment-train"]
99
+ super().__init__(list_dataset, 'finance')
100
+
101
+ self._processing_data()
102
+
103
+ def _processing_data(self):
104
+ datasets = []
105
+ for dataset_name in self.dataset_name:
106
+ ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
107
+ if dataset_name == 'gbharti/finance-alpaca':
108
+ ds = ds.remove_columns(['text'])
109
+ df = pd.DataFrame(ds)
110
+ datasets.append(df)
111
+ dataset = pd.concat(datasets, ignore_index=True)
112
+ self.list_dataset, global_test = self.get_split_dataset(dataset)
113
+ global global_test_set_hete
114
+ global_test_set_hete.update(
115
+ {self.metadata['domain']: global_test}
116
+ )
117
+
118
+
119
+ class MathDataset(DatasetAbstract):
120
+
121
+ def __init__(self):
122
+ list_dataset = ["TIGER-Lab/MathInstruct", "xDAN2099/lighteval-MATH", "gsm8k"]
123
+ super().__init__(list_dataset, 'math')
124
+ self._processing_data()
125
+
126
+
127
+ def get_split_dataset(self, dataset):
128
+ dataset_train, dataset_test = dataset[0], dataset[1]
129
+ dataset_test, global_test = train_test_split(
130
+ dataset_test, test_size=0.1, shuffle=True, random_state=42
131
+ )
132
+ global_test = Dataset.from_pandas(global_test)
133
+ print(f">> ===== After processing, Dataset has {len(dataset_train)} examples. =====")
134
+ if len(dataset_train) > 10000:
135
+ ds_train_part1, ds_train_part2 = train_test_split(
136
+ dataset_train, test_size=0.5, shuffle=True, random_state=42
137
+ )
138
+ ds_test_part1, ds_test_part2 = train_test_split(
139
+ dataset_test, test_size=0.5, shuffle=True, random_state=42
140
+ )
141
+ print(f">> ===== After split, Dataset1 has {len(ds_train_part1)} examples and Dataset2 has {len(ds_train_part2)} examples. =====")
142
+ list_dataset = []
143
+ for i in range(2):
144
+ ds = DatasetDict({
145
+ "train": Dataset.from_pandas(eval(f'ds_train_part{i+1}')).remove_columns(['__index_level_0__']),
146
+ "test": Dataset.from_pandas(eval(f'ds_test_part{i+1}')).remove_columns(['__index_level_0__'])
147
+ })
148
+ list_dataset.append(ds)
149
+ return list_dataset, global_test
150
+
151
+ else:
152
+ ds = DatasetDict(
153
+ {
154
+ "train": Dataset.from_pandas(dataset_train).remove_columns(['__index_level_0__']),
155
+ "test": Dataset.from_pandas(dataset_test).remove_columns(['__index_level_0__'])
156
+ }
157
+ )
158
+ return [ds], global_test
159
+
160
+ def _processing_data(self):
161
+ datasets_train, datasets_test = [], []
162
+ for dataset_name in self.dataset_name:
163
+ ds_tmp = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)
164
+ if dataset_name == 'TIGER-Lab/MathInstruct':
165
+ df = pd.DataFrame(ds_tmp['train'])
166
+ df = df.drop_duplicates(subset=['instruction'])
167
+ df = df.drop(['source'], axis=1)
168
+ df_train, df_test = train_test_split(df, test_size=0.3, shuffle=True, random_state=42)
169
+
170
+ elif dataset_name == "xDAN2099/lighteval-MATH":
171
+ ds_tmp = ds_tmp.remove_columns(['level', 'type'])
172
+ ds_tmp = ds_tmp.rename_column("solution", "output")
173
+ ds_tmp = ds_tmp.rename_column("problem", "instruction")
174
+ df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])
175
+
176
+ elif dataset_name == 'gsm8k':
177
+ ds_tmp = ds_tmp.rename_column("answer", "output")
178
+ ds_tmp = ds_tmp.rename_column("question", "instruction")
179
+ df_train, df_test = pd.DataFrame(ds_tmp['train']), pd.DataFrame(ds_tmp['test'])
180
+
181
+ df_train['input'] = [''] * len(df_train)
182
+ df_test['input'] = [''] * len(df_test)
183
+ datasets_train.append(df_train)
184
+ datasets_test.append(df_test)
185
+
186
+ dataset_train = pd.concat(datasets_train, ignore_index=True)
187
+ dataset_test = pd.concat(datasets_test, ignore_index=True)
188
+ dataset = [dataset_train, dataset_test]
189
+ self.list_dataset, global_test = self.get_split_dataset(dataset)
190
+ global global_test_set_hete
191
+ global_test_set_hete.update(
192
+ {self.metadata['domain']: global_test}
193
+ )
194
+
195
+
196
+ class MedicalDataset(DatasetAbstract):
197
+
198
+ def __init__(self):
199
+ list_dataset = ["medalpaca/medical_meadow_medical_flashcards"]
200
+ super().__init__(list_dataset, 'medical')
201
+ self._processing_data()
202
+
203
+ def _processing_data(self):
204
+ datasets = []
205
+ for dataset_name in self.dataset_name:
206
+ ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
207
+ if dataset_name == 'medalpaca/medical_meadow_medical_flashcards':
208
+ ds = ds.remove_columns(['instruction'])
209
+ ds = ds.rename_column("input", "instruction")
210
+
211
+ df = pd.DataFrame(ds)
212
+ df['input'] = [''] * len(df)
213
+ datasets.append(df)
214
+ dataset = pd.concat(datasets, ignore_index=True)
215
+ self.list_dataset, global_test = self.get_split_dataset(dataset)
216
+ global global_test_set_hete
217
+ global_test_set_hete.update(
218
+ {self.metadata['domain']: global_test}
219
+ )
220
+
221
+ class CodeDataset(DatasetAbstract):
222
+
223
+ def __init__(self):
224
+ list_dataset = ["lucasmccabe-lmi/CodeAlpaca-20k", "WizardLMTeam/WizardLM_evol_instruct_70k"]
225
+ super().__init__(list_dataset, 'code')
226
+ self._processing_data()
227
+
228
+ def _processing_data(self):
229
+ datasets = []
230
+ for dataset_name in self.dataset_name:
231
+ ds = super().get_dataset(dataset_name=dataset_name, local_data_dir=None)['train']
232
+ df = pd.DataFrame(ds)
233
+ if dataset_name == 'WizardLMTeam/WizardLM_evol_instruct_70k':
234
+ df['input'] = [''] * len(df)
235
+ datasets.append(df)
236
+ dataset = pd.concat(datasets, ignore_index=True)
237
+ self.list_dataset, global_test = self.get_split_dataset(dataset)
238
+ global global_test_set_hete
239
+ global_test_set_hete.update(
240
+ {self.metadata['domain']: global_test}
241
+ )
242
+
243
+ def release_ds():
244
+ data_domain = {
245
+ 'general': GeneralDataset().list_dataset,
246
+ 'finance': FinanceDataset().list_dataset,
247
+ 'math': MathDataset().list_dataset,
248
+ 'medical': MedicalDataset().list_dataset,
249
+ 'code': CodeDataset().list_dataset
250
+ }
251
+ tmp_dataset = {}
252
+ k = 0
253
+ for task in data_domain.keys():
254
+ tmp_dataset[str(k)] = data_domain[task][0]
255
+ tmp_dataset[str(k+1)] = data_domain[task][1]
256
+ k += 2
257
+
258
+ return tmp_dataset
259
+
260
+ # data_domain = {
261
+ # 'general': GeneralDataset().list_dataset,
262
+ # 'finance': FinanceDataset().list_dataset,
263
+ # 'math': MathDataset().list_dataset,
264
+ # 'medical': MedicalDataset().list_dataset,
265
+ # 'code': CodeDataset().list_dataset
266
+ # }
267
+
268
+ # client_id_dataset = {
269
+ # '0': data_domain['general'][0],
270
+ # '1': data_domain['general'][1],
271
+ # '2': data_domain['finance'][0],
272
+ # '3': data_domain['finance'][1],
273
+ # '4': data_domain['math'][0],
274
+ # '5': data_domain['math'][1],
275
+ # '6': data_domain['medical'][0],
276
+ # '7': data_domain['medical'][1],
277
+ # '8': data_domain['code'][0],
278
+ # '9': data_domain['code'][1],
279
+ # }
280
+
281
+ client_id_dataset = release_ds()
template_FL/src/fedllm/dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trl import DataCollatorForCompletionOnlyLM
2
+
3
+ from flwr_datasets.partitioner import IidPartitioner
4
+ from flwr_datasets import FederatedDataset
5
+ from datasets import Dataset, DatasetDict
6
+ from sklearn.model_selection import train_test_split
7
+ import pandas as pd
8
+
9
+
10
+ FDS = None # Cache FederatedDataset
11
+ client_id_ds = None
12
+ global_test_set_homo = None
13
+
14
+ def split_train_test(dataset, test_size):
15
+ # Split the dataset into train and test sets
16
+ train_data, test_data = train_test_split(dataset.to_pandas(), test_size=test_size, shuffle=True, random_state=42)
17
+ test_data, global_test = train_test_split(test_data, test_size=0.1, shuffle=True, random_state=42)
18
+
19
+ # Convert to Dataset objects
20
+ train_dataset = Dataset.from_pandas(train_data)
21
+ test_dataset = Dataset.from_pandas(test_data)
22
+ global_test = Dataset.from_pandas(global_test)
23
+
24
+ # Combine into a DatasetDict
25
+ datasets_dict = DatasetDict({
26
+ 'train': train_dataset,
27
+ 'test': test_dataset
28
+ })
29
+ return datasets_dict
30
+
31
+
32
+
33
+ def formatting_prompts_func(example):
34
+ output_texts = []
35
+ # Constructing a standard Alpaca (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
36
+ mssg = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
37
+ for i in range(len(example["instruction"])):
38
+ text = f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n### Response: {example['response'][i]}"
39
+ output_texts.append(text)
40
+ return output_texts
41
+
42
+
43
+ def get_data_collator_and_propt_formatting(tokenizer):
44
+ # From: https://huggingface.co/docs/trl/en/sft_trainer
45
+ response_template_with_context = "\n### Response:" # alpaca response tag
46
+ response_template_ids = tokenizer.encode(
47
+ response_template_with_context, add_special_tokens=False
48
+ )[2:]
49
+ data_collator = DataCollatorForCompletionOnlyLM(
50
+ response_template_ids, tokenizer=tokenizer
51
+ )
52
+
53
+ return data_collator, formatting_prompts_func
54
+
55
+
56
+ def load_data(partition_id: int, num_partitions: int, dataset_name: str):
57
+ """Load partition data."""
58
+ # Only initialize `FederatedDataset` once
59
+ global FDS
60
+ if FDS is None:
61
+ partitioner = IidPartitioner(num_partitions=num_partitions)
62
+ FDS = FederatedDataset(
63
+ dataset=dataset_name,
64
+ partitioners={"train": partitioner},
65
+ )
66
+ print(f"<---- Load client {partition_id} --->")
67
+ client_trainset = FDS.load_partition(partition_id, "train")
68
+ client_trainset = client_trainset.rename_column("output", "response")
69
+ print(client_trainset)
70
+ return client_trainset
71
+
72
+ def load_data_homo(partition_id: int, num_partitions: int, dataset_name: str):
73
+ """Load partition data."""
74
+ # Only initialize `FederatedDataset` once
75
+ global FDS
76
+ global global_test_set_homo
77
+ if FDS is None:
78
+ partitioner = IidPartitioner(num_partitions=num_partitions)
79
+ FDS = FederatedDataset(
80
+ dataset=dataset_name,
81
+ partitioners={"train": partitioner},
82
+ )
83
+ # list_ds = []
84
+ # for cid in range(0,num_partitions):
85
+ # tmp_set = FDS.load_partition(cid, "train")
86
+ # list_ds.append(
87
+ # pd.DataFrame(tmp_set)
88
+ # )
89
+ # list_ds = pd.concat(list_ds, ignore_index=True)
90
+ # _, global_test_set_homo = train_test_split(
91
+ # list_ds, test_size=0.1, shuffle=True, random_state=42
92
+ # )
93
+ # global_test_set_homo = Dataset.from_pandas(global_test_set_homo).remove_columns(['__index_level_0__'])
94
+
95
+ print(f"<---- Load client {partition_id} --->")
96
+ client_trainset = FDS.load_partition(partition_id, "train")
97
+ # client_trainset = client_trainset.rename_column("output", "response")
98
+ client_set = split_train_test(client_trainset, test_size=0.2)
99
+ return client_set
100
+
101
+
102
+ def load_data_hete(partition_id: int):
103
+ """Load partition data heterogeneous"""
104
+ global client_id_ds
105
+ if client_id_ds is None:
106
+ from .data_domains import client_id_dataset
107
+ client_id_ds = client_id_dataset
108
+ print(f"<---- Load client {partition_id} --->")
109
+ client_set = client_id_ds[str(partition_id)]
110
+ return client_set
111
+
112
+
113
+ def replace_keys(input_dict, match="-", target="_"):
114
+ """Recursively replace match string with target string in dictionary keys."""
115
+ new_dict = {}
116
+ for key, value in input_dict.items():
117
+ new_key = key.replace(match, target)
118
+ if isinstance(value, dict):
119
+ new_dict[new_key] = replace_keys(value, match, target)
120
+ else:
121
+ new_dict[new_key] = value
122
+ return new_dict
template_FL/src/fedllm/flwr_mods.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flwr.common import Context, Message, MessageType, ConfigsRecord
2
+ from flwr.client.typing import ClientAppCallable
3
+ from typing import Callable
4
+ import wandb
5
+ import time
6
+ from .myfedavg import client_id_idx
7
+
8
+
9
+ # Define type alias for Mod
10
+ Mod = Callable[[Message, Context, ClientAppCallable], Message]
11
+
12
+ def get_wandb_mod(name: str) -> Mod:
13
+ # Keep track of active runs
14
+ active_run: Optional[wandb.Run] = None
15
+
16
+ def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message:
17
+ nonlocal active_run
18
+ server_round = int(msg.metadata.group_id)
19
+ run_id = msg.metadata.run_id
20
+ group_name = f"Run ID: {run_id}"
21
+ node_id = str(msg.metadata.dst_node_id)
22
+ run_name = f"Client ID: {client_id_idx[node_id]}"
23
+
24
+ wandb.init(
25
+ project=name,
26
+ group=group_name,
27
+ name=run_name,
28
+ id=f"{run_id}_{client_id_idx[node_id]}",
29
+ resume="allow",
30
+ reinit=True,
31
+ # settings=wandb.Settings(start_method="thread")
32
+ )
33
+
34
+ start_time = time.time()
35
+ reply = app(msg, context)
36
+
37
+ if reply.metadata.message_type == MessageType.TRAIN and reply.has_content():
38
+
39
+ time_diff = time.time() - start_time
40
+ metrics = reply.content.configs_records
41
+ results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord()))
42
+ results_to_log["fit_time"] = time_diff
43
+
44
+ wandb.log(results_to_log, step=int(server_round), commit=True)
45
+
46
+ return reply
47
+
48
+ return wandb_mod
49
+
template_FL/src/fedllm/make_data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from typing import Union
5
+
6
+ class Prompter(object):
7
+ __slots__ = ("template", "_verbose")
8
+
9
+ def __init__(self, template_name: str = "", verbose: bool = False):
10
+ self._verbose = verbose
11
+ if not template_name:
12
+ # Enforece the default here, so the constructor can be called with '' and will not break.
13
+ template_name = "alpaca"
14
+ file_name = osp.join(
15
+ os.getcwd(), "fedllm/templates", f"{template_name}.json"
16
+ )
17
+
18
+ if not osp.exists(file_name):
19
+ raise ValueError(f"Can't read {file_name}")
20
+ with open(file=file_name) as fp:
21
+ self.template = json.load(fp)
22
+ if self._verbose:
23
+ print(
24
+ f"Using prompt template {template_name}: {self.template['description']}"
25
+ )
26
+
27
+ def generate_prompt(
28
+ self,
29
+ instruction: str,
30
+ input: Union[None, str] = None,
31
+ label: Union[None, str] = None,
32
+ ) -> str:
33
+ # returns the full prompt from instruction and optional input
34
+ # if a label (=response, =output) is provided, it's also appended.
35
+ if input:
36
+ res = self.template["prompt_input"].format(
37
+ instruction=instruction,
38
+ input=input,
39
+ )
40
+ else:
41
+ res = self.template["prompt_no_input"].format(
42
+ instruction=instruction,
43
+ )
44
+ if label:
45
+ res = f"{res}{label}"
46
+ if self._verbose:
47
+ print(res)
48
+ return res
49
+
50
+ def get_reponse(self, output: str) -> str:
51
+ return output.split(self.template["response_split"])[1].strip()
52
+
53
+
54
+ def tokenize(tokenizer, prompt, cutoff_len=512, add_eos_token=True):
55
+ result = tokenizer(
56
+ prompt,
57
+ truncation=True,
58
+ max_length=cutoff_len,
59
+ padding=False,
60
+ return_tensors=None,
61
+ )
62
+ if (
63
+ result["input_ids"][-1] != tokenizer.eos_token_id
64
+ and len(result["input_ids"]) < cutoff_len
65
+ and add_eos_token
66
+ ):
67
+ result["input_ids"].append(tokenizer.eos_token_id)
68
+ result["attention_mask"].append(1)
69
+
70
+ result["labels"] = result["input_ids"].copy()
71
+
72
+ return result
73
+
74
+
75
+ def generate_and_tokenize_prompt(data_point, **kwargs):
76
+ full_prompt = kwargs["prompter"].generate_prompt(
77
+ data_point["instruction"],
78
+ data_point["input"],
79
+ data_point["output"],
80
+ )
81
+ tokenized_full_prompt = tokenize(
82
+ kwargs["tokenizer"],
83
+ full_prompt,
84
+ cutoff_len=kwargs["seq_length"],
85
+ add_eos_token=True,
86
+ )
87
+ if not kwargs["train_on_inputs"]:
88
+ user_prompt = kwargs["prompter"].generate_prompt(
89
+ data_point["instruction"], data_point["input"]
90
+ )
91
+ tokenized_user_prompt = kwargs["tokenizer"](
92
+ user_prompt, add_eos_token=False
93
+ )
94
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
95
+
96
+ tokenized_full_prompt["labels"] = [
97
+ -100
98
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
99
+ user_prompt_len:
100
+ ] # could be sped up, probably
101
+ return tokenized_full_prompt
template_FL/src/fedllm/metrics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import evaluate
3
+ from rouge_score import rouge_scorer
4
+ import numpy as np
5
+ import copy
6
+ from collections import OrderedDict, Counter
7
+ from .utils import clean_output_text
8
+
9
+ def get_answer(text):
10
+ # text = text.lower()
11
+ label = text.split('Response:')[-1].strip()
12
+ return label
13
+
14
+ def check_data_state(preds, targets):
15
+ assert len(preds) == len(targets)
16
+
17
+ def get_rouge_score(predictions, targets):
18
+ check_data_state(predictions, targets)
19
+ rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
20
+ scores = {
21
+ 'rouge1': 0.0,
22
+ 'rouge2': 0.0,
23
+ 'rougeL': 0.0,
24
+ 'rougeLsum': 0.0
25
+ }
26
+ for prediction, target in zip(predictions, targets):
27
+ prediction = get_answer(clean_output_text(prediction))
28
+ target = get_answer(clean_output_text(target))
29
+ rouge_output = rouge.score(prediction=prediction, target=target)
30
+ scores['rouge1'] += round(rouge_output["rouge1"].fmeasure, 4)
31
+ scores['rouge2'] += round(rouge_output["rouge2"].fmeasure, 4)
32
+ scores['rougeL'] += round(rouge_output["rougeL"].fmeasure, 4)
33
+ scores['rougeLsum'] += round(rouge_output["rougeLsum"].fmeasure, 4)
34
+
35
+
36
+ scores['rouge1'] /= len(predictions)
37
+ scores['rouge2'] /= len(predictions)
38
+ scores['rougeL'] /= len(predictions)
39
+ scores['rougeLsum'] /= len(predictions)
40
+ return scores
41
+
42
+ def exact_match(predictions, targets):
43
+ check_data_state(predictions, targets)
44
+ predictions = [get_answer(clean_output_text(prediction)) for prediction in predictions]
45
+ targets = [get_answer(clean_output_text(target)) for target in targets]
46
+
47
+ preds, targets = np.asarray(predictions, dtype="<U16"), np.asarray(targets, dtype="<U16")
48
+
49
+ # print(preds, targets)
50
+ return {"exact_match": np.sum(preds == targets) / preds.size}
51
+
52
+ def _f1_score(prediction, target):
53
+ prediction_tokens = prediction.split()
54
+ target_tokens = target.split()
55
+ common = Counter(prediction_tokens) & Counter(target_tokens)
56
+ num_same = sum(common.values())
57
+ if num_same == 0:
58
+ return 0
59
+ precision = 1.0 * num_same / len(prediction_tokens)
60
+ recall = 1.0 * num_same / len(target_tokens)
61
+ f1 = (2 * precision * recall) / (precision + recall)
62
+ return f1
63
+
64
+ def f1(predictions, targets):
65
+ check_data_state(predictions, targets)
66
+ f1_score = 0.0
67
+ for prediction, target in zip(predictions, targets):
68
+ prediction = get_answer(clean_output_text(prediction))
69
+ target = get_answer(clean_output_text(target))
70
+ f1_score += _f1_score(prediction=prediction, target=target)
71
+
72
+ f1_score /= len(predictions)
73
+ return {'f1': f1_score}
template_FL/src/fedllm/models.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import DictConfig
6
+ from collections import OrderedDict
7
+ from peft import (
8
+ LoraConfig,
9
+ get_peft_model,
10
+ get_peft_model_state_dict,
11
+ set_peft_model_state_dict,
12
+ )
13
+ from peft.utils import prepare_model_for_kbit_training
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback, BertForSequenceClassification
15
+
16
+ from flwr.common.typing import NDArrays
17
+ from transformers.trainer_callback import TrainerControl, TrainerState
18
+ from transformers.training_args import TrainingArguments
19
+ from thop import profile
20
+ import wandb
21
+ from typing import Dict, List
22
+ import copy
23
+ import time
24
+ import numpy as np
25
+
26
+
27
+ def cosine_annealing(
28
+ current_round: int,
29
+ total_round: int,
30
+ lrate_max: float = 0.001,
31
+ lrate_min: float = 0.0,
32
+ ) -> float:
33
+ """Implement cosine annealing learning rate schedule."""
34
+
35
+ cos_inner = math.pi * current_round / total_round
36
+ return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
37
+
38
+
39
+ def get_model(model_cfg: DictConfig):
40
+ """Load model with appropriate quantization config and other optimizations.
41
+
42
+ Please refer to this example for `peft + BitsAndBytes`:
43
+ https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
44
+ """
45
+ use_cuda = torch.cuda.is_available()
46
+ device_map = torch.device("cuda:0" if use_cuda else "cpu")
47
+ if model_cfg.quantization == 4:
48
+ quantization_config = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_use_double_quant=True,
51
+ bnb_4bit_quant_type="nf4",
52
+ bnb_4bit_compute_dtype=torch.bfloat16,
53
+ )
54
+ elif model_cfg.quantization == 8:
55
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
56
+ else:
57
+ raise ValueError(
58
+ f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
59
+ )
60
+
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ model_cfg.name,
63
+ quantization_config=quantization_config,
64
+ # torch_dtype=torch.bfloat16,
65
+ attn_implementation=(
66
+ "flash_attention_2" if model_cfg.flash_attention else "eager"
67
+ ),
68
+ ).to(device_map)
69
+
70
+ if use_cuda:
71
+ model = prepare_model_for_kbit_training(
72
+ model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
73
+ )
74
+
75
+
76
+ # Get tokenizer
77
+ tokenizer = AutoTokenizer.from_pretrained(
78
+ model_cfg.name, use_fast=True, padding_side="right"
79
+ )
80
+ tokenizer.pad_token = tokenizer.eos_token
81
+
82
+ peft_config = LoraConfig(
83
+ r=model_cfg.lora.lora_r,
84
+ lora_alpha=model_cfg.lora.lora_alpha,
85
+ lora_dropout=model_cfg.lora.lora_dropout,
86
+ target_modules=model_cfg.lora.lora_target_modules.split(", "),
87
+ bias="none",
88
+ task_type="CAUSAL_LM",
89
+ )
90
+
91
+ return get_peft_model(model, peft_config), tokenizer
92
+
93
+ def get_data_influence_model(model_cfg: DictConfig):
94
+ use_cuda = torch.cuda.is_available()
95
+ device_map = torch.device("cuda:0" if use_cuda else "cpu")
96
+
97
+ # Load model with num_labels=1
98
+ model = BertForSequenceClassification.from_pretrained(
99
+ "bert-base-uncased",
100
+ num_labels=1, # Set number of labels to 1 for regression or single-class tasks
101
+ ).to(device_map)
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
104
+
105
+ if use_cuda:
106
+ model = prepare_model_for_kbit_training(
107
+ model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
108
+ )
109
+
110
+ return model, tokenizer
111
+
112
+
113
+ def set_parameters(model, parameters: NDArrays) -> None:
114
+ """Change the parameters of the model using the given ones."""
115
+ peft_state_dict_keys = get_peft_model_state_dict(model).keys()
116
+ params_dict = zip(peft_state_dict_keys, parameters)
117
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
118
+ set_peft_model_state_dict(model, state_dict)
119
+
120
+
121
+ def get_parameters(model) -> NDArrays:
122
+ """Return the parameters of the current net."""
123
+ state_dict = get_peft_model_state_dict(model)
124
+ return [val.cpu().numpy() for _, val in state_dict.items()]
125
+
126
+ def model_parameters_to_ndarrays(model):
127
+ """
128
+ Convert the parameters of a HuggingFace model into a list of NDArrays.
129
+
130
+ Args:
131
+ model (torch.nn.Module): The HuggingFace model.
132
+
133
+ Returns:
134
+ list[NDArrays]: A list of NumPy arrays representing the model's parameters.
135
+ """
136
+ ndarrays = []
137
+ for param_tensor in model.state_dict().values():
138
+ # Convert PyTorch tensor to NumPy array
139
+ ndarrays.append(param_tensor.cpu().numpy())
140
+ return ndarrays
141
+
142
+
143
+ def concatenate_models_with_marker(main_model_params: list[NDArrays],
144
+ data_influence_model_params: list[NDArrays],
145
+ marker_value: float = np.nan) -> list[NDArrays]:
146
+ """
147
+ Concatenate two models' parameters with a unique marker.
148
+
149
+ Args:
150
+ main_model_params (list[NDArrays]): Parameters of the main model as NDArrays.
151
+ data_influence_model_params (list[NDArrays]): Parameters of the data influence model as NDArrays.
152
+ marker_value (float): A unique marker value to separate the two models.
153
+
154
+ Returns:
155
+ list[NDArrays]: A single list of NDArrays with the unique marker separating the models.
156
+ """
157
+ marker = np.array([marker_value]) # Unique marker
158
+ concatenated_params = main_model_params + [marker] + data_influence_model_params
159
+ return concatenated_params
160
+
161
+
162
+ def split_models(concatenated_model: list[NDArrays]) -> tuple[list[NDArrays], list[NDArrays]]:
163
+ """Split the concatenated model back into main and data influence models."""
164
+ # Find the marker's index
165
+ marker_index = next(
166
+ (i for i, param in enumerate(concatenated_model) if np.isnan(param).all()),
167
+ -1,
168
+ )
169
+ if marker_index == -1:
170
+ raise ValueError("Marker not found in the concatenated model parameters.")
171
+
172
+ main_model = concatenated_model[:marker_index]
173
+ data_influence_model = concatenated_model[marker_index + 1 :]
174
+ return main_model, data_influence_model
175
+
176
+
177
+ def set_parameters_bert(model: BertForSequenceClassification, parameters: list[NDArrays]) -> None:
178
+ """
179
+ Set the parameters of a BertForSequenceClassification model using the given ones.
180
+
181
+ Args:
182
+ model (BertForSequenceClassification): The model whose parameters need to be updated.
183
+ parameters (list[NDArrays]): A list of NumPy arrays representing the parameters.
184
+ """
185
+ # Get the state_dict keys from the model
186
+ state_dict_keys = model.state_dict().keys()
187
+
188
+ # Ensure the number of parameters matches the model's state_dict
189
+ if len(parameters) != len(state_dict_keys):
190
+ raise ValueError(
191
+ f"Number of parameters ({len(parameters)}) does not match "
192
+ f"the number of state_dict keys ({len(state_dict_keys)})."
193
+ )
194
+
195
+ # Create an OrderedDict to update the model
196
+ params_dict = zip(state_dict_keys, parameters)
197
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
198
+
199
+ # Load the updated state_dict into the model
200
+ model.load_state_dict(state_dict)
template_FL/src/fedllm/myaggregation.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Aggregation functions for strategy implementations."""
16
+ # mypy: disallow_untyped_calls=False
17
+
18
+ from functools import partial, reduce
19
+ from typing import Any, Callable, Union
20
+
21
+ import numpy as np
22
+
23
+ from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays
24
+ from flwr.server.client_proxy import ClientProxy
25
+
26
+ from .models import split_models
27
+
28
+
29
+ def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
30
+ """Compute weighted average."""
31
+ # Calculate the total number of examples used during training
32
+ num_examples_total = sum(num_examples for (_, num_examples) in results)
33
+
34
+ # Create a list of weights, each multiplied by the related number of examples
35
+ weighted_weights = [
36
+ [layer * num_examples for layer in weights] for weights, num_examples in results
37
+ ]
38
+
39
+ # Compute average weights of each layer
40
+ weights_prime: NDArrays = [
41
+ reduce(np.add, layer_updates) / num_examples_total
42
+ for layer_updates in zip(*weighted_weights)
43
+ ]
44
+ return weights_prime
45
+
46
+
47
+ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
48
+ """Compute in-place weighted average."""
49
+ # Count total examples
50
+ num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
51
+
52
+ # Compute scaling factors for each result
53
+ scaling_factors = np.asarray(
54
+ [fit_res.num_examples / num_examples_total for _, fit_res in results]
55
+ )
56
+
57
+ def _try_inplace(
58
+ x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
59
+ ) -> NDArray:
60
+ return ( # type: ignore[no-any-return]
61
+ np_binary_op(x, y, out=x)
62
+ if np.can_cast(y, x.dtype, casting="same_kind")
63
+ else np_binary_op(x, np.array(y, x.dtype), out=x)
64
+ )
65
+
66
+ # Let's do in-place aggregation
67
+ # Get first result, then add up each other
68
+ # print(f"Param : {results[0][1].parameters}. Type: {type(results[0][1].parameters)}")
69
+ params = [
70
+ _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
71
+ for x in parameters_to_ndarrays(results[0][1].parameters)
72
+ ]
73
+
74
+ for i, (_, fit_res) in enumerate(results[1:], start=1):
75
+ res = (
76
+ _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
77
+ for x in parameters_to_ndarrays(fit_res.parameters)
78
+ )
79
+ params = [
80
+ reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
81
+ for layer_updates in zip(params, res)
82
+ ]
83
+
84
+ return params
85
+
86
+ def aggregate_inplace_mates(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
87
+ """Aggregate main model and data influence model separately."""
88
+ num_examples_total = sum(fit_res.num_examples for _, fit_res in results)
89
+ scaling_factors = [
90
+ fit_res.num_examples / num_examples_total for _, fit_res in results
91
+ ]
92
+
93
+ aggregated_main_model = None
94
+ aggregated_data_influence_model = None
95
+
96
+ for i, (_, fit_res) in enumerate(results):
97
+ # Convert parameters to NDArrays and split into main and data influence models
98
+ concatenated_params = parameters_to_ndarrays(fit_res.parameters)
99
+ main_model, data_influence_model = split_models(concatenated_params)
100
+
101
+ # Scale the models by the scaling factor
102
+ scaled_main_model = [x * scaling_factors[i] for x in main_model]
103
+ scaled_data_influence_model = [x * scaling_factors[i] for x in data_influence_model]
104
+
105
+ # Aggregate in-place
106
+ if aggregated_main_model is None:
107
+ aggregated_main_model = scaled_main_model
108
+ aggregated_data_influence_model = scaled_data_influence_model
109
+ else:
110
+ aggregated_main_model = [
111
+ x + y for x, y in zip(aggregated_main_model, scaled_main_model)
112
+ ]
113
+ aggregated_data_influence_model = [
114
+ x + y for x, y in zip(aggregated_data_influence_model, scaled_data_influence_model)
115
+ ]
116
+
117
+ return concatenate_models_with_marker(aggregated_main_model, aggregated_data_influence_model)
118
+
119
+
120
+ def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
121
+ """Compute median."""
122
+ # Create a list of weights and ignore the number of examples
123
+ weights = [weights for weights, _ in results]
124
+
125
+ # Compute median weight of each layer
126
+ median_w: NDArrays = [
127
+ np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
128
+ ]
129
+ return median_w
130
+
131
+
132
+ def aggregate_krum(
133
+ results: list[tuple[NDArrays, int]], num_malicious: int, to_keep: int
134
+ ) -> NDArrays:
135
+ """Choose one parameter vector according to the Krum function.
136
+
137
+ If to_keep is not None, then MultiKrum is applied.
138
+ """
139
+ # Create a list of weights and ignore the number of examples
140
+ weights = [weights for weights, _ in results]
141
+
142
+ # Compute distances between vectors
143
+ distance_matrix = _compute_distances(weights)
144
+
145
+ # For each client, take the n-f-2 closest parameters vectors
146
+ num_closest = max(1, len(weights) - num_malicious - 2)
147
+ closest_indices = []
148
+ for distance in distance_matrix:
149
+ closest_indices.append(
150
+ np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
151
+ )
152
+
153
+ # Compute the score for each client, that is the sum of the distances
154
+ # of the n-f-2 closest parameters vectors
155
+ scores = [
156
+ np.sum(distance_matrix[i, closest_indices[i]])
157
+ for i in range(len(distance_matrix))
158
+ ]
159
+
160
+ if to_keep > 0:
161
+ # Choose to_keep clients and return their average (MultiKrum)
162
+ best_indices = np.argsort(scores)[::-1][len(scores) - to_keep :] # noqa: E203
163
+ best_results = [results[i] for i in best_indices]
164
+ return aggregate(best_results)
165
+
166
+ # Return the model parameters that minimize the score (Krum)
167
+ return weights[np.argmin(scores)]
168
+
169
+
170
+ # pylint: disable=too-many-locals
171
+ def aggregate_bulyan(
172
+ results: list[tuple[NDArrays, int]],
173
+ num_malicious: int,
174
+ aggregation_rule: Callable, # type: ignore
175
+ **aggregation_rule_kwargs: Any,
176
+ ) -> NDArrays:
177
+ """Perform Bulyan aggregation.
178
+
179
+ Parameters
180
+ ----------
181
+ results: list[tuple[NDArrays, int]]
182
+ Weights and number of samples for each of the client.
183
+ num_malicious: int
184
+ The maximum number of malicious clients.
185
+ aggregation_rule: Callable
186
+ Byzantine resilient aggregation rule used as the first step of the Bulyan
187
+ aggregation_rule_kwargs: Any
188
+ The arguments to the aggregation rule.
189
+
190
+ Returns
191
+ -------
192
+ aggregated_parameters: NDArrays
193
+ Aggregated parameters according to the Bulyan strategy.
194
+ """
195
+ byzantine_resilient_single_ret_model_aggregation = [aggregate_krum]
196
+ # also GeoMed (but not implemented yet)
197
+ byzantine_resilient_many_return_models_aggregation = [] # type: ignore
198
+ # Brute, Medoid (but not implemented yet)
199
+
200
+ num_clients = len(results)
201
+ if num_clients < 4 * num_malicious + 3:
202
+ raise ValueError(
203
+ "The Bulyan aggregation requires then number of clients to be greater or "
204
+ "equal to the 4 * num_malicious + 3. This is the assumption of this method."
205
+ "It is needed to ensure that the method reduces the attacker's leeway to "
206
+ "the one proved in the paper."
207
+ )
208
+ selected_models_set: list[tuple[NDArrays, int]] = []
209
+
210
+ theta = len(results) - 2 * num_malicious
211
+ beta = theta - 2 * num_malicious
212
+
213
+ for _ in range(theta):
214
+ best_model = aggregation_rule(
215
+ results=results, num_malicious=num_malicious, **aggregation_rule_kwargs
216
+ )
217
+ list_of_weights = [weights for weights, num_samples in results]
218
+ # This group gives exact result
219
+ if aggregation_rule in byzantine_resilient_single_ret_model_aggregation:
220
+ best_idx = _find_reference_weights(best_model, list_of_weights)
221
+ # This group requires finding the closest model to the returned one
222
+ # (weights distance wise)
223
+ elif aggregation_rule in byzantine_resilient_many_return_models_aggregation:
224
+ # when different aggregation strategies available
225
+ # write a function to find the closest model
226
+ raise NotImplementedError(
227
+ "aggregate_bulyan currently does not support the aggregation rules that"
228
+ " return many models as results. "
229
+ "Such aggregation rules are currently not available in Flower."
230
+ )
231
+ else:
232
+ raise ValueError(
233
+ "The given aggregation rule is not added as Byzantine resilient. "
234
+ "Please choose from Byzantine resilient rules."
235
+ )
236
+
237
+ selected_models_set.append(results[best_idx])
238
+
239
+ # remove idx from tracker and weights_results
240
+ results.pop(best_idx)
241
+
242
+ # Compute median parameter vector across selected_models_set
243
+ median_vect = aggregate_median(selected_models_set)
244
+
245
+ # Take the averaged beta parameters of the closest distance to the median
246
+ # (coordinate-wise)
247
+ parameters_aggregated = _aggregate_n_closest_weights(
248
+ median_vect, selected_models_set, beta_closest=beta
249
+ )
250
+ return parameters_aggregated
251
+
252
+
253
+ def weighted_loss_avg(results: list[tuple[int, float]]) -> float:
254
+ """Aggregate evaluation results obtained from multiple clients."""
255
+ num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results)
256
+ weighted_losses = [num_examples * loss for num_examples, loss in results]
257
+ return sum(weighted_losses) / num_total_evaluation_examples
258
+
259
+
260
+ def aggregate_qffl(
261
+ parameters: NDArrays, deltas: list[NDArrays], hs_fll: list[NDArrays]
262
+ ) -> NDArrays:
263
+ """Compute weighted average based on Q-FFL paper."""
264
+ demominator: float = np.sum(np.asarray(hs_fll))
265
+ scaled_deltas = []
266
+ for client_delta in deltas:
267
+ scaled_deltas.append([layer * 1.0 / demominator for layer in client_delta])
268
+ updates = []
269
+ for i in range(len(deltas[0])):
270
+ tmp = scaled_deltas[0][i]
271
+ for j in range(1, len(deltas)):
272
+ tmp += scaled_deltas[j][i]
273
+ updates.append(tmp)
274
+ new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
275
+ return new_parameters
276
+
277
+
278
+ def _compute_distances(weights: list[NDArrays]) -> NDArray:
279
+ """Compute distances between vectors.
280
+
281
+ Input: weights - list of weights vectors
282
+ Output: distances - matrix distance_matrix of squared distances between the vectors
283
+ """
284
+ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights])
285
+ distance_matrix = np.zeros((len(weights), len(weights)))
286
+ for i, flat_w_i in enumerate(flat_w):
287
+ for j, flat_w_j in enumerate(flat_w):
288
+ delta = flat_w_i - flat_w_j
289
+ norm = np.linalg.norm(delta)
290
+ distance_matrix[i, j] = norm**2
291
+ return distance_matrix
292
+
293
+
294
+ def _trim_mean(array: NDArray, proportiontocut: float) -> NDArray:
295
+ """Compute trimmed mean along axis=0.
296
+
297
+ It is based on the scipy implementation.
298
+
299
+ https://docs.scipy.org/doc/scipy/reference/generated/
300
+ scipy.stats.trim_mean.html.
301
+ """
302
+ axis = 0
303
+ nobs = array.shape[axis]
304
+ lowercut = int(proportiontocut * nobs)
305
+ uppercut = nobs - lowercut
306
+ if lowercut > uppercut:
307
+ raise ValueError("Proportion too big.")
308
+
309
+ atmp = np.partition(array, (lowercut, uppercut - 1), axis)
310
+
311
+ slice_list = [slice(None)] * atmp.ndim
312
+ slice_list[axis] = slice(lowercut, uppercut)
313
+ result: NDArray = np.mean(atmp[tuple(slice_list)], axis=axis)
314
+ return result
315
+
316
+
317
+ def aggregate_trimmed_avg(
318
+ results: list[tuple[NDArrays, int]], proportiontocut: float
319
+ ) -> NDArrays:
320
+ """Compute trimmed average."""
321
+ # Create a list of weights and ignore the number of examples
322
+ weights = [weights for weights, _ in results]
323
+
324
+ trimmed_w: NDArrays = [
325
+ _trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
326
+ for layer in zip(*weights)
327
+ ]
328
+
329
+ return trimmed_w
330
+
331
+
332
+ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
333
+ """Check if weights are the same."""
334
+ if len(weights1) != len(weights2):
335
+ return False
336
+ return all(
337
+ np.array_equal(layer_weights1, layer_weights2)
338
+ for layer_weights1, layer_weights2 in zip(weights1, weights2)
339
+ )
340
+
341
+
342
+ def _find_reference_weights(
343
+ reference_weights: NDArrays, list_of_weights: list[NDArrays]
344
+ ) -> int:
345
+ """Find the reference weights by looping through the `list_of_weights`.
346
+
347
+ Raise Error if the reference weights is not found.
348
+
349
+ Parameters
350
+ ----------
351
+ reference_weights: NDArrays
352
+ Weights that will be searched for.
353
+ list_of_weights: list[NDArrays]
354
+ list of weights that will be searched through.
355
+
356
+ Returns
357
+ -------
358
+ index: int
359
+ The index of `reference_weights` in the `list_of_weights`.
360
+
361
+ Raises
362
+ ------
363
+ ValueError
364
+ If `reference_weights` is not found in `list_of_weights`.
365
+ """
366
+ for idx, weights in enumerate(list_of_weights):
367
+ if _check_weights_equality(reference_weights, weights):
368
+ return idx
369
+ raise ValueError("The reference weights not found in list_of_weights.")
370
+
371
+
372
+ def _aggregate_n_closest_weights(
373
+ reference_weights: NDArrays, results: list[tuple[NDArrays, int]], beta_closest: int
374
+ ) -> NDArrays:
375
+ """Calculate element-wise mean of the `N` closest values.
376
+
377
+ Note, each i-th coordinate of the result weight is the average of the beta_closest
378
+ -ith coordinates to the reference weights
379
+
380
+
381
+ Parameters
382
+ ----------
383
+ reference_weights: NDArrays
384
+ The weights from which the distances will be computed
385
+ results: list[tuple[NDArrays, int]]
386
+ The weights from models
387
+ beta_closest: int
388
+ The number of the closest distance weights that will be averaged
389
+
390
+ Returns
391
+ -------
392
+ aggregated_weights: NDArrays
393
+ Averaged (element-wise) beta weights that have the closest distance to
394
+ reference weights
395
+ """
396
+ list_of_weights = [weights for weights, num_examples in results]
397
+ aggregated_weights = []
398
+
399
+ for layer_id, layer_weights in enumerate(reference_weights):
400
+ other_weights_layer_list = []
401
+ for other_w in list_of_weights:
402
+ other_weights_layer = other_w[layer_id]
403
+ other_weights_layer_list.append(other_weights_layer)
404
+ other_weights_layer_np = np.array(other_weights_layer_list)
405
+ diff_np = np.abs(layer_weights - other_weights_layer_np)
406
+ # Create indices of the smallest differences
407
+ # We do not need the exact order but just the beta closest weights
408
+ # therefore np.argpartition is used instead of np.argsort
409
+ indices = np.argpartition(diff_np, kth=beta_closest - 1, axis=0)
410
+ # Take the weights (coordinate-wise) corresponding to the beta of the
411
+ # closest distances
412
+ beta_closest_weights = np.take_along_axis(
413
+ other_weights_layer_np, indices=indices, axis=0
414
+ )[:beta_closest]
415
+ aggregated_weights.append(np.mean(beta_closest_weights, axis=0))
416
+ return aggregated_weights
template_FL/src/fedllm/myfedavg.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Federated Averaging (FedAvg) [McMahan et al., 2016] strategy.
16
+
17
+ Paper: arxiv.org/abs/1602.05629
18
+ """
19
+
20
+
21
+ from logging import WARNING
22
+ from typing import Callable, Optional, Union
23
+
24
+ from flwr.common import (
25
+ EvaluateIns,
26
+ EvaluateRes,
27
+ FitIns,
28
+ FitRes,
29
+ MetricsAggregationFn,
30
+ NDArrays,
31
+ Parameters,
32
+ Scalar,
33
+ ndarrays_to_parameters,
34
+ parameters_to_ndarrays,
35
+ )
36
+ from flwr.common.logger import log
37
+ from flwr.server.client_manager import ClientManager
38
+ from flwr.server.client_proxy import ClientProxy
39
+
40
+ from .myaggregation import aggregate, aggregate_inplace, weighted_loss_avg
41
+ from flwr.server.strategy import Strategy
42
+
43
+ WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
44
+ Setting `min_available_clients` lower than `min_fit_clients` or
45
+ `min_evaluate_clients` can cause the server to fail when there are too few clients
46
+ connected to the server. `min_available_clients` must be set to a value larger
47
+ than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
48
+ """
49
+
50
+ client_id_idx = {}
51
+
52
+ # pylint: disable=line-too-long
53
+ class FedAvg(Strategy):
54
+ """Federated Averaging strategy.
55
+
56
+ Implementation based on https://arxiv.org/abs/1602.05629
57
+
58
+ Parameters
59
+ ----------
60
+ fraction_fit : float, optional
61
+ Fraction of clients used during training. In case `min_fit_clients`
62
+ is larger than `fraction_fit * available_clients`, `min_fit_clients`
63
+ will still be sampled. Defaults to 1.0.
64
+ fraction_evaluate : float, optional
65
+ Fraction of clients used during validation. In case `min_evaluate_clients`
66
+ is larger than `fraction_evaluate * available_clients`,
67
+ `min_evaluate_clients` will still be sampled. Defaults to 1.0.
68
+ min_fit_clients : int, optional
69
+ Minimum number of clients used during training. Defaults to 2.
70
+ min_evaluate_clients : int, optional
71
+ Minimum number of clients used during validation. Defaults to 2.
72
+ min_available_clients : int, optional
73
+ Minimum number of total clients in the system. Defaults to 2.
74
+ evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]]
75
+ Optional function used for validation. Defaults to None.
76
+ on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
77
+ Function used to configure training. Defaults to None.
78
+ on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
79
+ Function used to configure validation. Defaults to None.
80
+ accept_failures : bool, optional
81
+ Whether or not accept rounds containing failures. Defaults to True.
82
+ initial_parameters : Parameters, optional
83
+ Initial global model parameters.
84
+ fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
85
+ Metrics aggregation function, optional.
86
+ evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
87
+ Metrics aggregation function, optional.
88
+ inplace : bool (default: True)
89
+ Enable (True) or disable (False) in-place aggregation of model updates.
90
+ """
91
+
92
+ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
93
+ def __init__(
94
+ self,
95
+ *,
96
+ fraction_fit: float = 1.0,
97
+ fraction_evaluate: float = 1.0,
98
+ min_fit_clients: int = 2,
99
+ min_evaluate_clients: int = 2,
100
+ min_available_clients: int = 2,
101
+ evaluate_fn: Optional[
102
+ Callable[
103
+ [int, NDArrays, dict[str, Scalar]],
104
+ Optional[tuple[float, dict[str, Scalar]]],
105
+ ]
106
+ ] = None,
107
+ on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
108
+ on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
109
+ accept_failures: bool = True,
110
+ initial_parameters: Optional[Parameters] = None,
111
+ fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
112
+ evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
113
+ inplace: bool = True,
114
+ use_mates: bool = False,
115
+ ) -> None:
116
+ super().__init__()
117
+
118
+ if (
119
+ min_fit_clients > min_available_clients
120
+ or min_evaluate_clients > min_available_clients
121
+ ):
122
+ log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)
123
+
124
+ self.fraction_fit = fraction_fit
125
+ self.fraction_evaluate = fraction_evaluate
126
+ self.min_fit_clients = min_fit_clients
127
+ self.min_evaluate_clients = min_evaluate_clients
128
+ self.min_available_clients = min_available_clients
129
+ self.evaluate_fn = evaluate_fn
130
+ self.on_fit_config_fn = on_fit_config_fn
131
+ self.on_evaluate_config_fn = on_evaluate_config_fn
132
+ self.accept_failures = accept_failures
133
+ self.initial_parameters = initial_parameters
134
+ self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
135
+ self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
136
+ self.inplace = inplace
137
+ self.use_mates = use_mates
138
+
139
+ def __repr__(self) -> str:
140
+ """Compute a string representation of the strategy."""
141
+ rep = f"FedAvg(accept_failures={self.accept_failures})"
142
+ return rep
143
+
144
+ def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]:
145
+ """Return the sample size and the required number of available clients."""
146
+ num_clients = int(num_available_clients * self.fraction_fit)
147
+ return max(num_clients, self.min_fit_clients), self.min_available_clients
148
+
149
+ def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]:
150
+ """Use a fraction of available clients for evaluation."""
151
+ num_clients = int(num_available_clients * self.fraction_evaluate)
152
+ return max(num_clients, self.min_evaluate_clients), self.min_available_clients
153
+
154
+ def initialize_parameters(
155
+ self, client_manager: ClientManager
156
+ ) -> Optional[Parameters]:
157
+ """Initialize global model parameters."""
158
+ initial_parameters = self.initial_parameters
159
+ self.initial_parameters = None # Don't keep initial parameters in memory
160
+ return initial_parameters
161
+
162
+ def evaluate(
163
+ self, server_round: int, parameters: Parameters
164
+ ) -> Optional[tuple[float, dict[str, Scalar]]]:
165
+ """Evaluate model parameters using an evaluation function."""
166
+ if self.evaluate_fn is None:
167
+ # No evaluation function provided
168
+ return None
169
+ parameters_ndarrays = parameters_to_ndarrays(parameters)
170
+ eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
171
+ if eval_res is None:
172
+ return None
173
+ loss, metrics = eval_res
174
+ return loss, metrics
175
+
176
+ def configure_fit(
177
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
178
+ ) -> list[tuple[ClientProxy, FitIns]]:
179
+ """Configure the next round of training."""
180
+ config = {}
181
+ if self.on_fit_config_fn is not None:
182
+ # Custom fit config function provided
183
+ config = self.on_fit_config_fn(server_round)
184
+ fit_ins = FitIns(parameters, config)
185
+
186
+ if not client_id_idx:
187
+ for i, (client_id, _) in enumerate(client_manager.clients.items()):
188
+ client_id_idx[client_id] = i
189
+
190
+
191
+ # Sample clients
192
+ sample_size, min_num_clients = self.num_fit_clients(
193
+ client_manager.num_available()
194
+ )
195
+ clients = client_manager.sample(
196
+ num_clients=sample_size, min_num_clients=min_num_clients
197
+ )
198
+
199
+ # Return client/config pairs
200
+ return [(client, fit_ins) for client in clients]
201
+
202
+ def configure_evaluate(
203
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
204
+ ) -> list[tuple[ClientProxy, EvaluateIns]]:
205
+ """Configure the next round of evaluation."""
206
+ # Do not configure federated evaluation if fraction eval is 0.
207
+ if self.fraction_evaluate == 0.0:
208
+ return []
209
+
210
+ # Parameters and config
211
+ config = {}
212
+ if self.on_evaluate_config_fn is not None:
213
+ # Custom evaluation config function provided
214
+ config = self.on_evaluate_config_fn(server_round)
215
+ evaluate_ins = EvaluateIns(parameters, config)
216
+
217
+ # Sample clients
218
+ sample_size, min_num_clients = self.num_evaluation_clients(
219
+ client_manager.num_available()
220
+ )
221
+ clients = client_manager.sample(
222
+ num_clients=sample_size, min_num_clients=min_num_clients
223
+ )
224
+
225
+ # Return client/config pairs
226
+ return [(client, evaluate_ins) for client in clients]
227
+
228
+ def aggregate_fit(
229
+ self,
230
+ server_round: int,
231
+ results: list[tuple[ClientProxy, FitRes]],
232
+ failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
233
+ ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
234
+ """Aggregate fit results using weighted average."""
235
+ if not results:
236
+ return None, {}
237
+ # Do not aggregate if there are failures and failures are not accepted
238
+ if not self.accept_failures and failures:
239
+ return None, {}
240
+
241
+ if self.inplace:
242
+ # Does in-place weighted average of results
243
+ aggregated_ndarrays = aggregate_inplace(results)
244
+ elif self.use_mates:
245
+ aggregated_ndarrays = aggregate_inplace_mates(results)
246
+ else:
247
+ # Convert results
248
+ weights_results = [
249
+ (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
250
+ for _, fit_res in results
251
+ ]
252
+ aggregated_ndarrays = aggregate(weights_results)
253
+
254
+ parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)
255
+
256
+ # Aggregate custom metrics if aggregation fn was provided
257
+ metrics_aggregated = {}
258
+ if self.fit_metrics_aggregation_fn:
259
+ fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
260
+ metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
261
+ elif server_round == 1: # Only log this warning once
262
+ log(WARNING, "No fit_metrics_aggregation_fn provided")
263
+
264
+ return parameters_aggregated, metrics_aggregated
265
+
266
+ def aggregate_evaluate(
267
+ self,
268
+ server_round: int,
269
+ results: list[tuple[ClientProxy, EvaluateRes]],
270
+ failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
271
+ ) -> tuple[Optional[float], dict[str, Scalar]]:
272
+ """Aggregate evaluation losses using weighted average."""
273
+ if not results:
274
+ return None, {}
275
+ # Do not aggregate if there are failures and failures are not accepted
276
+ if not self.accept_failures and failures:
277
+ return None, {}
278
+
279
+ # Aggregate loss
280
+ loss_aggregated = weighted_loss_avg(
281
+ [
282
+ (evaluate_res.num_examples, evaluate_res.loss)
283
+ for _, evaluate_res in results
284
+ ]
285
+ )
286
+
287
+ # Aggregate custom metrics if aggregation fn was provided
288
+ metrics_aggregated = {}
289
+ if self.evaluate_metrics_aggregation_fn:
290
+ eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
291
+ metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
292
+ elif server_round == 1: # Only log this warning once
293
+ log(WARNING, "No evaluate_metrics_aggregation_fn provided")
294
+
295
+ return loss_aggregated, metrics_aggregated
template_FL/src/fedllm/server_app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """flowertune-llm: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ import torch
5
+ import wandb
6
+ import numpy as np
7
+ from dotenv import load_dotenv
8
+ from datetime import datetime
9
+ from tqdm import tqdm
10
+
11
+ from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding, TrainingArguments, Trainer, GenerationConfig
12
+ from .trainer import ManualTrainer
13
+ from transformers.integrations import WandbCallback
14
+ from torch.utils.data import DataLoader
15
+ from flwr.common import Context, ndarrays_to_parameters
16
+ from flwr.common.config import unflatten_dict
17
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
18
+ # from flwr.server.strategy import FedAvg
19
+ from omegaconf import DictConfig
20
+
21
+ from .models import *
22
+ from .dataset import replace_keys
23
+ from .myfedavg import FedAvg
24
+ from .data_domains import global_test_set_hete
25
+ from .make_data import Prompter, generate_and_tokenize_prompt
26
+ from .metrics import exact_match, f1, get_rouge_score
27
+
28
+ from datasets import load_dataset, Dataset
29
+ from sklearn.model_selection import train_test_split
30
+
31
+
32
+ load_dotenv(".env")
33
+
34
+ os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
35
+ os.environ["WANDB_NAME"] = os.getenv("WANDB_NAME")
36
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
37
+ # os.environ["WANDB_LOG_MODEL"] = "checkpoint"
38
+
39
+ class LLMSampleCB(WandbCallback):
40
+ def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
41
+ "A CallBack to log samples a wandb.Table during training"
42
+ super().__init__()
43
+ # self._log_model = log_model
44
+ self.task = task
45
+ self.sample_dataset = test_dataset.shuffle().select(range(num_samples))
46
+ self.model, self.tokenizer = trainer.model, trainer.tokenizer
47
+ self.max_new_tokens = max_new_tokens
48
+ self.gen_config = GenerationConfig.from_pretrained(trainer.model.name_or_path,
49
+ max_new_tokens=max_new_tokens)
50
+ def generate(self, prompt):
51
+ tokenized_prompt = self.tokenizer(
52
+ prompt,
53
+ # padding='max_length', max_length=self.max_new_tokens,
54
+ return_tensors='pt'
55
+ )
56
+ input_ids = tokenized_prompt['input_ids'].to('cuda:0')
57
+
58
+ with torch.inference_mode():
59
+ output = self.model.generate(input_ids, generation_config=self.gen_config)
60
+ return self.tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)
61
+
62
+ def samples_table(self, examples):
63
+ "Create a wandb.Table to store the generations"
64
+ records_table = wandb.Table(columns=["input", "prediction", "label", "task"] + list(self.gen_config.to_dict().keys()))
65
+ for example in tqdm(examples, leave=False):
66
+ instruction = example["instruction"]
67
+ inputt = example["input"]
68
+ output = example['output']
69
+ prompt = ''
70
+ if inputt == '':
71
+ prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Response: """
72
+ else:
73
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Input: {inputt} ### Response:"""
74
+
75
+ generation = self.generate(prompt=prompt)
76
+ records_table.add_data(prompt, generation, output, self.task, *list(self.gen_config.to_dict().values()))
77
+ return records_table
78
+
79
+ def on_evaluate(self, args, state, control, **kwargs):
80
+ "Log the wandb.Table after calling trainer.evaluate"
81
+ super().on_evaluate(args, state, control, **kwargs)
82
+ records_table = self.samples_table(self.sample_dataset)
83
+ self._wandb.log({"sample_predictions":records_table})
84
+
85
+
86
+
87
+ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, task):
88
+
89
+ wandb.init(
90
+ project='FL@CSS25',
91
+ name=f'global_eval_round_{sround}',
92
+ id=f"round_{sround}",
93
+ resume="allow",
94
+ reinit=True,
95
+ # settings=wandb.Settings(start_method="thread")
96
+ )
97
+
98
+ def compute_metrics(pred):
99
+ labels_ids = pred['label_ids']
100
+ labels_ids[labels_ids == -100] = 1829
101
+ pred_ids = pred['predictions']
102
+
103
+ # all unnecessary tokens are removed
104
+ pred_str = tokenizer.batch_decode(
105
+ pred_ids, skip_special_tokens=True
106
+ )
107
+ label_str = tokenizer.batch_decode(
108
+ labels_ids, skip_special_tokens=True
109
+ )
110
+ return {
111
+ **get_rouge_score(predictions=pred_str, targets=label_str),
112
+ **f1(predictions=pred_str, targets=label_str),
113
+ }
114
+
115
+ data_collator = DataCollatorForSeq2Seq(
116
+ tokenizer,
117
+ pad_to_multiple_of=8,
118
+ return_tensors="pt",
119
+ padding=True,
120
+ )
121
+
122
+ testset = (
123
+ dataset
124
+ .shuffle()
125
+ .map(
126
+ lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
127
+ num_proc=8,
128
+ )
129
+ )
130
+
131
+ training_arguments = TrainingArguments(**train_cfg.training_arguments)
132
+ training_arguments.output_dir = './global_results'
133
+ training_arguments.logging_dir='./global_logs'
134
+ # training_arguments.run_name = f'global_eval_round_{sround}'
135
+
136
+
137
+
138
+ # # Constuct baseline Trainer
139
+ # trainer = Trainer(
140
+ # model=model,
141
+ # eval_dataset=testset.select(range(10)),
142
+ # args=training_arguments,
143
+ # data_collator=data_collator,
144
+ # compute_metrics=compute_metrics,
145
+ # tokenizer=tokenizer
146
+ # )
147
+
148
+ mates_args.state = False
149
+
150
+ trainer = ManualTrainer(
151
+ model= model,
152
+ tokenizer = tokenizer,
153
+ train_dataset=None,
154
+ val_dataset=testset.select(range(10)),
155
+ holdout_dataset=None,
156
+ reference_dataset=None,
157
+ args=training_arguments,
158
+ data_collator=data_collator,
159
+ compute_metrics=compute_metrics,
160
+ mates_args=mates_args,
161
+ data_influence_model=None,
162
+ data_influence_tokenizer=None,
163
+ )
164
+
165
+ # Do local training
166
+ results = trainer.evaluate(wandb_sample=True)
167
+
168
+ # Extract loss, predictions, and labels
169
+
170
+ eval_loss = results[f"eval_loss"]
171
+ eval_metrics = {
172
+ f'{task}_f1': results["f1"],
173
+ f'{task}_rouge1': results["rouge1"],
174
+ f'{task}_rouge2': results['rouge2'],
175
+ f'{task}_rougeL': results['rougeL'],
176
+ f'{task}_rougeLsum': results['rougeLsum'],
177
+ }
178
+
179
+ wandb.finish()
180
+
181
+ return eval_loss, eval_metrics
182
+
183
+
184
+
185
+ # Get function that will be executed by the strategy's evaluate() method
186
+ # Here we use it to save global model checkpoints
187
+
188
+ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args):
189
+ """Return an evaluation function for saving global model."""
190
+
191
+ def evaluate(server_round: int, parameters, config):
192
+ # Save model
193
+ total_loss, result_metric = 0, {}
194
+ prompter = Prompter(train_cfg.prompt_template_name, train_cfg.verbose)
195
+ if server_round != 0 and (
196
+ server_round == total_round or server_round % save_every_round == 0
197
+ ):
198
+ # Init model
199
+ main_model_params, _ = split_models(parameters)
200
+ model, tokenizer = get_model(model_cfg)
201
+ set_parameters(model, main_model_params)
202
+
203
+ tmp_dict = {
204
+ "prompter": prompter,
205
+ "seq_length": train_cfg.seq_length,
206
+ "train_on_inputs": train_cfg.train_on_inputs,
207
+ "tokenizer": tokenizer,
208
+ }
209
+ if dataset_cfg.type == 'homo':
210
+ ds = load_dataset(dataset_cfg.name)
211
+ _, test = train_test_split(
212
+ ds, test_size=0.09, shuffle=True, random_state=42
213
+ )
214
+ global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
215
+
216
+ loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, 'homo')
217
+ total_loss = loss
218
+ result_metric = {'homo_f1': metrics['homo_f1']}
219
+ else:
220
+ (
221
+ list_loss, list_f1,
222
+ list_rouge1, list_rouge2,
223
+ list_rougeL, list_rougeLsum
224
+ ) = [], {}, {}, {}, {}, {}
225
+
226
+ for task in ['general', 'finance', 'math', 'medical', 'code']:
227
+ ds = global_test_set_hete[task]
228
+ loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, task)
229
+ list_loss.append(loss)
230
+
231
+ list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
232
+ # list_rouge1[f'{task}_rouge1'] = metrics['rouge1']
233
+ # list_rouge2[f'{task}_rouge2'] = metrics['rouge2']
234
+ # list_rougeL[f'{task}_rougeL'] = metrics['rougeL']
235
+ # list_rougeLsum[f'{task}_rougeLsum'] = metrics['rougeLsum']
236
+
237
+
238
+ total_loss = sum(list_loss) / len(list_loss)
239
+ avg_f1 = sum([v for k, v in list_f1.items()]) / len(list_f1)
240
+ result_metric = {**list_f1, 'avg_hete_f1': avg_f1}
241
+
242
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
243
+
244
+ return total_loss, result_metric
245
+
246
+ return evaluate
247
+
248
+
249
+ def get_on_fit_config(save_path):
250
+ """Return a function that will be used to construct the config that the client's
251
+ fit() method will receive."""
252
+
253
+ def fit_config_fn(server_round: int):
254
+ fit_config = {}
255
+ fit_config["current_round"] = server_round
256
+ fit_config["save_path"] = save_path
257
+ return fit_config
258
+
259
+ return fit_config_fn
260
+
261
+
262
+ def fit_weighted_average(metrics):
263
+ """Aggregate (federated) evaluation metrics."""
264
+ # Multiply accuracy of each client by number of examples used
265
+ losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
266
+ total_flops = [m["flops"] for num_examples, m in metrics]
267
+ examples = [num_examples for num_examples, _ in metrics]
268
+
269
+ # Aggregate and return custom metric (weighted average)
270
+ return {"train_loss": round(sum(losses) / sum(examples), 3), "total_flops": f"{sum(total_flops)/1e12:.2f}T"}
271
+
272
+
273
+ def server_fn(context: Context):
274
+ """Construct components that set the ServerApp behaviour."""
275
+ # Create output directory given current timestamp
276
+ current_time = datetime.now()
277
+ folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
278
+ save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
279
+ os.makedirs(save_path, exist_ok=True)
280
+
281
+ # Read from config
282
+ num_rounds = context.run_config["num-server-rounds"]
283
+ num_nodes = context.run_config['num-supernodes']
284
+ cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
285
+
286
+ # Get initial model weights
287
+ init_model, tokenizer = get_model(cfg.model)
288
+ init_model_parameters = get_parameters(init_model)
289
+ init_model_parameters = ndarrays_to_parameters(init_model_parameters)
290
+
291
+ # Define strategy
292
+ strategy = FedAvg(
293
+ fraction_fit=cfg.train.strategy.fraction_fit,
294
+ fraction_evaluate=cfg.train.strategy.fraction_evaluate,
295
+ on_fit_config_fn=get_on_fit_config(save_path),
296
+ fit_metrics_aggregation_fn=fit_weighted_average,
297
+ initial_parameters=init_model_parameters,
298
+ evaluate_fn=get_evaluate_fn(
299
+ cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates
300
+ ),
301
+ use_mates=cfg.mates.state
302
+ )
303
+ config = ServerConfig(num_rounds=num_rounds)
304
+
305
+ return ServerAppComponents(strategy=strategy, config=config)
306
+
307
+
308
+ # Flower ServerApp
309
+ app = ServerApp(server_fn=server_fn)
template_FL/src/fedllm/skipbert/__init__.py ADDED
File without changes
template_FL/src/fedllm/skipbert/modeling.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SkipBERT modeling"""
2
+
3
+ from __future__ import absolute_import, division, print_function, unicode_literals
4
+
5
+ import copy
6
+ import json
7
+ import math
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn
16
+ from typing import List, Optional, Tuple, Union
17
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
18
+
19
+ import transformers
20
+ from transformers import BertPreTrainedModel, BertModel
21
+ from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertLayer
22
+ from transformers.models.bert.modeling_bert import BertPreTrainingHeads
23
+ from transformers.modeling_outputs import SequenceClassifierOutput
24
+ from . import plot
25
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, add_code_sample_docstrings
26
+
27
+ import logging
28
+ logger = logging.getLogger(__name__)
29
+
30
+ logger.warn('Hacking BertSelfAttention! Now it returns attention scores rather than probabilities.')
31
+
32
+ class BertSelfAttention(transformers.models.bert.modeling_bert.BertSelfAttention):
33
+
34
+ def forward(
35
+ self,
36
+ hidden_states,
37
+ attention_mask=None,
38
+ head_mask=None,
39
+ encoder_hidden_states=None,
40
+ encoder_attention_mask=None,
41
+ past_key_value=None,
42
+ output_attentions=False,
43
+ ):
44
+
45
+ device = hidden_states.device
46
+ mixed_query_layer = self.query(hidden_states)
47
+
48
+ # most codes are copied from transformers v4.3.3
49
+
50
+ is_cross_attention = encoder_hidden_states is not None
51
+
52
+ if is_cross_attention and past_key_value is not None:
53
+ # reuse k,v, cross_attentions
54
+ key_layer = past_key_value[0]
55
+ value_layer = past_key_value[1]
56
+ attention_mask = encoder_attention_mask
57
+ elif is_cross_attention:
58
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
59
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
60
+ attention_mask = encoder_attention_mask
61
+ elif past_key_value is not None:
62
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
63
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
64
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
65
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
66
+ else:
67
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
68
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
69
+
70
+ query_layer = self.transpose_for_scores(mixed_query_layer)
71
+
72
+ if self.is_decoder:
73
+ past_key_value = (key_layer, value_layer)
74
+
75
+ # Take the dot product between "query" and "key" to get the raw attention scores.
76
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
77
+
78
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
79
+ seq_length = hidden_states.size()[1]
80
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
81
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
82
+ distance = position_ids_l - position_ids_r
83
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
84
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
85
+
86
+ if self.position_embedding_type == "relative_key":
87
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
88
+ attention_scores = attention_scores + relative_position_scores
89
+ elif self.position_embedding_type == "relative_key_query":
90
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
91
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
92
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
93
+
94
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
95
+ if attention_mask is not None:
96
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
97
+ attention_scores = attention_scores.to(device) + attention_mask.to(device)
98
+
99
+ # Normalize the attention scores to probabilities.
100
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
101
+
102
+ # This is actually dropping out entire tokens to attend to, which might
103
+ # seem a bit unusual, but is taken from the original Transformer paper.
104
+ attention_probs = self.dropout(attention_probs)
105
+
106
+ # Mask heads if we want to
107
+ if head_mask is not None:
108
+ attention_probs = attention_probs * head_mask
109
+ #attention_scores = attention_scores * head_mask
110
+
111
+ context_layer = torch.matmul(attention_probs, value_layer)
112
+
113
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
114
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
115
+ context_layer = context_layer.view(*new_context_layer_shape)
116
+
117
+ outputs = (context_layer, attention_scores) if output_attentions else (context_layer,) # hacked: replace attention_probs with attention_scores
118
+
119
+ if self.is_decoder:
120
+ outputs = outputs + (past_key_value,)
121
+ return outputs
122
+
123
+ transformers.models.bert.modeling_bert.BertSelfAttention = BertSelfAttention
124
+
125
+
126
+
127
+ class BertForPreTraining(BertPreTrainedModel):
128
+ def __init__(self, config):
129
+ super().__init__(config)
130
+ fit_size = getattr(config, 'fit_size', 768)
131
+ self.bert = BertModel(config)
132
+ self.cls = BertPreTrainingHeads(config)
133
+ self.fit_denses = nn.ModuleList(
134
+ [nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers+1)]
135
+ )
136
+
137
+ def forward(self, input_ids, token_type_ids=None,
138
+ attention_mask=None, masked_lm_labels=None,
139
+ next_sentence_label=None, labels=None,
140
+ output_attentions=True, output_hidden_states=True,):
141
+ outputs = self.bert(
142
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
143
+ sequence_output, att_output, pooled_output = outputs.hidden_states, outputs.attentions, outputs.pooler_output
144
+ tmp = []
145
+ for s_id, sequence_layer in enumerate(sequence_output):
146
+ tmp.append(self.fit_denses[s_id](sequence_layer))
147
+ sequence_output = tmp
148
+
149
+ return att_output, sequence_output
150
+
151
+
152
+
153
+
154
+ # class BertForSequenceClassification(BertPreTrainedModel):
155
+ # def __init__(self, config, do_fit=False, share_param=True):
156
+ # super().__init__(config)
157
+ # num_labels = config.num_labels
158
+ # self.hidden_size = config.hidden_size
159
+ # self.num_labels = num_labels
160
+ # self.bert = BertModel(config)
161
+ # self.dropout = nn.Dropout(
162
+ # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
163
+ # self.classifier = nn.Linear(config.hidden_size, num_labels)
164
+
165
+ # self.do_fit, self.share_param = do_fit, share_param
166
+ # if self.do_fit:
167
+ # fit_size = getattr(config, 'fit_size', 768)
168
+ # self.fit_size = fit_size
169
+ # if self.share_param:
170
+ # self.fit_dense = nn.Linear(config.hidden_size, fit_size)
171
+ # else:
172
+ # self.fit_denses = nn.ModuleList(
173
+ # [nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
174
+ # )
175
+
176
+ # def do_fit_dense(self, sequence_output):
177
+
178
+ # tmp = []
179
+ # if self.do_fit:
180
+ # for s_id, sequence_layer in enumerate(sequence_output):
181
+ # if self.share_param:
182
+ # tmp.append(self.fit_dense(sequence_layer))
183
+ # else:
184
+ # tmp.append(self.fit_denses[s_id](sequence_layer))
185
+ # sequence_output = tmp
186
+
187
+ # return sequence_output
188
+
189
+ # def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
190
+
191
+ # outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
192
+ # output_hidden_states=True, output_attentions=True)
193
+ # sequence_output, att_output, pooled_output = outputs.hidden_states, outputs.attentions, outputs.pooler_output
194
+
195
+ # logits = self.classifier(pooled_output)
196
+
197
+ # sequence_output = self.do_fit_dense(sequence_output)
198
+
199
+ # return logits, att_output, sequence_output
200
+
201
+
202
+
203
+
204
+ # SequenceClassification docstring
205
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
206
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
207
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
208
+ _CONFIG_FOR_DOC = "BertConfig"
209
+
210
+
211
+
212
+ BERT_START_DOCSTRING = r"""
213
+
214
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
215
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
216
+ etc.)
217
+
218
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
219
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
220
+ and behavior.
221
+
222
+ Parameters:
223
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
224
+ Initializing with a config file does not load the weights associated with the model, only the
225
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
226
+ """
227
+
228
+
229
+ BERT_INPUTS_DOCSTRING = r"""
230
+ Args:
231
+ input_ids (`torch.LongTensor` of shape `({0})`):
232
+ Indices of input sequence tokens in the vocabulary.
233
+
234
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
235
+ [`PreTrainedTokenizer.__call__`] for details.
236
+
237
+ [What are input IDs?](../glossary#input-ids)
238
+ attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
239
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
240
+
241
+ - 1 for tokens that are **not masked**,
242
+ - 0 for tokens that are **masked**.
243
+
244
+ [What are attention masks?](../glossary#attention-mask)
245
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
246
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
247
+ 1]`:
248
+
249
+ - 0 corresponds to a *sentence A* token,
250
+ - 1 corresponds to a *sentence B* token.
251
+
252
+ [What are token type IDs?](../glossary#token-type-ids)
253
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
254
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
255
+ config.max_position_embeddings - 1]`.
256
+
257
+ [What are position IDs?](../glossary#position-ids)
258
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
259
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
260
+
261
+ - 1 indicates the head is **not masked**,
262
+ - 0 indicates the head is **masked**.
263
+
264
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
265
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
266
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
267
+ model's internal embedding lookup matrix.
268
+ output_attentions (`bool`, *optional*):
269
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
270
+ tensors for more detail.
271
+ output_hidden_states (`bool`, *optional*):
272
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
273
+ more detail.
274
+ return_dict (`bool`, *optional*):
275
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
276
+ """
277
+
278
+ @add_start_docstrings(
279
+ """
280
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
281
+ output) e.g. for GLUE tasks.
282
+ """,
283
+ BERT_START_DOCSTRING,
284
+ )
285
+ class BertForSequenceClassification(BertPreTrainedModel):
286
+ def __init__(self, config, do_fit=False, share_param=True):
287
+ super().__init__(config)
288
+ self.num_labels = config.num_labels
289
+ self.config = config
290
+ self.do_fit = do_fit
291
+ self.share_param = share_param
292
+
293
+ self.bert = BertModel(config)
294
+ classifier_dropout = (
295
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
296
+ )
297
+ self.dropout = nn.Dropout(classifier_dropout)
298
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
299
+
300
+ # Add fit layers if enabled
301
+ if self.do_fit:
302
+ fit_size = getattr(config, 'fit_size', 768)
303
+ if self.share_param:
304
+ self.fit_dense = nn.Linear(config.hidden_size, fit_size)
305
+ else:
306
+ self.fit_denses = nn.ModuleList(
307
+ [nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
308
+ )
309
+
310
+ self.post_init()
311
+
312
+ def do_fit_dense(self, hidden_states):
313
+ """Process hidden states through fit layers if enabled"""
314
+ if not self.do_fit:
315
+ return hidden_states
316
+
317
+ processed_states = []
318
+ for layer_idx, state in enumerate(hidden_states):
319
+ if self.share_param:
320
+ processed_states.append(self.fit_dense(state))
321
+ else:
322
+ processed_states.append(self.fit_denses[layer_idx](state))
323
+ return processed_states
324
+
325
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
326
+ @add_code_sample_docstrings(
327
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
328
+ output_type=SequenceClassifierOutput,
329
+ config_class=_CONFIG_FOR_DOC,
330
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
331
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
332
+ )
333
+ def forward(
334
+ self,
335
+ input_ids: Optional[torch.Tensor] = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ token_type_ids: Optional[torch.Tensor] = None,
338
+ position_ids: Optional[torch.Tensor] = None,
339
+ head_mask: Optional[torch.Tensor] = None,
340
+ inputs_embeds: Optional[torch.Tensor] = None,
341
+ labels: Optional[torch.Tensor] = None,
342
+ output_attentions: Optional[bool] = None,
343
+ output_hidden_states: Optional[bool] = None,
344
+ return_dict: Optional[bool] = None,
345
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
346
+ r"""
347
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
348
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
349
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
350
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
351
+ """
352
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
353
+
354
+ # Force output hidden states if fit layers are enabled
355
+ if self.do_fit:
356
+ output_hidden_states = True
357
+
358
+ outputs = self.bert(
359
+ input_ids,
360
+ attention_mask=attention_mask,
361
+ token_type_ids=token_type_ids,
362
+ position_ids=position_ids,
363
+ head_mask=head_mask,
364
+ inputs_embeds=inputs_embeds,
365
+ output_attentions=output_attentions,
366
+ output_hidden_states=output_hidden_states,
367
+ return_dict=return_dict,
368
+ )
369
+
370
+ pooled_output = outputs[1]
371
+ pooled_output = self.dropout(pooled_output)
372
+ logits = self.classifier(pooled_output)
373
+
374
+ # Process hidden states through fit layers
375
+ hidden_states = outputs.hidden_states
376
+ if self.do_fit and hidden_states is not None:
377
+ hidden_states = self.do_fit_dense(hidden_states)
378
+
379
+ loss = None
380
+ if labels is not None:
381
+ if self.config.problem_type is None:
382
+ if self.num_labels == 1:
383
+ self.config.problem_type = "regression"
384
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
385
+ self.config.problem_type = "single_label_classification"
386
+ else:
387
+ self.config.problem_type = "multi_label_classification"
388
+
389
+ if self.config.problem_type == "regression":
390
+ loss_fct = MSELoss()
391
+ if self.num_labels == 1:
392
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
393
+ else:
394
+ loss = loss_fct(logits, labels)
395
+ elif self.config.problem_type == "single_label_classification":
396
+ loss_fct = CrossEntropyLoss()
397
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
398
+ elif self.config.problem_type == "multi_label_classification":
399
+ loss_fct = BCEWithLogitsLoss()
400
+ loss = loss_fct(logits, labels)
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[2:]
404
+ # Replace original hidden states with processed ones
405
+ if self.do_fit and hidden_states is not None:
406
+ output = (logits,) + (hidden_states,) + outputs[3:]
407
+ return ((loss,) + output) if loss is not None else output
408
+
409
+ return SequenceClassifierOutput(
410
+ loss=loss,
411
+ logits=logits,
412
+ hidden_states=hidden_states if output_hidden_states else None,
413
+ attentions=outputs.attentions,
414
+ )
415
+
416
+
417
+
418
+ class BertForSequenceClassificationPrediction(BertForSequenceClassification):
419
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
420
+
421
+ assert not self.training
422
+
423
+ _, pooled_output, sequence_output, att_output = self.bert(
424
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
425
+ output_hidden_states=True, output_attentions=True)
426
+
427
+ logits = self.classifier(pooled_output)
428
+
429
+ loss = None
430
+ if labels is not None:
431
+ loss = torch.tensor(0.)
432
+
433
+ return SequenceClassifierOutput(
434
+ loss=loss,
435
+ logits=logits,
436
+ )
437
+
438
+
439
+ class ShallowSkipping(nn.Module):
440
+
441
+ def __init__(self, model):
442
+ super().__init__()
443
+ # self.model = model # do not register
444
+ self.config = model.config
445
+ self.shallow_config = model.shallow_config
446
+ # current only support trigram
447
+ self.ngram = 3
448
+
449
+ if self.shallow_config.hidden_size != self.config.hidden_size:
450
+ self.linear = nn.Linear(self.shallow_config.hidden_size, self.config.hidden_size)
451
+
452
+ self.plot = plot.Plot(self.config.max_num_entries, self.config.hidden_size)
453
+
454
+ def _build_tri_gram_ids(self, input_ids:torch.Tensor) -> torch.Tensor:
455
+ return torch.from_numpy(
456
+ self.plot.input_ids_to_tri_grams(input_ids.cpu().numpy())
457
+ ).to(input_ids.device)
458
+
459
+ def build_input_ngrams(self, input_ids:torch.Tensor, token_type_ids:torch.Tensor):
460
+
461
+ input_ngram_ids = self._build_tri_gram_ids(input_ids)
462
+
463
+ token_ngram_type_ids = None #
464
+
465
+ attention_mask = (input_ngram_ids > 0).float()
466
+
467
+ if self.training:
468
+ _mask = torch.rand(attention_mask.shape).to(attention_mask.device)
469
+ _mask = (_mask > self.config.ngram_masking)
470
+ attention_mask *= _mask
471
+
472
+ attention_mask[:, self.ngram//2] = 1 # avoid masking all tokens in a tri-gram
473
+ return input_ngram_ids, token_ngram_type_ids, attention_mask
474
+
475
+ @torch.jit.script
476
+ def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
477
+ batch_size, seq_length = input_ids.shape
478
+ lens = (input_ids!=0).sum(1)
479
+ hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
480
+
481
+ # assert to be trigrams
482
+ flat_hidden_state = ngram_hidden_states[:, 1]
483
+ flat_hidden_state[:-1] = flat_hidden_state[:-1] + ngram_hidden_states[1:, 0]
484
+ flat_hidden_state[1:] = flat_hidden_state[1:] + ngram_hidden_states[:-1, 2]
485
+ k = 0
486
+ for i in range(batch_size):
487
+ hidden_state[i, :lens[i]] = flat_hidden_state[k: k+lens[i]]
488
+ k += 1 + lens[i] # 1 for skipping one padding tri-gram
489
+ hidden_state = hidden_state + aux_embeddings
490
+ return hidden_state
491
+
492
+ def forward_shallow_layers(
493
+ self,
494
+ input_ids,
495
+ token_type_ids,
496
+ attention_mask,
497
+ ngram_mask_position=None,
498
+ head_mask=None,
499
+ encoder_hidden_states=None,
500
+ encoder_attention_mask=None,
501
+ past_key_value=None,
502
+ output_attentions=True,
503
+ output_hidden_states=True,
504
+ model=None,
505
+ ):
506
+ device = model.device
507
+
508
+ input_ngram_ids, token_ngram_type_ids, attention_mask = self.build_input_ngrams(input_ids, token_type_ids)
509
+ ngram_attention_mask = attention_mask.clone()
510
+
511
+ if ngram_mask_position is not None:
512
+ input_ngram_ids[:, ngram_mask_position] = 0
513
+ ngram_attention_mask[:, ngram_mask_position] = 0
514
+
515
+ extended_attention_mask = model.get_extended_attention_mask(attention_mask, input_ngram_ids.shape, device)
516
+
517
+ ngram_index=(input_ngram_ids[:, self.ngram//2] > 0)
518
+
519
+ embedding_output = model.embeddings(input_ids=input_ngram_ids, token_type_ids=token_ngram_type_ids)
520
+
521
+ hidden_states = embedding_output
522
+ attention_mask = extended_attention_mask
523
+
524
+ for i, layer_module in enumerate(
525
+ model.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]):
526
+ layer_head_mask = head_mask[i] if head_mask is not None else None
527
+
528
+ layer_outputs = layer_module(
529
+ hidden_states=hidden_states,
530
+ attention_mask=attention_mask,
531
+ head_mask=layer_head_mask,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ encoder_attention_mask=encoder_attention_mask,
534
+ past_key_value=past_key_value,
535
+ output_attentions=output_attentions,
536
+ )
537
+
538
+ hidden_states = layer_outputs[0]
539
+
540
+ if self.shallow_config.hidden_size != self.config.hidden_size:
541
+ hidden_states = self.linear(hidden_states)
542
+
543
+ # Set zero the padding ngrams: (..., [PAD], ...)
544
+ hidden_states = hidden_states * ngram_index[:, None, None]
545
+
546
+ hidden_states = hidden_states * model.attn(hidden_states).sigmoid() * ngram_attention_mask.unsqueeze(-1)
547
+
548
+ return input_ngram_ids, hidden_states
549
+
550
+ def forward(
551
+ self,
552
+ input_ids,
553
+ token_type_ids,
554
+ attention_mask,
555
+ head_mask=None,
556
+ encoder_hidden_states=None,
557
+ encoder_attention_mask=None,
558
+ past_key_value=None,
559
+ output_attentions=True,
560
+ output_hidden_states=True,
561
+ model=None,
562
+ ):
563
+
564
+ device = model.device
565
+
566
+ batch_size, seq_length = input_ids.shape
567
+ aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
568
+ aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
569
+
570
+ if self.config.plot_mode == 'force_compute':
571
+ '''
572
+ compute only, ignore PLOT
573
+ '''
574
+ input_ngram_ids, hidden_states = self.forward_shallow_layers(
575
+ input_ids=input_ids,
576
+ token_type_ids=token_type_ids,
577
+ attention_mask=attention_mask,
578
+ head_mask=head_mask,
579
+ encoder_hidden_states=encoder_hidden_states,
580
+ encoder_attention_mask=encoder_attention_mask,
581
+ past_key_value=past_key_value,
582
+ output_attentions=output_attentions,
583
+ output_hidden_states=output_hidden_states,
584
+ ngram_mask_position=None,
585
+ model=model,
586
+ )
587
+
588
+ elif self.config.plot_mode == 'update_all':
589
+ '''
590
+ build PLOT
591
+ '''
592
+ # uni-grams
593
+ input_ngram_ids, hidden_states = self.forward_shallow_layers(
594
+ input_ids=input_ids,
595
+ token_type_ids=token_type_ids,
596
+ attention_mask=attention_mask,
597
+ head_mask=head_mask,
598
+ encoder_hidden_states=encoder_hidden_states,
599
+ encoder_attention_mask=encoder_attention_mask,
600
+ past_key_value=past_key_value,
601
+ output_attentions=output_attentions,
602
+ output_hidden_states=output_hidden_states,
603
+ ngram_mask_position=(0,2),
604
+ model=model,
605
+ )
606
+ self.plot.update_data(input_ngram_ids, hidden_states)
607
+
608
+ # bi-grams
609
+ input_ngram_ids, hidden_states = self.forward_shallow_layers(
610
+ input_ids=input_ids,
611
+ token_type_ids=token_type_ids,
612
+ attention_mask=attention_mask,
613
+ head_mask=head_mask,
614
+ encoder_hidden_states=encoder_hidden_states,
615
+ encoder_attention_mask=encoder_attention_mask,
616
+ past_key_value=past_key_value,
617
+ output_attentions=output_attentions,
618
+ output_hidden_states=output_hidden_states,
619
+ ngram_mask_position=0,
620
+ model=model,
621
+ )
622
+ self.plot.update_data(input_ngram_ids, hidden_states)
623
+
624
+ # tri-grams
625
+ input_ngram_ids, hidden_states = self.forward_shallow_layers(
626
+ input_ids=input_ids,
627
+ token_type_ids=token_type_ids,
628
+ attention_mask=attention_mask,
629
+ head_mask=head_mask,
630
+ encoder_hidden_states=encoder_hidden_states,
631
+ encoder_attention_mask=encoder_attention_mask,
632
+ past_key_value=past_key_value,
633
+ output_attentions=output_attentions,
634
+ output_hidden_states=output_hidden_states,
635
+ ngram_mask_position=None,
636
+ model=model,
637
+ )
638
+ self.plot.update_data(input_ngram_ids, hidden_states)
639
+
640
+ elif self.config.plot_mode == 'plot_passive':
641
+ '''
642
+ use plot if no oov
643
+ '''
644
+
645
+ if input_ids.is_cuda:
646
+ input_ids = input_ids.cpu()
647
+ if not self.plot.has_oov(input_ids):
648
+ hidden_states = self.plot.retrieve_data(input_ids)
649
+ hidden_states = hidden_states.to(device)
650
+ else:
651
+ input_ids = input_ids.to(device)
652
+ input_ngram_ids, hidden_states = self.forward_shallow_layers(
653
+ input_ids=input_ids,
654
+ token_type_ids=token_type_ids,
655
+ attention_mask=attention_mask,
656
+ head_mask=head_mask,
657
+ encoder_hidden_states=encoder_hidden_states,
658
+ encoder_attention_mask=encoder_attention_mask,
659
+ past_key_value=past_key_value,
660
+ output_attentions=output_attentions,
661
+ output_hidden_states=output_hidden_states,
662
+ ngram_mask_position=None,
663
+ model=model,
664
+ )
665
+ self.plot.update_data(input_ngram_ids, hidden_states)
666
+
667
+ elif self.config.plot_mode == 'plot_only':
668
+ '''
669
+ plot only
670
+ looking up order: trigram -> bigram -> unigram -> 0
671
+ '''
672
+ if input_ids.is_cuda:
673
+ logger.warn("'input_ids' is better to placed in CPU.")
674
+ input_ids = input_ids.cpu()
675
+ hidden_states = self.plot.retrieve_data(input_ids)
676
+ hidden_states = hidden_states.to(device)
677
+
678
+ hidden_states = F.dropout(hidden_states, self.config.hidden_dropout_prob, self.training)
679
+ hidden_states = self.merge_ngrams(input_ids, hidden_states, aux_embeddings)
680
+ hidden_states = model.norm(hidden_states)
681
+
682
+ return hidden_states
683
+
684
+
685
+ class SkipBertEncoder(BertEncoder):
686
+ def __init__(self, shallow_config, config):
687
+ super(BertEncoder, self).__init__()
688
+ self.config = config
689
+ self.shallow_config = shallow_config
690
+ self.layer = nn.ModuleList(
691
+ [
692
+ BertLayer(shallow_config) for _ in range(config.num_hidden_layers - config.num_full_hidden_layers)
693
+ ] + [
694
+ BertLayer(config) for _ in range(config.num_full_hidden_layers)
695
+ ])
696
+
697
+ class SkipBertModel(BertModel):
698
+ def __init__(self, config, add_pooling_layer=True):
699
+ super().__init__(config)
700
+ self.config = config
701
+ self.shallow_config = copy.deepcopy(config)
702
+
703
+ self.shallow_config.hidden_size = getattr(config, 'shallow_hidden_size', 768)
704
+ self.shallow_config.intermediate_size = getattr(config, 'shallow_intermediate_size', 3072)
705
+
706
+ self.embeddings = BertEmbeddings(self.shallow_config)
707
+ self.encoder = SkipBertEncoder(self.shallow_config, config)
708
+
709
+ self.pooler = BertPooler(config) if add_pooling_layer else None
710
+
711
+ self.embeddings.position_embeddings2 = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
712
+ self.embeddings.token_type_embeddings2 = nn.Embedding(self.config.type_vocab_size, self.config.hidden_size)
713
+
714
+ self.norm = nn.LayerNorm(self.config.hidden_size)
715
+ self.attn = nn.Linear(self.config.hidden_size, 1)
716
+ self.shallow_skipping = ShallowSkipping(self)
717
+
718
+ self.init_weights()
719
+
720
+ def forward(
721
+ self,
722
+ input_ids=None,
723
+ attention_mask=None,
724
+ token_type_ids=None,
725
+ position_ids=None,
726
+ head_mask=None,
727
+ encoder_hidden_states=None,
728
+ encoder_attention_mask=None,
729
+ output_attentions=True,
730
+ output_hidden_states=True,
731
+ ):
732
+
733
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
734
+ output_hidden_states = (
735
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
736
+ )
737
+
738
+ input_shape = input_ids.size()
739
+ device = self.device
740
+
741
+ if attention_mask is None:
742
+ attention_mask = (input_ids != 0).float()
743
+ if token_type_ids is None:
744
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
745
+
746
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
747
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
748
+
749
+ hidden_states = self.shallow_skipping(
750
+ input_ids=input_ids,
751
+ token_type_ids=token_type_ids,
752
+ attention_mask=attention_mask,
753
+ head_mask=head_mask,
754
+ encoder_hidden_states=encoder_hidden_states,
755
+ encoder_attention_mask=encoder_attention_mask,
756
+ model=self,
757
+ )
758
+
759
+ # Global transformer layers
760
+ attention_mask = extended_attention_mask
761
+
762
+ all_hidden_states = ()
763
+ all_self_attentions = ()
764
+
765
+ for i, layer_module in enumerate(self.encoder.layer[-self.config.num_full_hidden_layers:]):
766
+
767
+ if output_hidden_states:
768
+ all_hidden_states = all_hidden_states + (hidden_states,)
769
+
770
+ layer_head_mask = head_mask[i + self.config.num_hidden_layers - self.config.num_full_hidden_layers] if head_mask is not None else None
771
+
772
+ layer_outputs = layer_module(
773
+ hidden_states=hidden_states,
774
+ attention_mask=attention_mask,
775
+ head_mask=layer_head_mask,
776
+ encoder_hidden_states=encoder_hidden_states,
777
+ encoder_attention_mask=encoder_attention_mask,
778
+ past_key_value=None,
779
+ output_attentions=output_attentions,
780
+ )
781
+
782
+ hidden_states = layer_outputs[0]
783
+
784
+ if output_attentions:
785
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
786
+
787
+ if output_hidden_states:
788
+ all_hidden_states = all_hidden_states + (hidden_states,)
789
+
790
+ sequence_output = hidden_states
791
+ pooled_output = self.pooler(sequence_output)
792
+
793
+ return (sequence_output, pooled_output, all_hidden_states, all_self_attentions)
794
+
795
+ def freeze_shallow_layers(self):
796
+ for p in self.embeddings.parameters():
797
+ p.requires_grad = False
798
+ for layer in self.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]:
799
+ for p in layer.parameters():
800
+ p.requires_grad = False
801
+ try:
802
+ for p in self.shallow_skipping.linear.parameters():
803
+ p.requires_grad = False
804
+ except Exception as e:
805
+ pass
806
+ try:
807
+ for p in self.attn.parameters():
808
+ p.requires_grad = False
809
+ except Exception as e:
810
+ pass
811
+
812
+ self.embeddings.dropout.p = 0.
813
+ for layer in self.encoder.layer[:self.config.num_hidden_layers - self.config.num_full_hidden_layers]:
814
+ for m in layer.modules():
815
+ if isinstance(m, torch.nn.Dropout):
816
+ m.p = 0.
817
+
818
+
819
+ class SkipBertForPreTraining(BertPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ fit_size = getattr(config, 'fit_size', 768)
823
+ self.bert = SkipBertModel(config)
824
+ self.cls = BertPreTrainingHeads(config)
825
+
826
+ if self.fit_size != config.hidden_size:
827
+ self.fit_denses = nn.ModuleList(
828
+ [nn.Linear(config.hidden_size, self.fit_size) for _ in range(config.num_hidden_layers + 1)]
829
+ )
830
+
831
+ def forward(self, input_ids, token_type_ids=None,
832
+ attention_mask=None, masked_lm_labels=None,
833
+ next_sentence_label=None, labels=None,
834
+ output_attentions=True, output_hidden_states=True,):
835
+ _, pooled_output, sequence_output, att_output = self.bert(
836
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
837
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states)
838
+
839
+ if self.fit_size != self.config.hidden_size:
840
+ tmp = []
841
+ for s_id, sequence_layer in enumerate(sequence_output):
842
+ tmp.append(self.fit_denses[s_id](sequence_layer))
843
+ sequence_output = tmp
844
+
845
+ return att_output, sequence_output
846
+
847
+
848
+ def freeze_shallow_layers(self):
849
+ self.bert.freeze_shallow_layers()
850
+
851
+
852
+ class SkipBertForSequenceClassification(BertPreTrainedModel):
853
+ def __init__(self, config, do_fit=False, share_param=True):
854
+ super().__init__(config)
855
+ num_labels = config.num_labels
856
+ self.hidden_size = config.hidden_size
857
+ self.num_labels = num_labels
858
+ self.bert = SkipBertModel(config)
859
+ self.dropout = nn.Dropout(
860
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
861
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
862
+
863
+ self.do_fit, self.share_param = do_fit, share_param
864
+ if self.do_fit:
865
+ fit_size = getattr(config, 'fit_size', 768)
866
+ self.fit_size = fit_size
867
+ if self.share_param:
868
+ self.share_fit_dense = nn.Linear(config.hidden_size, fit_size)
869
+ else:
870
+ self.fit_denses = nn.ModuleList(
871
+ [nn.Linear(config.hidden_size, fit_size) for _ in range(config.num_hidden_layers + 1)]
872
+ )
873
+
874
+ def do_fit_dense(self, sequence_output):
875
+
876
+ tmp = []
877
+ if self.do_fit:
878
+ for s_id, sequence_layer in enumerate(sequence_output):
879
+ if self.share_param:
880
+ tmp.append(self.share_fit_dense(sequence_layer))
881
+ else:
882
+ tmp.append(self.fit_denses[s_id](sequence_layer))
883
+ sequence_output = tmp
884
+
885
+ return sequence_output
886
+
887
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
888
+
889
+ _, pooled_output, sequence_output, att_output = self.bert(
890
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
891
+ output_hidden_states=True, output_attentions=True)
892
+
893
+ sequence_output = self.do_fit_dense(sequence_output)
894
+
895
+ pooled_output = self.dropout(pooled_output)
896
+ logits = self.classifier(pooled_output)
897
+
898
+ return logits, att_output, sequence_output
899
+
900
+ def freeze_shallow_layers(self):
901
+ self.bert.freeze_shallow_layers()
902
+
903
+
904
+ class SkipBertForSequenceClassificationPrediction(SkipBertForSequenceClassification):
905
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
906
+
907
+ assert not self.training
908
+
909
+ _, pooled_output, sequence_output, att_output = self.bert(
910
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
911
+ output_hidden_states=True, output_attentions=True)
912
+
913
+ logits = self.classifier(pooled_output)
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ loss = torch.tensor(0.)
918
+
919
+ return SequenceClassifierOutput(
920
+ loss=loss,
921
+ logits=logits,
922
+ )
template_FL/src/fedllm/skipbert/plot.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import time
4
+ import torch.nn.functional as F
5
+ import torch
6
+ import numpy as np
7
+ import sys,os
8
+ import numba
9
+
10
+ def _set_madvise(large_data, advise=1):
11
+ '''
12
+ 0: MADV_NORMAL
13
+ 1: MADV_RANDOM
14
+ 2: MADV_SEQUENTIAL
15
+ 3: MADV_WILLNEED
16
+ 4: MADV_DONTNEED
17
+ '''
18
+ import ctypes
19
+ madvise = ctypes.CDLL("libc.so.6").madvise
20
+ madvise.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
21
+ madvise.restype = ctypes.c_int
22
+ assert madvise(large_data.ctypes.data, large_data.size * large_data.dtype.itemsize, advise) == 0, "MADVISE FAILED" # 1 means MADV_RANDOM
23
+
24
+ def _read_or_create_memmap(path, return_tensor=True, *args, **kargs):
25
+ if os.path.exists(path):
26
+ a = np.memmap(path, mode='r+', *args, **kargs)
27
+ else:
28
+ a = np.memmap(path, mode='w+', *args, **kargs)
29
+ # first row is reserved for oovs
30
+ a[0] = 0
31
+ _set_madvise(a, advise=1)
32
+ if return_tensor:
33
+ a = torch.from_numpy(a) # zero-copy
34
+ return a
35
+
36
+ def _to_key(k):
37
+ return tuple(k.tolist())
38
+
39
+ @numba.njit()
40
+ def _input_ids_to_tri_grams(x: np.array):
41
+ bs, seq_len = x.shape
42
+ ret = np.zeros((bs*(seq_len+1), 3), dtype=np.int64)
43
+ i_ret = 0
44
+ for i_bs in range(bs):
45
+ for i_token in range(seq_len):
46
+ if x[i_bs, i_token] == 0:
47
+ break
48
+ if i_token == 0:
49
+ ret[i_ret][1] = x[i_bs, i_token]
50
+ ret[i_ret][2] = x[i_bs, i_token+1]
51
+ elif i_token == seq_len - 1:
52
+ ret[i_ret][0] = x[i_bs, i_token-1]
53
+ ret[i_ret][1] = x[i_bs, i_token]
54
+ else:
55
+ ret[i_ret] = x[i_bs, i_token-1:i_token+2]
56
+ i_ret += 1
57
+ i_ret += 1 # add a pad trigram between seqs
58
+ return ret[:i_ret]
59
+
60
+
61
+ @numba.njit()
62
+ def _input_ids_to_ngram_ids(d: dict, x: np.array):
63
+ '''
64
+ input_ids tp ngram_ids.
65
+ try match (x0, x1, x2) -> id;
66
+ if not possible, match (0, x1, x2) -> id;
67
+ if also not possible, match (0, x1, 0) -> id.
68
+ '''
69
+ bs, seq_len = x.shape
70
+ ret = np.zeros(bs*(seq_len+1), dtype=np.int64)
71
+ i_ret = 0
72
+ for i_bs in range(bs):
73
+ for i_token in range(seq_len):
74
+ if x[i_bs, i_token] == 0:
75
+ break
76
+ if i_token == 0:
77
+ k = (0, x[i_bs, i_token], x[i_bs, i_token+1])
78
+ elif i_token == seq_len - 1:
79
+ k = (x[i_bs, i_token-1], x[i_bs, i_token], 0)
80
+ else:
81
+ k = (x[i_bs, i_token-1], x[i_bs, i_token], x[i_bs, i_token+1])
82
+ if k in d: # tri-gram
83
+ ret[i_ret] = d[k]
84
+ else:
85
+ k = (0, k[1], k[2])
86
+ if k in d: # bi-gram
87
+ ret[i_ret] = d[k]
88
+ else:
89
+ k = (0, k[1], 0)
90
+ if k in d: # uni-gram
91
+ ret[i_ret] = d[k]
92
+ i_ret += 1
93
+ i_ret += 1 # add a pad trigram between seqs
94
+ return ret[:i_ret]
95
+
96
+ @numba.njit()
97
+ def _has_oov(d: dict, x: np.array):
98
+ bs, seq_len = x.shape
99
+ for i_bs in range(bs):
100
+ for i_token in range(seq_len):
101
+ if x[i_bs, i_token] == 0:
102
+ break
103
+ if i_token == 0:
104
+ k = (0, x[i_bs, i_token], x[i_bs, i_token+1])
105
+ elif i_token == seq_len - 1:
106
+ k = (x[i_bs, i_token-1], x[i_bs, i_token], 0)
107
+ else:
108
+ k = (x[i_bs, i_token-1], x[i_bs, i_token], x[i_bs, i_token+1])
109
+ if k not in d:
110
+ return True
111
+ return False
112
+
113
+
114
+ class Plot:
115
+ def __init__(self, max_num_entries=100000, hidden_size=768):
116
+
117
+ self.max_num_entries = max_num_entries
118
+ self.hidden_size = hidden_size
119
+
120
+ self.trigram_to_id, self.id_to_trigram = self.build_hash_table('input_ids_tri_gram.memmap', max_num_entries)
121
+ self.orig_trigram_hidden_states = _read_or_create_memmap("plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
122
+
123
+ def build_hash_table(self, path, max_num_entries):
124
+ n_gram = 3
125
+ hash_table1 = numba.typed.Dict()
126
+ hash_table1[tuple([0]*n_gram)] = 0 # dummy entry
127
+ orig_ngram_ids_mmap = _read_or_create_memmap(
128
+ path, return_tensor=False, dtype='int32', shape=(max_num_entries, n_gram))
129
+
130
+ for i in range(1, self.max_num_entries):
131
+ _tmp = orig_ngram_ids_mmap[i]
132
+ # break when meet all 0 ngram
133
+ if (_tmp==0).all():
134
+ break
135
+ tmp_hash = _to_key(_tmp)
136
+ if tmp_hash not in hash_table1:
137
+ hash_table1[tmp_hash] = i
138
+
139
+ return hash_table1, orig_ngram_ids_mmap
140
+
141
+ def input_ids_to_tri_grams(self, input_ids):
142
+ return _input_ids_to_tri_grams(input_ids)
143
+
144
+ def update_data(self, ngram_input_ids, ngram_hidden_states):
145
+ ngram_input_ids = ngram_input_ids.cpu().numpy()
146
+ ngram_hidden_states = ngram_hidden_states.detach().half().cpu() # FP16
147
+ bs, ngram = ngram_input_ids.shape
148
+ ngram_to_id, id_to_ngram, id_to_hidden_state = \
149
+ self.trigram_to_id, self.id_to_trigram, self.orig_trigram_hidden_states
150
+ # TODO: optimize the for-loop later
151
+ id_to_save = []
152
+ for i in range(bs):
153
+ ngram = _to_key(ngram_input_ids[i])
154
+ # TODO: handle ngram_id > max_size
155
+ ngram_id = ngram_to_id.get(ngram, len(ngram_to_id))
156
+ if ngram_id >= self.max_num_entries:
157
+ print('Exceed max number of entries...')
158
+ print('Skip current entry...')
159
+ continue
160
+ ngram_to_id[ngram] = ngram_id
161
+ id_to_ngram[ngram_id] = ngram
162
+ id_to_save.append(ngram_id)
163
+ id_to_hidden_state[id_to_save] = ngram_hidden_states
164
+
165
+ def retrieve_data(self, input_ids):
166
+ input_ids = input_ids.numpy()
167
+ id_to_get = _input_ids_to_ngram_ids(self.trigram_to_id, input_ids)
168
+ hidden_states = self.orig_trigram_hidden_states[id_to_get]
169
+ return hidden_states
170
+
171
+ def has_oov(self, input_ids):
172
+ return _has_oov(self.trigram_to_id, input_ids.numpy())
173
+
template_FL/src/fedllm/templates/alpaca.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Template used by Alpaca-LoRA.",
3
+ "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
4
+ "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
5
+ "response_split": "### Response:"
6
+ }
template_FL/src/fedllm/trainer.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import Accelerator
2
+ from torch.utils.data import DataLoader
3
+ import torch
4
+ import copy
5
+ import numpy as np
6
+ from transformers import BertForSequenceClassification, GenerationConfig, AutoTokenizer
7
+ import inspect
8
+ import logging
9
+ import wandb
10
+ from tqdm import tqdm
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ManualLLMSampleCB:
15
+ def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
16
+ self.model = model
17
+ self.concat_model = None
18
+ self.tokenizer = tokenizer
19
+ self.task = task
20
+ self.num_samples = num_samples
21
+ self.max_new_tokens = max_new_tokens
22
+ self.gen_config = GenerationConfig.from_pretrained(
23
+ model.config.name_or_path, max_new_tokens=max_new_tokens
24
+ )
25
+
26
+ def generate(self, prompt):
27
+ # Tokenize the input prompt and include the attention mask
28
+ tokenized_prompt = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
29
+ input_ids = tokenized_prompt['input_ids']
30
+ attention_mask = tokenized_prompt['attention_mask'] # Extract attention mask
31
+
32
+ with torch.no_grad():
33
+ output = self.model.generate(
34
+ input_ids=input_ids,
35
+ attention_mask=attention_mask,
36
+ max_new_tokens=self.max_new_tokens,
37
+ generation_config=self.gen_config
38
+ )
39
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
40
+
41
+
42
+ def create_samples_table(self, dataset):
43
+ table = wandb.Table(columns=["input", "prediction", "label", "task"])
44
+ sampled_dataset = dataset.shuffle(seed=42).select(range(self.num_samples))
45
+
46
+ for example in tqdm(sampled_dataset, desc="Generating Samples"):
47
+ instruction = example.get("instruction", "")
48
+ input_text = example.get("input", "")
49
+ label = example.get("output", "")
50
+
51
+ if input_text:
52
+ prompt = f"Instruction: {instruction} Input: {input_text} Response:"
53
+ else:
54
+ prompt = f"Instruction: {instruction} Response:"
55
+
56
+ prediction = self.generate(prompt)
57
+ table.add_data(prompt, prediction, label, self.task)
58
+
59
+ return table
60
+
61
+ def log_samples_to_wandb(self, dataset):
62
+ samples_table = self.create_samples_table(dataset)
63
+ wandb.log({"sample_predictions": samples_table})
64
+
65
+
66
+ class ManualTrainer:
67
+ def __init__(
68
+ self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
69
+ args, data_collator, compute_metrics, mates_args, data_influence_model, data_influence_tokenizer
70
+ ):
71
+ self.accelerator = Accelerator()
72
+ self.model = model
73
+ self.tokenizer = tokenizer
74
+ self.args = args
75
+ self.data_collator = data_collator
76
+ self.compute_metrics = compute_metrics
77
+ self.mates_args = mates_args
78
+ self.data_influence_model = data_influence_model
79
+ self.data_influence_tokenizer = data_influence_tokenizer
80
+
81
+ # Remove unused columns from datasets
82
+ if train_dataset:
83
+ self.train_dataset = self._remove_unused_columns(train_dataset, "training")
84
+ # Prepare data loaders
85
+ self.full_train_loader = DataLoader(
86
+ self.train_dataset,
87
+ batch_size=self.args.per_device_train_batch_size,
88
+ shuffle=True,
89
+ collate_fn=self.data_collator,
90
+ drop_last=self.args.dataloader_drop_last
91
+ )
92
+ else:
93
+ self.train_loader = None
94
+ self.full_train_loader = None
95
+
96
+ if val_dataset:
97
+ self.val_dataset = self._remove_unused_columns(val_dataset, "validation")
98
+ self.val_loader = DataLoader(
99
+ self.val_dataset,
100
+ batch_size=self.args.per_device_eval_batch_size,
101
+ shuffle=False,
102
+ collate_fn=self.data_collator,
103
+ drop_last=self.args.dataloader_drop_last
104
+ )
105
+ else:
106
+ self.val_loader = None
107
+
108
+ if self.mates_args.state:
109
+ self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
110
+ self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
111
+
112
+ self.holdout_loader = DataLoader(
113
+ self.holdout_dataset,
114
+ batch_size=self.mates_args.holdout_batch_size,
115
+ shuffle=True,
116
+ collate_fn=self.data_collator,
117
+ drop_last=self.args.dataloader_drop_last
118
+ )
119
+
120
+ self.reference_loader = DataLoader(
121
+ self.reference_dataset,
122
+ batch_size=self.mates_args.reference_batch_size,
123
+ shuffle=False,
124
+ collate_fn=self.data_collator,
125
+ drop_last=self.args.dataloader_drop_last
126
+ )
127
+
128
+ # Prepare optimizer
129
+ self.optimizer = torch.optim.AdamW(
130
+ self.model.parameters(),
131
+ lr=self.args.learning_rate
132
+ )
133
+
134
+ # Prepare model, optimizer, and data loaders for Accelerator
135
+ self.model, self.optimizer, self.full_train_loader, self.val_loader = self.accelerator.prepare(
136
+ self.model, self.optimizer, self.full_train_loader, self.val_loader
137
+ )
138
+
139
+ if self.mates_args.state:
140
+ # Prepare holdout and reference loaders for Accelerator
141
+ self.data_influence_model, self.holdout_loader, self.reference_loader = self.accelerator.prepare(
142
+ self.data_influence_model, self.holdout_loader, self.reference_loader
143
+ )
144
+
145
+ def _remove_unused_columns(self, dataset, description=None):
146
+ """
147
+ Removes columns from a dataset that are not used by the model's forward method.
148
+
149
+ Args:
150
+ dataset: A dataset object (e.g., from datasets.Dataset).
151
+ description: A string description of the dataset (e.g., "training" or "validation").
152
+ Returns:
153
+ The dataset with unused columns removed.
154
+ """
155
+ # Inspect the model forward signature
156
+ forward_signature = inspect.signature(self.model.forward)
157
+ signature_columns = list(forward_signature.parameters.keys())
158
+
159
+ # Add label columns to the signature columns
160
+ label_columns = ["labels", "label_ids"]
161
+ signature_columns += label_columns
162
+
163
+ # Determine unused columns
164
+ dataset_columns = set(dataset.column_names)
165
+ used_columns = set(signature_columns).intersection(dataset_columns)
166
+ ignored_columns = list(dataset_columns - used_columns)
167
+
168
+ if ignored_columns:
169
+ logger.info(
170
+ f"The following columns in the {description} set don't have a corresponding argument in "
171
+ f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
172
+ )
173
+
174
+ # Ensure at least one column matches the model's expected inputs
175
+ if not used_columns:
176
+ raise ValueError(
177
+ f"No columns in the {description} dataset match the model's forward method signature. "
178
+ f"The following columns have been ignored: {', '.join(ignored_columns)}."
179
+ )
180
+
181
+ return dataset.remove_columns(ignored_columns)
182
+
183
+ def train(self):
184
+ best_val_loss = float('inf')
185
+ early_stopping_counter = 0
186
+ early_stopping_patience = 5
187
+ training_loss = []
188
+
189
+ for epoch in range(self.args.num_train_epochs):
190
+ # Check if it's time to update the data influence model and state is True
191
+ if self.mates_args.state and epoch % self.mates_args.update_data_influence_model_step == 0:
192
+ print("Updating the data influence model and selecting high-quality data...")
193
+ logger.info("Updating the data influence model and selecting high-quality data...")
194
+ self.update_data_influence_model()
195
+
196
+ # Filter high-quality data using the data influence model
197
+ high_quality_indices = self.select_high_quality_data(
198
+ dataset_size=len(self.train_dataset),
199
+ selection_fraction=self.mates_args.selection_fraction,
200
+ )
201
+ self.train_loader = self.accelerator.prepare(
202
+ self.create_filtered_dataloader(high_quality_indices)
203
+ )
204
+
205
+ self.model.train()
206
+ epoch_loss = 0.0
207
+
208
+ for step, batch in enumerate(self.train_loader):
209
+ if step >= self.args.max_steps:
210
+ break
211
+
212
+ self.optimizer.zero_grad()
213
+
214
+ outputs = self.model(
215
+ input_ids=batch['input_ids'],
216
+ attention_mask=batch['attention_mask'],
217
+ labels=batch['labels']
218
+ )
219
+ loss = outputs.loss
220
+
221
+ self.accelerator.backward(loss)
222
+ self.optimizer.step()
223
+
224
+ epoch_loss += loss.item()
225
+
226
+ if (step + 1) % self.args.logging_steps == 0:
227
+ # print(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
228
+ logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
229
+
230
+ avg_epoch_loss = epoch_loss / len(self.train_loader)
231
+ training_loss.append(avg_epoch_loss)
232
+
233
+ val_results = self.evaluate()
234
+
235
+ # print(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
236
+ logger,info(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
237
+
238
+ # Early stopping logic
239
+ if val_results["eval_loss"] < best_val_loss:
240
+ best_val_loss = val_results["eval_loss"]
241
+ early_stopping_counter = 0
242
+ else:
243
+ early_stopping_counter += 1
244
+ if early_stopping_counter >= early_stopping_patience:
245
+ print("Early stopping triggered")
246
+ break
247
+
248
+ return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
249
+
250
+ def select_high_quality_data(self, dataset_size, selection_fraction):
251
+ """
252
+ Use the data influence model to predict quality scores and select high-quality data indices.
253
+ """
254
+ print("Selecting high-quality data using the data influence model...")
255
+
256
+ # Predict influence scores for the entire dataset
257
+ influence_scores = []
258
+ self.data_influence_model.eval()
259
+ influence_optimizer = self.accelerator.prepare(
260
+ torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
261
+ )
262
+ i = 0
263
+ with torch.no_grad():
264
+ for batch in self.full_train_loader: # Full dataset loader
265
+ text = self.tokenizer.batch_decode(
266
+ batch['input_ids'],
267
+ skip_special_tokens=True
268
+ )
269
+
270
+ # Tokenize the text using the BERT tokenizer
271
+ bert_inputs = self.data_influence_tokenizer(
272
+ text,
273
+ truncation=True,
274
+ padding='max_length',
275
+ max_length=256,
276
+ return_tensors='pt'
277
+ ).to(self.accelerator.device)
278
+
279
+ # Train the data influence model
280
+ influence_optimizer.zero_grad()
281
+ outputs = self.data_influence_model(
282
+ input_ids=bert_inputs['input_ids'],
283
+ attention_mask=bert_inputs['attention_mask'],
284
+ )
285
+
286
+ influence_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
287
+
288
+ i += 1
289
+
290
+ if i == 100:
291
+ break
292
+
293
+ # Normalize influence scores and apply Gumbel-Top-$k$ selection
294
+ influence_scores = np.array(influence_scores)
295
+ print(">> Influence scores shape:", influence_scores.shape)
296
+
297
+ # Add Gumbel noise for diversity
298
+ rng = np.random.default_rng()
299
+ gumbel_noise = rng.gumbel(size=len(influence_scores))
300
+ influence_scores += gumbel_noise
301
+
302
+ # Select top indices based on influence scores
303
+ selection_size = int(len(influence_scores)*selection_fraction)
304
+ high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
305
+ print(f"Selected {len(high_quality_indices)} high-quality samples.")
306
+
307
+ return high_quality_indices
308
+
309
+ def create_filtered_dataloader(self, indices):
310
+ """
311
+ Create a new dataloader with only the selected high-quality data.
312
+ """
313
+ print("Creating a filtered dataloader with selected high-quality data...")
314
+ subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
315
+ return torch.utils.data.DataLoader(
316
+ subset_dataset,
317
+ batch_size=self.args.per_device_train_batch_size,
318
+ shuffle=True,
319
+ collate_fn=self.data_collator, # Use the same collate function
320
+ drop_last=self.args.dataloader_drop_last
321
+ )
322
+
323
+
324
+ def update_data_influence_model(self):
325
+ # Train a copy of the model on holdout data and validate on reference data
326
+ copied_model = copy.deepcopy(self.model)
327
+ copied_model.train()
328
+ optimizer = self.accelerator.prepare(
329
+ torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
330
+ )
331
+ holdout_reference_pairs = []
332
+
333
+ # print("Starting to collect holdout-reference pairs...")
334
+ logger.info("Starting to collect holdout-reference pairs...")
335
+ for step, holdout_batch in enumerate(self.holdout_loader):
336
+ # print(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
337
+ logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
338
+
339
+ optimizer.zero_grad()
340
+ outputs = copied_model(
341
+ input_ids=holdout_batch['input_ids'],
342
+ attention_mask=holdout_batch['attention_mask'],
343
+ labels=holdout_batch['labels']
344
+ )
345
+ holdout_loss = outputs.loss
346
+ decoded_texts = self.tokenizer.batch_decode(
347
+ holdout_batch['input_ids'],
348
+ skip_special_tokens=True
349
+ )
350
+
351
+ holdout_loss.backward()
352
+ optimizer.step()
353
+
354
+ print(f"Evaluating reference losses at step {step}...")
355
+ logger.info(f"Evaluating reference losses at step {step}...")
356
+
357
+ copied_model.eval()
358
+ reference_losses = []
359
+
360
+ with torch.no_grad():
361
+ for ref_batch in self.reference_loader:
362
+ outputs = copied_model(
363
+ input_ids=ref_batch['input_ids'],
364
+ attention_mask=ref_batch['attention_mask'],
365
+ labels=ref_batch['labels']
366
+ )
367
+ reference_losses.append(outputs.loss.item())
368
+
369
+ # Compute the mean of reference losses
370
+ score = sum(reference_losses) / len(reference_losses) if reference_losses else 0.0
371
+ holdout_reference_pairs.append((decoded_texts, score))
372
+ # copied_model.train()
373
+
374
+ # Train the data influence model using the generated pairs
375
+ print("Starting to train the data influence model...")
376
+ logger.info("Starting to train the data influence model...")
377
+
378
+ self.data_influence_model.train()
379
+ influence_optimizer = torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
380
+
381
+ for step, (text, score) in enumerate(holdout_reference_pairs):
382
+ # Tokenize the text using the BERT tokenizer
383
+ bert_inputs = self.data_influence_tokenizer(
384
+ text,
385
+ truncation=True,
386
+ padding='max_length',
387
+ max_length=256,
388
+ return_tensors='pt'
389
+ ).to(self.accelerator.device)
390
+
391
+ # Convert score to tensor and enable gradients
392
+ score_tensor = torch.tensor([score], device=self.accelerator.device, dtype=torch.float32, requires_grad=True)
393
+
394
+ # Train the data influence model
395
+ influence_optimizer.zero_grad()
396
+ outputs = self.data_influence_model(
397
+ input_ids=bert_inputs['input_ids'],
398
+ attention_mask=bert_inputs['attention_mask'],
399
+ labels=score_tensor
400
+ )
401
+ influence_loss = outputs.loss
402
+
403
+ influence_loss.backward()
404
+ influence_optimizer.step()
405
+
406
+ if step % 50 == 0:
407
+ print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
408
+ logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
409
+
410
+
411
+ # Distillation for SkipBERT
412
+
413
+
414
+
415
+ def evaluate(self, wandb_sample=True):
416
+ self.model.eval()
417
+ val_loss = 0.0
418
+
419
+ all_preds = []
420
+ all_labels = []
421
+
422
+ with torch.no_grad():
423
+ for batch in self.val_loader:
424
+ outputs = self.model(
425
+ input_ids=batch['input_ids'],
426
+ attention_mask=batch['attention_mask'],
427
+ labels=batch['labels']
428
+ )
429
+ val_loss += outputs.loss.item()
430
+
431
+ logits = self.accelerator.gather(outputs.logits)
432
+ labels = self.accelerator.gather(batch['labels'])
433
+
434
+ logits = logits.cpu().numpy()
435
+ labels = labels.cpu().numpy()
436
+
437
+ predictions = np.argmax(logits, axis=-1)
438
+ attention_mask = batch['attention_mask'].cpu().numpy()
439
+
440
+ for pred, label, mask in zip(predictions, labels, attention_mask):
441
+ valid_pred = pred[mask.astype(bool)]
442
+ valid_label = label[mask.astype(bool)]
443
+
444
+ all_preds.append(valid_pred)
445
+ all_labels.append(valid_label)
446
+
447
+ max_len = max(len(seq) for seq in all_preds)
448
+ padded_preds = np.array([
449
+ np.pad(seq, (0, max_len - len(seq)), 'constant', constant_values=self.tokenizer.pad_token_id)
450
+ for seq in all_preds
451
+ ])
452
+
453
+ max_len = max(len(seq) for seq in all_labels)
454
+ padded_labels = np.array([
455
+ np.pad(seq, (0, max_len - len(seq)), 'constant', constant_values=-100)
456
+ for seq in all_labels
457
+ ])
458
+
459
+ metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
460
+
461
+ metrics.update({"eval_loss": val_loss / len(self.val_loader)})
462
+ print("Validation Metrics:", metrics)
463
+
464
+ if wandb_sample:
465
+ # Sample Logging
466
+ llm_sample_cb = ManualLLMSampleCB(
467
+ model=self.model,
468
+ tokenizer=self.tokenizer,
469
+ task="classification",
470
+ num_samples=5,
471
+ max_new_tokens=128
472
+ )
473
+ llm_sample_cb.log_samples_to_wandb(self.val_dataset)
474
+
475
+ return metrics
476
+
template_FL/src/fedllm/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def clean_output_text(text):
2
+ """
3
+ Clean and normalize text from LLM outputs by removing noise and repetitions.
4
+
5
+ Args:
6
+ text (str): Raw text from LLM prediction
7
+
8
+ Returns:
9
+ str: Cleaned and normalized text
10
+ """
11
+ import re
12
+
13
+ def remove_repeats(text):
14
+ # Remove repeated words
15
+ pattern_words = r'\b(\w+)(?:\s+\1\b)+'
16
+ text = re.sub(pattern_words, r'\1', text)
17
+
18
+ # Remove repeated character patterns (like 'asasas')
19
+ pattern_chars = r'(\w+?)\1+'
20
+ text = re.sub(pattern_chars, r'\1', text)
21
+
22
+ return text
23
+
24
+ # Remove excessive punctuation
25
+ def normalize_punctuation(text):
26
+ # Replace multiple exclamation/question marks with single ones
27
+ text = re.sub(r'!+', '!', text)
28
+ text = re.sub(r'\?+', '?', text)
29
+ # Remove multiple periods (except for ellipsis)
30
+ text = re.sub(r'\.{4,}', '...', text)
31
+ text = text.replace('cor', '').replace('asesa', '')
32
+ return text
33
+
34
+ # Main cleaning pipeline
35
+ cleaned_text = text.strip()
36
+
37
+ # Remove common noise patterns
38
+ noise_patterns = [
39
+ r'\n+', # Multiple newlines
40
+ r'\s+', # Multiple spaces
41
+ r'\\n', # Literal \n
42
+ r'\\t', # Literal \t
43
+ ]
44
+
45
+ for pattern in noise_patterns:
46
+ cleaned_text = re.sub(pattern, ' ', cleaned_text)
47
+
48
+ # Apply cleaning functions
49
+ # cleaned_text = remove_repetitions(cleaned_text)
50
+ cleaned_text = remove_repeats(cleaned_text)
51
+ cleaned_text = normalize_punctuation(cleaned_text)
52
+ cleaned_text = ' '.join(cleaned_text.split()) # Normalize spacing
53
+
54
+ return cleaned_text.strip()
template_FL/src/pyproject.toml ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "flowertune-llm"
7
+ version = "1.0.0"
8
+ description = "FlowerTune LLM: Federated LLM Fine-tuning with Flower"
9
+ license = "Apache-2.0"
10
+ dependencies = [
11
+ "flwr[simulation]==1.12.0",
12
+ "flwr-datasets>=0.3.0",
13
+ "trl==0.8.1",
14
+ "bitsandbytes==0.43.0",
15
+ "scipy==1.13.0",
16
+ "peft==0.6.2",
17
+ "fschat[model_worker,webui]==0.2.35",
18
+ "transformers==4.41.1",
19
+ "sentencepiece==0.2.0",
20
+ "omegaconf==2.3.0",
21
+ "hf_transfer==0.1.8",
22
+ ]
23
+
24
+
25
+ [tool.hatch.build.targets.wheel]
26
+ packages = ["."]
27
+
28
+ [tool.flwr.app]
29
+ publisher = "flwrlabs"
30
+
31
+ [tool.flwr.app.components]
32
+ serverapp = "fedllm.server_app:app"
33
+ clientapp = "fedllm.client_app:app"
34
+
35
+ [tool.flwr.app.config]
36
+ num-server-rounds = 2
37
+ num-supernodes = 10
38
+
39
+ # Define dataset
40
+ dataset.type = 'hete' # type = ['homo','hete']
41
+ dataset.name = "vicgalle/alpaca-gpt4"
42
+
43
+ # Define model settings
44
+ model.name = "Qwen/Qwen2.5-1.5B-Instruct"
45
+ model.quantization = 4
46
+ model.gradient-checkpointing = true
47
+ model.flash_attention = false
48
+
49
+ ### Use MATES ###
50
+ mates.state = true
51
+ mates.holdout-ratio = 0.001
52
+ mates.reference-ratio = 0.0005
53
+ mates.holdout-batch-size = 4
54
+ mates.reference-batch-size = 2
55
+ mates.update-data-influence-model-step = 20
56
+ mates.selection-fraction = 0.4
57
+ ### END ###
58
+
59
+ ### Use SkipBERT ###
60
+
61
+ # Model setting
62
+ skipbert.student-model = "bert-base-uncased"
63
+ skipbert.num_layers_student = 12
64
+ skipbert.num_full_hidden_layers_student = 6
65
+ skipbert.num_masked_layers_teacher = 0
66
+ skipbert.num_masked_last_layers_teacher = 0
67
+
68
+ # Training hyperparameters
69
+ skipbert.train_batch_size = 8
70
+ skipbert.gradient_accumulation_steps = 2
71
+ skipbert.eval_batch_size = 8
72
+ skipbert.eval_accumulation_steps = 2
73
+ skipbert.learning_rate = 2.0e-5
74
+ skipbert.num_train_epochs = 10
75
+ skipbert.eval_step = 10
76
+ skipbert.max_seq_length = 128
77
+ skipbert.weight_decay = 1.0e-4
78
+ skipbert.warmup_steps = 100 # 500
79
+ skipbert.do_train = true
80
+ skipbert.do_eval = true
81
+ skipbert.max_steps = -1
82
+ skipbert.evaluation_strategy = "epoch"
83
+ skipbert.save_strategy = "epoch"
84
+ skipbert.lr_scheduler_type = "cosine" # or 'warmup_linear'
85
+ skipbert.logging_dir = './skipbert_logs'
86
+ skipbert.output_dir = "./skipbert_results"
87
+ skipbert.report_to = 'wandb'
88
+
89
+ # Knowledge distillation parameters
90
+ skipbert.beta = 0.01
91
+ skipbert.T = 1.0
92
+ skipbert.alpha = 1.0
93
+ skipbert.reduce_T = 1.0
94
+ skipbert.epochs_no_cls = 5
95
+
96
+ # Training schedule and features
97
+ skipbert.freeze_lower_layers = true
98
+
99
+ # Feature usage flags
100
+ skipbert.use_logits = true
101
+ skipbert.use_att = true
102
+ skipbert.use_rep = true
103
+ skipbert.use_embedding = false
104
+
105
+ # Training modes
106
+ skipbert.do_train = true
107
+ skipbert.do_eval = true
108
+ skipbert.do_predict = false
109
+ skipbert.do_fit = false
110
+ skipbert.fp16 = false
111
+ skipbert.no_pretrain = false
112
+ skipbert.use_init_weight = false
113
+ skipbert.share_param = true
114
+ skipbert.do_lower_case = true
115
+ skipbert.no_cuda = false
116
+
117
+ # N-gram settings
118
+ skipbert.n_gram_left = 1
119
+ skipbert.n_gram_right = 1
120
+
121
+ # Layer mappings
122
+ skipbert.att_layer_maps: [1, 3, 5, 7, 9, 11]
123
+ skipbert.hid_layer_maps: [6, 7, 8, 9, 10, 11, 12]
124
+
125
+ ### END ###
126
+
127
+
128
+
129
+
130
+ # Define LoRA settings
131
+ model.lora.lora-r = 8
132
+ model.lora.lora-alpha = 16
133
+ model.lora.lora-dropout = 0.05
134
+ model.lora.lora-target-modules = "q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj"
135
+
136
+ # Define training settings
137
+ train.save-every-round = 5
138
+ train.learning-rate-max = 5e-5
139
+ train.learning-rate-min = 1e-6
140
+ train.seq-length = 256
141
+ train.prompt_template_name = "alpaca"
142
+ train.train_on_inputs = true
143
+ train.verbose = false
144
+
145
+ # Define training agruments for HF Trainer
146
+ train.training-arguments.output-dir = ""
147
+ train.training-arguments.learning-rate = 3e-4
148
+ train.training-arguments.per-device-train-batch-size = 4
149
+ train.training-arguments.gradient-accumulation-steps = 1
150
+ train.training-arguments.per-device-eval-batch-size = 2
151
+ train.training-arguments.eval-accumulation-steps = 1
152
+ train.training-arguments.logging-steps = 10
153
+ train.training-arguments.num-train-epochs = 1
154
+ train.training-arguments.max-steps = 10
155
+ train.training-arguments.save-steps = 1000
156
+ train.training-arguments.save-total-limit = 10
157
+ train.training-arguments.gradient-checkpointing = true
158
+ train.training-arguments.lr-scheduler-type = "cosine"
159
+ train.training-arguments.warmup-steps = 0
160
+ train.training-arguments.do-train = true
161
+ train.training-arguments.do-eval = true
162
+ train.training-arguments.dataloader-drop-last = false
163
+ train.training-arguments.eval-strategy = "epoch"
164
+ train.training-arguments.save-strategy = "epoch"
165
+ train.training-arguments.ddp-find-unused-parameters = false
166
+ train.training-arguments.group-by-length = true
167
+ train.training-arguments.load_best_model_at_end = true
168
+ train.training-arguments.report-to = "wandb"
169
+
170
+ # Define local training settings
171
+ train.strategy.fraction-fit = 0.2
172
+ train.strategy.fraction-evaluate = 0.0
173
+
174
+ [tool.flwr.federations]
175
+ default = "local-simulation"
176
+
177
+ [tool.flwr.federations.local-simulation]
178
+ options.num-supernodes = 10
179
+ options.backend.client-resources.num-cpus = 8
180
+ options.backend.client-resources.num-gpus = 1.0
template_FL/src/tesst.ipynb ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "013c8cfe-254e-4ee8-81e0-27478628c8e9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import tomllib\n",
11
+ "\n",
12
+ "with open(\"config.toml\", \"rb\") as f:\n",
13
+ " config = tomllib.load(f)\n",
14
+ "\n",
15
+ "numbers = config[\"my_integers\"] # Returns Python list: [1, 2, 3, 42]"
16
+ ]
17
+ }
18
+ ],
19
+ "metadata": {
20
+ "kernelspec": {
21
+ "display_name": "Python 3 (ipykernel)",
22
+ "language": "python",
23
+ "name": "python3"
24
+ },
25
+ "language_info": {
26
+ "codemirror_mode": {
27
+ "name": "ipython",
28
+ "version": 3
29
+ },
30
+ "file_extension": ".py",
31
+ "mimetype": "text/x-python",
32
+ "name": "python",
33
+ "nbconvert_exporter": "python",
34
+ "pygments_lexer": "ipython3",
35
+ "version": "3.11.7"
36
+ }
37
+ },
38
+ "nbformat": 4,
39
+ "nbformat_minor": 5
40
+ }