rmaphoh commited on
Commit
2ac90ad
·
1 Parent(s): 80fed16
Example.ipynb ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b2e049c7-d5db-45e6-b651-2601c02f4b7d",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Data organisation example - IDRiD"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "16b65740-249b-4eef-9298-1db01f72d050",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import shutil\n",
20
+ "import pandas as pd\n",
21
+ "from sklearn.model_selection import train_test_split"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "b12bad44",
27
+ "metadata": {},
28
+ "source": []
29
+ },
30
+ {
31
+ "attachments": {},
32
+ "cell_type": "markdown",
33
+ "id": "ff0bf26e-c657-49de-8761-89d5a94c390d",
34
+ "metadata": {},
35
+ "source": [
36
+ "### Split val set from train data\n",
37
+ "- Download dataset from [official website](https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid) \n",
38
+ "- Images can be processed if necessary, with any processing tools such as [AutoMorph](https://github.com/rmaphoh/AutoMorph)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "4bc1cb67-0adf-4640-8640-d0740a39366b",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "list_ = pd.read_csv('IDRiD_Disease_Grading_Training_Labels.csv')"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "b85fc0d1-2049-4550-bdec-76240b1bc759",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "noDR = list_.loc[list_['Retinopathy grade']==0, 'Image name']\n",
59
+ "mildDR = list_.loc[list_['Retinopathy grade']==1, 'Image name']\n",
60
+ "moderateDR = list_.loc[list_['Retinopathy grade']==2, 'Image name']\n",
61
+ "severeDR = list_.loc[list_['Retinopathy grade']==3, 'Image name']\n",
62
+ "proDR = list_.loc[list_['Retinopathy grade']==4, 'Image name']"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "d0617e35-8b91-45d3-90d5-d5e5bf2d7762",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "noDR_train, noDR_val = train_test_split(noDR, test_size=0.2,random_state=1)\n",
73
+ "mildDR_train, mildDR_val = train_test_split(mildDR, test_size=0.2,random_state=1)\n",
74
+ "moderateDR_train, moderateDR_val = train_test_split(moderateDR, test_size=0.2,random_state=1)\n",
75
+ "severeDR_train, severeDR_val = train_test_split(severeDR, test_size=0.2,random_state=1)\n",
76
+ "proDR_train, proDR_val = train_test_split(proDR, test_size=0.2,random_state=1)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "id": "f30ce03f-5730-4e68-b6c5-8e1b6b9167f8",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "for i in noDR_train:\n",
87
+ " shutil.copy('./train_processed/{}.png'.format(i), './train/a_noDR/{}.png'.format(i))\n",
88
+ " \n",
89
+ "for i in mildDR_train:\n",
90
+ " shutil.copy('./train_processed/{}.png'.format(i), './train/b_mildDR/{}.png'.format(i))\n",
91
+ " \n",
92
+ "for i in moderateDR_train:\n",
93
+ " shutil.copy('./train_processed/{}.png'.format(i), './train/c_moderateDR/{}.png'.format(i))\n",
94
+ " \n",
95
+ "for i in severeDR_train:\n",
96
+ " shutil.copy('./train_processed/{}.png'.format(i), './train/d_severeDR/{}.png'.format(i))\n",
97
+ " \n",
98
+ "for i in proDR_train:\n",
99
+ " shutil.copy('./train_processed/{}.png'.format(i), './train/e_proDR/{}.png'.format(i))"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "id": "196d1845-3e5e-4d38-82e5-66057a693962",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "for i in noDR_val:\n",
110
+ " shutil.copy('./train_processed/{}.png'.format(i), './val/a_noDR/{}.png'.format(i))\n",
111
+ " \n",
112
+ "for i in mildDR_val:\n",
113
+ " shutil.copy('./train_processed/{}.png'.format(i), './val/b_mildDR/{}.png'.format(i))\n",
114
+ " \n",
115
+ "for i in moderateDR_val:\n",
116
+ " shutil.copy('./train_processed/{}.png'.format(i), './val/c_moderateDR/{}.png'.format(i))\n",
117
+ " \n",
118
+ "for i in severeDR_val:\n",
119
+ " shutil.copy('./train_processed/{}.png'.format(i), './val/d_severeDR/{}.png'.format(i))\n",
120
+ " \n",
121
+ "for i in proDR_val:\n",
122
+ " shutil.copy('./train_processed/{}.png'.format(i), './val/e_proDR/{}.png'.format(i))"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "faf285f4-9079-49ca-9d99-8f3f5718afbf",
128
+ "metadata": {},
129
+ "source": [
130
+ "### Organise test set"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "118d15d0-9e94-4f6e-855d-dfa3796b24d2",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "list_test = pd.read_csv('IDRiD_Disease_Grading_Testing_Labels.csv')"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "89a098fe-0aad-41d4-ab09-476ff0354c77",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "noDR_test = list_test.loc[list_test['Retinopathy grade']==0, 'Image name']\n",
151
+ "mildDR_test = list_test.loc[list_test['Retinopathy grade']==1, 'Image name']\n",
152
+ "moderateDR_test = list_test.loc[list_test['Retinopathy grade']==2, 'Image name']\n",
153
+ "severeDR_test = list_test.loc[list_test['Retinopathy grade']==3, 'Image name']\n",
154
+ "proDR_test = list_test.loc[list_test['Retinopathy grade']==4, 'Image name']"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "id": "33a207c1-1fef-4e79-8ff2-84329062495b",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "for i in noDR_test:\n",
165
+ " shutil.copy('./test_processed/{}.png'.format(i), './test/a_noDR/{}.png'.format(i))\n",
166
+ " \n",
167
+ "for i in mildDR_test:\n",
168
+ " shutil.copy('./test_processed/{}.png'.format(i), './test/b_mildDR/{}.png'.format(i))\n",
169
+ " \n",
170
+ "for i in moderateDR_test:\n",
171
+ " shutil.copy('./test_processed/{}.png'.format(i), './test/c_moderateDR/{}.png'.format(i))\n",
172
+ " \n",
173
+ "for i in severeDR_test:\n",
174
+ " shutil.copy('./test_processed/{}.png'.format(i), './test/d_severeDR/{}.png'.format(i))\n",
175
+ " \n",
176
+ "for i in proDR_test:\n",
177
+ " shutil.copy('./test_processed/{}.png'.format(i), './test/e_proDR/{}.png'.format(i))"
178
+ ]
179
+ }
180
+ ],
181
+ "metadata": {
182
+ "environment": {
183
+ "kernel": "python3",
184
+ "name": "common-cu110.m91",
185
+ "type": "gcloud",
186
+ "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
187
+ },
188
+ "kernelspec": {
189
+ "display_name": "Python 3",
190
+ "language": "python",
191
+ "name": "python3"
192
+ },
193
+ "language_info": {
194
+ "codemirror_mode": {
195
+ "name": "ipython",
196
+ "version": 3
197
+ },
198
+ "file_extension": ".py",
199
+ "mimetype": "text/x-python",
200
+ "name": "python",
201
+ "nbconvert_exporter": "python",
202
+ "pygments_lexer": "ipython3",
203
+ "version": "3.7.12"
204
+ }
205
+ },
206
+ "nbformat": 4,
207
+ "nbformat_minor": 5
208
+ }
LICENSE ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Attribution-NonCommercial 4.0 International
3
+
4
+ =======================================================================
5
+
6
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
7
+ does not provide legal services or legal advice. Distribution of
8
+ Creative Commons public licenses does not create a lawyer-client or
9
+ other relationship. Creative Commons makes its licenses and related
10
+ information available on an "as-is" basis. Creative Commons gives no
11
+ warranties regarding its licenses, any material licensed under their
12
+ terms and conditions, or any related information. Creative Commons
13
+ disclaims all liability for damages resulting from their use to the
14
+ fullest extent possible.
15
+
16
+ Using Creative Commons Public Licenses
17
+
18
+ Creative Commons public licenses provide a standard set of terms and
19
+ conditions that creators and other rights holders may use to share
20
+ original works of authorship and other material subject to copyright
21
+ and certain other rights specified in the public license below. The
22
+ following considerations are for informational purposes only, are not
23
+ exhaustive, and do not form part of our licenses.
24
+
25
+ Considerations for licensors: Our public licenses are
26
+ intended for use by those authorized to give the public
27
+ permission to use material in ways otherwise restricted by
28
+ copyright and certain other rights. Our licenses are
29
+ irrevocable. Licensors should read and understand the terms
30
+ and conditions of the license they choose before applying it.
31
+ Licensors should also secure all rights necessary before
32
+ applying our licenses so that the public can reuse the
33
+ material as expected. Licensors should clearly mark any
34
+ material not subject to the license. This includes other CC-
35
+ licensed material, or material used under an exception or
36
+ limitation to copyright. More considerations for licensors:
37
+ wiki.creativecommons.org/Considerations_for_licensors
38
+
39
+ Considerations for the public: By using one of our public
40
+ licenses, a licensor grants the public permission to use the
41
+ licensed material under specified terms and conditions. If
42
+ the licensor's permission is not necessary for any reason--for
43
+ example, because of any applicable exception or limitation to
44
+ copyright--then that use is not regulated by the license. Our
45
+ licenses grant only permissions under copyright and certain
46
+ other rights that a licensor has authority to grant. Use of
47
+ the licensed material may still be restricted for other
48
+ reasons, including because others have copyright or other
49
+ rights in the material. A licensor may make special requests,
50
+ such as asking that all changes be marked or described.
51
+ Although not required by our licenses, you are encouraged to
52
+ respect those requests where reasonable. More_considerations
53
+ for the public:
54
+ wiki.creativecommons.org/Considerations_for_licensees
55
+
56
+ =======================================================================
57
+
58
+ Creative Commons Attribution-NonCommercial 4.0 International Public
59
+ License
60
+
61
+ By exercising the Licensed Rights (defined below), You accept and agree
62
+ to be bound by the terms and conditions of this Creative Commons
63
+ Attribution-NonCommercial 4.0 International Public License ("Public
64
+ License"). To the extent this Public License may be interpreted as a
65
+ contract, You are granted the Licensed Rights in consideration of Your
66
+ acceptance of these terms and conditions, and the Licensor grants You
67
+ such rights in consideration of benefits the Licensor receives from
68
+ making the Licensed Material available under these terms and
69
+ conditions.
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+ Section 2 -- Scope.
142
+
143
+ a. License grant.
144
+
145
+ 1. Subject to the terms and conditions of this Public License,
146
+ the Licensor hereby grants You a worldwide, royalty-free,
147
+ non-sublicensable, non-exclusive, irrevocable license to
148
+ exercise the Licensed Rights in the Licensed Material to:
149
+
150
+ a. reproduce and Share the Licensed Material, in whole or
151
+ in part, for NonCommercial purposes only; and
152
+
153
+ b. produce, reproduce, and Share Adapted Material for
154
+ NonCommercial purposes only.
155
+
156
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
157
+ Exceptions and Limitations apply to Your use, this Public
158
+ License does not apply, and You do not need to comply with
159
+ its terms and conditions.
160
+
161
+ 3. Term. The term of this Public License is specified in Section
162
+ 6(a).
163
+
164
+ 4. Media and formats; technical modifications allowed. The
165
+ Licensor authorizes You to exercise the Licensed Rights in
166
+ all media and formats whether now known or hereafter created,
167
+ and to make technical modifications necessary to do so. The
168
+ Licensor waives and/or agrees not to assert any right or
169
+ authority to forbid You from making technical modifications
170
+ necessary to exercise the Licensed Rights, including
171
+ technical modifications necessary to circumvent Effective
172
+ Technological Measures. For purposes of this Public License,
173
+ simply making modifications authorized by this Section 2(a)
174
+ (4) never produces Adapted Material.
175
+
176
+ 5. Downstream recipients.
177
+
178
+ a. Offer from the Licensor -- Licensed Material. Every
179
+ recipient of the Licensed Material automatically
180
+ receives an offer from the Licensor to exercise the
181
+ Licensed Rights under the terms and conditions of this
182
+ Public License.
183
+
184
+ b. No downstream restrictions. You may not offer or impose
185
+ any additional or different terms or conditions on, or
186
+ apply any Effective Technological Measures to, the
187
+ Licensed Material if doing so restricts exercise of the
188
+ Licensed Rights by any recipient of the Licensed
189
+ Material.
190
+
191
+ 6. No endorsement. Nothing in this Public License constitutes or
192
+ may be construed as permission to assert or imply that You
193
+ are, or that Your use of the Licensed Material is, connected
194
+ with, or sponsored, endorsed, or granted official status by,
195
+ the Licensor or others designated to receive attribution as
196
+ provided in Section 3(a)(1)(A)(i).
197
+
198
+ b. Other rights.
199
+
200
+ 1. Moral rights, such as the right of integrity, are not
201
+ licensed under this Public License, nor are publicity,
202
+ privacy, and/or other similar personality rights; however, to
203
+ the extent possible, the Licensor waives and/or agrees not to
204
+ assert any such rights held by the Licensor to the limited
205
+ extent necessary to allow You to exercise the Licensed
206
+ Rights, but not otherwise.
207
+
208
+ 2. Patent and trademark rights are not licensed under this
209
+ Public License.
210
+
211
+ 3. To the extent possible, the Licensor waives any right to
212
+ collect royalties from You for the exercise of the Licensed
213
+ Rights, whether directly or through a collecting society
214
+ under any voluntary or waivable statutory or compulsory
215
+ licensing scheme. In all other cases the Licensor expressly
216
+ reserves any right to collect such royalties, including when
217
+ the Licensed Material is used other than for NonCommercial
218
+ purposes.
219
+
220
+ Section 3 -- License Conditions.
221
+
222
+ Your exercise of the Licensed Rights is expressly made subject to the
223
+ following conditions.
224
+
225
+ a. Attribution.
226
+
227
+ 1. If You Share the Licensed Material (including in modified
228
+ form), You must:
229
+
230
+ a. retain the following if it is supplied by the Licensor
231
+ with the Licensed Material:
232
+
233
+ i. identification of the creator(s) of the Licensed
234
+ Material and any others designated to receive
235
+ attribution, in any reasonable manner requested by
236
+ the Licensor (including by pseudonym if
237
+ designated);
238
+
239
+ ii. a copyright notice;
240
+
241
+ iii. a notice that refers to this Public License;
242
+
243
+ iv. a notice that refers to the disclaimer of
244
+ warranties;
245
+
246
+ v. a URI or hyperlink to the Licensed Material to the
247
+ extent reasonably practicable;
248
+
249
+ b. indicate if You modified the Licensed Material and
250
+ retain an indication of any previous modifications; and
251
+
252
+ c. indicate the Licensed Material is licensed under this
253
+ Public License, and include the text of, or the URI or
254
+ hyperlink to, this Public License.
255
+
256
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
257
+ reasonable manner based on the medium, means, and context in
258
+ which You Share the Licensed Material. For example, it may be
259
+ reasonable to satisfy the conditions by providing a URI or
260
+ hyperlink to a resource that includes the required
261
+ information.
262
+
263
+ 3. If requested by the Licensor, You must remove any of the
264
+ information required by Section 3(a)(1)(A) to the extent
265
+ reasonably practicable.
266
+
267
+ 4. If You Share Adapted Material You produce, the Adapter's
268
+ License You apply must not prevent recipients of the Adapted
269
+ Material from complying with this Public License.
270
+
271
+ Section 4 -- Sui Generis Database Rights.
272
+
273
+ Where the Licensed Rights include Sui Generis Database Rights that
274
+ apply to Your use of the Licensed Material:
275
+
276
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277
+ to extract, reuse, reproduce, and Share all or a substantial
278
+ portion of the contents of the database for NonCommercial purposes
279
+ only;
280
+
281
+ b. if You include all or a substantial portion of the database
282
+ contents in a database in which You have Sui Generis Database
283
+ Rights, then the database in which You have Sui Generis Database
284
+ Rights (but not its individual contents) is Adapted Material; and
285
+
286
+ c. You must comply with the conditions in Section 3(a) if You Share
287
+ all or a substantial portion of the contents of the database.
288
+
289
+ For the avoidance of doubt, this Section 4 supplements and does not
290
+ replace Your obligations under this Public License where the Licensed
291
+ Rights include other Copyright and Similar Rights.
292
+
293
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294
+
295
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305
+
306
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315
+
316
+ c. The disclaimer of warranties and limitation of liability provided
317
+ above shall be interpreted in a manner that, to the extent
318
+ possible, most closely approximates an absolute disclaimer and
319
+ waiver of all liability.
320
+
321
+ Section 6 -- Term and Termination.
322
+
323
+ a. This Public License applies for the term of the Copyright and
324
+ Similar Rights licensed here. However, if You fail to comply with
325
+ this Public License, then Your rights under this Public License
326
+ terminate automatically.
327
+
328
+ b. Where Your right to use the Licensed Material has terminated under
329
+ Section 6(a), it reinstates:
330
+
331
+ 1. automatically as of the date the violation is cured, provided
332
+ it is cured within 30 days of Your discovery of the
333
+ violation; or
334
+
335
+ 2. upon express reinstatement by the Licensor.
336
+
337
+ For the avoidance of doubt, this Section 6(b) does not affect any
338
+ right the Licensor may have to seek remedies for Your violations
339
+ of this Public License.
340
+
341
+ c. For the avoidance of doubt, the Licensor may also offer the
342
+ Licensed Material under separate terms or conditions or stop
343
+ distributing the Licensed Material at any time; however, doing so
344
+ will not terminate this Public License.
345
+
346
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
+ License.
348
+
349
+ Section 7 -- Other Terms and Conditions.
350
+
351
+ a. The Licensor shall not be bound by any additional or different
352
+ terms or conditions communicated by You unless expressly agreed.
353
+
354
+ b. Any arrangements, understandings, or agreements regarding the
355
+ Licensed Material not stated herein are separate from and
356
+ independent of the terms and conditions of this Public License.
357
+
358
+ Section 8 -- Interpretation.
359
+
360
+ a. For the avoidance of doubt, this Public License does not, and
361
+ shall not be interpreted to, reduce, limit, restrict, or impose
362
+ conditions on any use of the Licensed Material that could lawfully
363
+ be made without permission under this Public License.
364
+
365
+ b. To the extent possible, if any provision of this Public License is
366
+ deemed unenforceable, it shall be automatically reformed to the
367
+ minimum extent necessary to make it enforceable. If the provision
368
+ cannot be reformed, it shall be severed from this Public License
369
+ without affecting the enforceability of the remaining terms and
370
+ conditions.
371
+
372
+ c. No term or condition of this Public License will be waived and no
373
+ failure to comply consented to unless expressly agreed to by the
374
+ Licensor.
375
+
376
+ d. Nothing in this Public License constitutes or may be interpreted
377
+ as a limitation upon, or waiver of, any privileges and immunities
378
+ that apply to the Licensor or You, including from the legal
379
+ processes of any jurisdiction or authority.
380
+
381
+ =======================================================================
382
+
383
+ Creative Commons is not a party to its public
384
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
385
+ its public licenses to material it publishes and in those instances
386
+ will be considered the “Licensor.” The text of the Creative Commons
387
+ public licenses is dedicated to the public domain under the CC0 Public
388
+ Domain Dedication. Except for the limited purpose of indicating that
389
+ material is shared under a Creative Commons public license or as
390
+ otherwise permitted by the Creative Commons policies published at
391
+ creativecommons.org/policies, Creative Commons does not authorize the
392
+ use of the trademark "Creative Commons" or any other trademark or logo
393
+ of Creative Commons without its prior written consent including,
394
+ without limitation, in connection with any unauthorized modifications
395
+ to any of its public licenses or any other arrangements,
396
+ understandings, or agreements concerning use of licensed material. For
397
+ the avoidance of doubt, this paragraph does not form part of the
398
+ public licenses.
399
+
400
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,3 +1,144 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## RETFound - A foundation model for retinal imaging
2
+
3
+
4
+ This is the official repo for RETFound, which is based on [MAE](https://github.com/facebookresearch/mae):
5
+
6
+ Please contact **[email protected]** or **[email protected]** if you have questions.
7
+
8
+ Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
9
+
10
+
11
+ ### Key features
12
+
13
+ - RETFound is pre-trained on 1.6 million retinal images with self-supervised learning
14
+ - RETFound has been validated in multiple disease detection tasks
15
+ - RETFound can be efficiently adapted to customised tasks
16
+
17
+
18
+ ### News
19
+
20
+ - A [visualisation demo](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_visualize.ipynb) is added
21
+
22
+ ### Install enviroment
23
+
24
+ Create enviroment with conda:
25
+
26
+ ```
27
+ conda create -n retfound python=3.7.5 -y
28
+ conda activate retfound
29
+ ```
30
+
31
+ Install Pytorch 1.81 (cuda 11.1)
32
+ ```
33
+ pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
34
+ ```
35
+
36
+ Install others
37
+ ```
38
+ git clone https://github.com/rmaphoh/RETFound_MAE/
39
+ cd RETFound_MAE
40
+ pip install -r requirement.txt
41
+ ```
42
+
43
+
44
+ ### Fine-tuning with RETFound weights
45
+
46
+ - RETFound pre-trained weights
47
+ <table><tbody>
48
+ <!-- START TABLE -->
49
+ <!-- TABLE HEADER -->
50
+ <th valign="bottom"></th>
51
+ <th valign="bottom">ViT-Large</th>
52
+ <!-- TABLE BODY -->
53
+ <tr><td align="left">Colour fundus image</td>
54
+ <td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td>
55
+ </tr>
56
+ <!-- TABLE BODY -->
57
+ <tr><td align="left">OCT</td>
58
+ <td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td>
59
+ </tr>
60
+ </tbody></table>
61
+
62
+ - Organise data (using IDRiD as an [example](Example.ipynb))
63
+
64
+ <p align="left">
65
+ <img src="./pic/file_index.jpg" width="160">
66
+ </p>
67
+
68
+
69
+ - Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
70
+
71
+
72
+ ```
73
+ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
74
+ --batch_size 16 \
75
+ --world_size 1 \
76
+ --model vit_large_patch16 \
77
+ --epochs 50 \
78
+ --blr 5e-3 --layer_decay 0.65 \
79
+ --weight_decay 0.05 --drop_path 0.2 \
80
+ --nb_classes 5 \
81
+ --data_path ./IDRiD_data/ \
82
+ --task ./finetune_IDRiD/ \
83
+ --finetune ./RETFound_cfp_weights.pth
84
+
85
+ ```
86
+
87
+
88
+ - For evaluation only
89
+
90
+
91
+ ```
92
+ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
93
+ --eval --batch_size 16 \
94
+ --world_size 1 \
95
+ --model vit_large_patch16 \
96
+ --epochs 50 \
97
+ --blr 5e-3 --layer_decay 0.65 \
98
+ --weight_decay 0.05 --drop_path 0.2 \
99
+ --nb_classes 5 \
100
+ --data_path ./IDRiD_data/ \
101
+ --task ./internal_IDRiD/ \
102
+ --resume ./finetune_IDRiD/checkpoint-best.pth
103
+
104
+ ```
105
+
106
+
107
+ ### Load the model and weights (if you want to call the model in your code)
108
+
109
+ ```
110
+ import torch
111
+ import models_vit
112
+ from util.pos_embed import interpolate_pos_embed
113
+ from timm.models.layers import trunc_normal_
114
+
115
+ # call the model
116
+ model = models_vit.__dict__['vit_large_patch16'](
117
+ num_classes=2,
118
+ drop_path_rate=0.2,
119
+ global_pool=True,
120
+ )
121
+
122
+ # load RETFound weights
123
+ checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
124
+ checkpoint_model = checkpoint['model']
125
+ state_dict = model.state_dict()
126
+ for k in ['head.weight', 'head.bias']:
127
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
128
+ print(f"Removing key {k} from pretrained checkpoint")
129
+ del checkpoint_model[k]
130
+
131
+ # interpolate position embedding
132
+ interpolate_pos_embed(model, checkpoint_model)
133
+
134
+ # load pre-trained model
135
+ msg = model.load_state_dict(checkpoint_model, strict=False)
136
+
137
+ assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
138
+
139
+ # manually initialize fc layer
140
+ trunc_normal_(model.head.weight, std=2e-5)
141
+
142
+ print("Model = %s" % str(model))
143
+ ```
144
+
RETFound_visualize.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
engine_finetune.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import sys
8
+ import csv
9
+ import os
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from timm.data import Mixup
14
+ from timm.utils import accuracy
15
+ from typing import Iterable, Optional
16
+ import util.misc as misc
17
+ import util.lr_sched as lr_sched
18
+ from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix
19
+ from pycm import *
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+
23
+
24
+
25
+
26
+ def misc_measures(confusion_matrix):
27
+
28
+ acc = []
29
+ sensitivity = []
30
+ specificity = []
31
+ precision = []
32
+ G = []
33
+ F1_score_2 = []
34
+ mcc_ = []
35
+
36
+ for i in range(1, confusion_matrix.shape[0]):
37
+ cm1=confusion_matrix[i]
38
+ acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1))
39
+ sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1])
40
+ sensitivity.append(sensitivity_)
41
+ specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0])
42
+ specificity.append(specificity_)
43
+ precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1])
44
+ precision.append(precision_)
45
+ G.append(np.sqrt(sensitivity_*specificity_))
46
+ F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_))
47
+ mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1]))
48
+ mcc_.append(mcc)
49
+
50
+ acc = np.array(acc).mean()
51
+ sensitivity = np.array(sensitivity).mean()
52
+ specificity = np.array(specificity).mean()
53
+ precision = np.array(precision).mean()
54
+ G = np.array(G).mean()
55
+ F1_score_2 = np.array(F1_score_2).mean()
56
+ mcc_ = np.array(mcc_).mean()
57
+
58
+ return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_
59
+
60
+
61
+
62
+
63
+
64
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
65
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
66
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
67
+ mixup_fn: Optional[Mixup] = None, log_writer=None,
68
+ args=None):
69
+ model.train(True)
70
+ metric_logger = misc.MetricLogger(delimiter=" ")
71
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
72
+ header = 'Epoch: [{}]'.format(epoch)
73
+ print_freq = 20
74
+
75
+ accum_iter = args.accum_iter
76
+
77
+ optimizer.zero_grad()
78
+
79
+ if log_writer is not None:
80
+ print('log_dir: {}'.format(log_writer.log_dir))
81
+
82
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
83
+
84
+ # we use a per iteration (instead of per epoch) lr scheduler
85
+ if data_iter_step % accum_iter == 0:
86
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
87
+
88
+ samples = samples.to(device, non_blocking=True)
89
+ targets = targets.to(device, non_blocking=True)
90
+
91
+ if mixup_fn is not None:
92
+ samples, targets = mixup_fn(samples, targets)
93
+
94
+ with torch.cuda.amp.autocast():
95
+ outputs = model(samples)
96
+ loss = criterion(outputs, targets)
97
+
98
+ loss_value = loss.item()
99
+
100
+ if not math.isfinite(loss_value):
101
+ print("Loss is {}, stopping training".format(loss_value))
102
+ sys.exit(1)
103
+
104
+ loss /= accum_iter
105
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
106
+ parameters=model.parameters(), create_graph=False,
107
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
108
+ if (data_iter_step + 1) % accum_iter == 0:
109
+ optimizer.zero_grad()
110
+
111
+ torch.cuda.synchronize()
112
+
113
+ metric_logger.update(loss=loss_value)
114
+ min_lr = 10.
115
+ max_lr = 0.
116
+ for group in optimizer.param_groups:
117
+ min_lr = min(min_lr, group["lr"])
118
+ max_lr = max(max_lr, group["lr"])
119
+
120
+ metric_logger.update(lr=max_lr)
121
+
122
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
123
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
124
+ """ We use epoch_1000x as the x-axis in tensorboard.
125
+ This calibrates different curves when batch size changes.
126
+ """
127
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
128
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
129
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
130
+
131
+ # gather the stats from all processes
132
+ metric_logger.synchronize_between_processes()
133
+ print("Averaged stats:", metric_logger)
134
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
135
+
136
+
137
+
138
+
139
+ @torch.no_grad()
140
+ def evaluate(data_loader, model, device, task, epoch, mode, num_class):
141
+ criterion = torch.nn.CrossEntropyLoss()
142
+
143
+ metric_logger = misc.MetricLogger(delimiter=" ")
144
+ header = 'Test:'
145
+
146
+ if not os.path.exists(task):
147
+ os.makedirs(task)
148
+
149
+ prediction_decode_list = []
150
+ prediction_list = []
151
+ true_label_decode_list = []
152
+ true_label_onehot_list = []
153
+
154
+ # switch to evaluation mode
155
+ model.eval()
156
+
157
+ for batch in metric_logger.log_every(data_loader, 10, header):
158
+ images = batch[0]
159
+ target = batch[-1]
160
+ images = images.to(device, non_blocking=True)
161
+ target = target.to(device, non_blocking=True)
162
+ true_label=F.one_hot(target.to(torch.int64), num_classes=num_class)
163
+
164
+ # compute output
165
+ with torch.cuda.amp.autocast():
166
+ output = model(images)
167
+ loss = criterion(output, target)
168
+ prediction_softmax = nn.Softmax(dim=1)(output)
169
+ _,prediction_decode = torch.max(prediction_softmax, 1)
170
+ _,true_label_decode = torch.max(true_label, 1)
171
+
172
+ prediction_decode_list.extend(prediction_decode.cpu().detach().numpy())
173
+ true_label_decode_list.extend(true_label_decode.cpu().detach().numpy())
174
+ true_label_onehot_list.extend(true_label.cpu().detach().numpy())
175
+ prediction_list.extend(prediction_softmax.cpu().detach().numpy())
176
+
177
+ acc1,_ = accuracy(output, target, topk=(1,2))
178
+
179
+ batch_size = images.shape[0]
180
+ metric_logger.update(loss=loss.item())
181
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
182
+ # gather the stats from all processes
183
+ true_label_decode_list = np.array(true_label_decode_list)
184
+ prediction_decode_list = np.array(prediction_decode_list)
185
+ confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)])
186
+ acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix)
187
+
188
+ auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro')
189
+ auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro')
190
+
191
+ metric_logger.synchronize_between_processes()
192
+
193
+ print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc))
194
+ results_path = task+'_metrics_{}.csv'.format(mode)
195
+ with open(results_path,mode='a',newline='',encoding='utf8') as cfa:
196
+ wf = csv.writer(cfa)
197
+ data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]]
198
+ for i in data2:
199
+ wf.writerow(i)
200
+
201
+
202
+ if mode=='test':
203
+ cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list)
204
+ cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib")
205
+ plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight')
206
+
207
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc
208
+
main_finetune.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import argparse
7
+ import datetime
8
+ import json
9
+ import numpy as np
10
+ import os
11
+ import time
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ import torch.backends.cudnn as cudnn
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ import timm
19
+
20
+ assert timm.__version__ == "0.3.2" # version check
21
+ from timm.models.layers import trunc_normal_
22
+ from timm.data.mixup import Mixup
23
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
24
+
25
+ import util.lr_decay as lrd
26
+ import util.misc as misc
27
+ from util.datasets import build_dataset
28
+ from util.pos_embed import interpolate_pos_embed
29
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
30
+
31
+ import models_vit
32
+
33
+ from engine_finetune import train_one_epoch, evaluate
34
+
35
+
36
+ def get_args_parser():
37
+ parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
38
+ parser.add_argument('--batch_size', default=64, type=int,
39
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
40
+ parser.add_argument('--epochs', default=50, type=int)
41
+ parser.add_argument('--accum_iter', default=1, type=int,
42
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
43
+
44
+ # Model parameters
45
+ parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
46
+ help='Name of model to train')
47
+
48
+ parser.add_argument('--input_size', default=224, type=int,
49
+ help='images input size')
50
+
51
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
52
+ help='Drop path rate (default: 0.1)')
53
+
54
+ # Optimizer parameters
55
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
56
+ help='Clip gradient norm (default: None, no clipping)')
57
+ parser.add_argument('--weight_decay', type=float, default=0.05,
58
+ help='weight decay (default: 0.05)')
59
+
60
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
61
+ help='learning rate (absolute lr)')
62
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
63
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
64
+ parser.add_argument('--layer_decay', type=float, default=0.75,
65
+ help='layer-wise lr decay from ELECTRA/BEiT')
66
+
67
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
68
+ help='lower lr bound for cyclic schedulers that hit 0')
69
+
70
+ parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
71
+ help='epochs to warmup LR')
72
+
73
+ # Augmentation parameters
74
+ parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
75
+ help='Color jitter factor (enabled only when not using Auto/RandAug)')
76
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
77
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
78
+ parser.add_argument('--smoothing', type=float, default=0.1,
79
+ help='Label smoothing (default: 0.1)')
80
+
81
+ # * Random Erase params
82
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
83
+ help='Random erase prob (default: 0.25)')
84
+ parser.add_argument('--remode', type=str, default='pixel',
85
+ help='Random erase mode (default: "pixel")')
86
+ parser.add_argument('--recount', type=int, default=1,
87
+ help='Random erase count (default: 1)')
88
+ parser.add_argument('--resplit', action='store_true', default=False,
89
+ help='Do not random erase first (clean) augmentation split')
90
+
91
+ # * Mixup params
92
+ parser.add_argument('--mixup', type=float, default=0,
93
+ help='mixup alpha, mixup enabled if > 0.')
94
+ parser.add_argument('--cutmix', type=float, default=0,
95
+ help='cutmix alpha, cutmix enabled if > 0.')
96
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
97
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
98
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
99
+ help='Probability of performing mixup or cutmix when either/both is enabled')
100
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
101
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
102
+ parser.add_argument('--mixup_mode', type=str, default='batch',
103
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
104
+
105
+ # * Finetuning params
106
+ parser.add_argument('--finetune', default='',type=str,
107
+ help='finetune from checkpoint')
108
+ parser.add_argument('--task', default='',type=str,
109
+ help='finetune from checkpoint')
110
+ parser.add_argument('--global_pool', action='store_true')
111
+ parser.set_defaults(global_pool=True)
112
+ parser.add_argument('--cls_token', action='store_false', dest='global_pool',
113
+ help='Use class token instead of global pool for classification')
114
+
115
+ # Dataset parameters
116
+ parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str,
117
+ help='dataset path')
118
+ parser.add_argument('--nb_classes', default=1000, type=int,
119
+ help='number of the classification types')
120
+
121
+ parser.add_argument('--output_dir', default='./output_dir',
122
+ help='path where to save, empty for no saving')
123
+ parser.add_argument('--log_dir', default='./output_dir',
124
+ help='path where to tensorboard log')
125
+ parser.add_argument('--device', default='cuda',
126
+ help='device to use for training / testing')
127
+ parser.add_argument('--seed', default=0, type=int)
128
+ parser.add_argument('--resume', default='',
129
+ help='resume from checkpoint')
130
+
131
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
132
+ help='start epoch')
133
+ parser.add_argument('--eval', action='store_true',
134
+ help='Perform evaluation only')
135
+ parser.add_argument('--dist_eval', action='store_true', default=False,
136
+ help='Enabling distributed evaluation (recommended during training for faster monitor')
137
+ parser.add_argument('--num_workers', default=10, type=int)
138
+ parser.add_argument('--pin_mem', action='store_true',
139
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
140
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
141
+ parser.set_defaults(pin_mem=True)
142
+
143
+ # distributed training parameters
144
+ parser.add_argument('--world_size', default=1, type=int,
145
+ help='number of distributed processes')
146
+ parser.add_argument('--local_rank', default=-1, type=int)
147
+ parser.add_argument('--dist_on_itp', action='store_true')
148
+ parser.add_argument('--dist_url', default='env://',
149
+ help='url used to set up distributed training')
150
+
151
+ return parser
152
+
153
+
154
+ def main(args):
155
+ misc.init_distributed_mode(args)
156
+
157
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
158
+ print("{}".format(args).replace(', ', ',\n'))
159
+
160
+ device = torch.device(args.device)
161
+
162
+ # fix the seed for reproducibility
163
+ seed = args.seed + misc.get_rank()
164
+ torch.manual_seed(seed)
165
+ np.random.seed(seed)
166
+
167
+ cudnn.benchmark = True
168
+
169
+ dataset_train = build_dataset(is_train='train', args=args)
170
+ dataset_val = build_dataset(is_train='val', args=args)
171
+ dataset_test = build_dataset(is_train='test', args=args)
172
+
173
+ if True: # args.distributed:
174
+ num_tasks = misc.get_world_size()
175
+ global_rank = misc.get_rank()
176
+ sampler_train = torch.utils.data.DistributedSampler(
177
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
178
+ )
179
+ print("Sampler_train = %s" % str(sampler_train))
180
+ if args.dist_eval:
181
+ if len(dataset_val) % num_tasks != 0:
182
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
183
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
184
+ 'equal num of samples per-process.')
185
+ sampler_val = torch.utils.data.DistributedSampler(
186
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
187
+ else:
188
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
189
+
190
+ if args.dist_eval:
191
+ if len(dataset_test) % num_tasks != 0:
192
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
193
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
194
+ 'equal num of samples per-process.')
195
+ sampler_test = torch.utils.data.DistributedSampler(
196
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
197
+ else:
198
+ sampler_test = torch.utils.data.SequentialSampler(dataset_test)
199
+
200
+
201
+ if global_rank == 0 and args.log_dir is not None and not args.eval:
202
+ os.makedirs(args.log_dir, exist_ok=True)
203
+ log_writer = SummaryWriter(log_dir=args.log_dir+args.task)
204
+ else:
205
+ log_writer = None
206
+
207
+ data_loader_train = torch.utils.data.DataLoader(
208
+ dataset_train, sampler=sampler_train,
209
+ batch_size=args.batch_size,
210
+ num_workers=args.num_workers,
211
+ pin_memory=args.pin_mem,
212
+ drop_last=True,
213
+ )
214
+
215
+ data_loader_val = torch.utils.data.DataLoader(
216
+ dataset_val, sampler=sampler_val,
217
+ batch_size=args.batch_size,
218
+ num_workers=args.num_workers,
219
+ pin_memory=args.pin_mem,
220
+ drop_last=False
221
+ )
222
+
223
+ data_loader_test = torch.utils.data.DataLoader(
224
+ dataset_test, sampler=sampler_test,
225
+ batch_size=args.batch_size,
226
+ num_workers=args.num_workers,
227
+ pin_memory=args.pin_mem,
228
+ drop_last=False
229
+ )
230
+
231
+
232
+ mixup_fn = None
233
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
234
+ if mixup_active:
235
+ print("Mixup is activated!")
236
+ mixup_fn = Mixup(
237
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
238
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
239
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
240
+
241
+ model = models_vit.__dict__[args.model](
242
+ num_classes=args.nb_classes,
243
+ drop_path_rate=args.drop_path,
244
+ global_pool=args.global_pool,
245
+ )
246
+
247
+ if args.finetune and not args.eval:
248
+ checkpoint = torch.load(args.finetune, map_location='cpu')
249
+
250
+ print("Load pre-trained checkpoint from: %s" % args.finetune)
251
+ checkpoint_model = checkpoint['model']
252
+ state_dict = model.state_dict()
253
+ for k in ['head.weight', 'head.bias']:
254
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
255
+ print(f"Removing key {k} from pretrained checkpoint")
256
+ del checkpoint_model[k]
257
+
258
+ # interpolate position embedding
259
+ interpolate_pos_embed(model, checkpoint_model)
260
+
261
+ # load pre-trained model
262
+ msg = model.load_state_dict(checkpoint_model, strict=False)
263
+ print(msg)
264
+
265
+ if args.global_pool:
266
+ assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
267
+ else:
268
+ assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
269
+
270
+ # manually initialize fc layer
271
+ trunc_normal_(model.head.weight, std=2e-5)
272
+
273
+ model.to(device)
274
+
275
+ model_without_ddp = model
276
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
277
+
278
+ print("Model = %s" % str(model_without_ddp))
279
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
280
+
281
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
282
+
283
+ if args.lr is None: # only base_lr is specified
284
+ args.lr = args.blr * eff_batch_size / 256
285
+
286
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
287
+ print("actual lr: %.2e" % args.lr)
288
+
289
+ print("accumulate grad iterations: %d" % args.accum_iter)
290
+ print("effective batch size: %d" % eff_batch_size)
291
+
292
+ if args.distributed:
293
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
294
+ model_without_ddp = model.module
295
+
296
+ # build optimizer with layer-wise lr decay (lrd)
297
+ param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
298
+ no_weight_decay_list=model_without_ddp.no_weight_decay(),
299
+ layer_decay=args.layer_decay
300
+ )
301
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
302
+ loss_scaler = NativeScaler()
303
+
304
+ if mixup_fn is not None:
305
+ # smoothing is handled with mixup label transform
306
+ criterion = SoftTargetCrossEntropy()
307
+ elif args.smoothing > 0.:
308
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
309
+ else:
310
+ criterion = torch.nn.CrossEntropyLoss()
311
+
312
+ print("criterion = %s" % str(criterion))
313
+
314
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
315
+
316
+ if args.eval:
317
+ test_stats,auc_roc = evaluate(data_loader_test, model, device, args.task, epoch=0, mode='test',num_class=args.nb_classes)
318
+ exit(0)
319
+
320
+ print(f"Start training for {args.epochs} epochs")
321
+ start_time = time.time()
322
+ max_accuracy = 0.0
323
+ max_auc = 0.0
324
+ for epoch in range(args.start_epoch, args.epochs):
325
+ if args.distributed:
326
+ data_loader_train.sampler.set_epoch(epoch)
327
+ train_stats = train_one_epoch(
328
+ model, criterion, data_loader_train,
329
+ optimizer, device, epoch, loss_scaler,
330
+ args.clip_grad, mixup_fn,
331
+ log_writer=log_writer,
332
+ args=args
333
+ )
334
+
335
+ val_stats,val_auc_roc = evaluate(data_loader_val, model, device,args.task,epoch, mode='val',num_class=args.nb_classes)
336
+ if max_auc<val_auc_roc:
337
+ max_auc = val_auc_roc
338
+
339
+ if args.output_dir:
340
+ misc.save_model(
341
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
342
+ loss_scaler=loss_scaler, epoch=epoch)
343
+
344
+
345
+ if epoch==(args.epochs-1):
346
+ test_stats,auc_roc = evaluate(data_loader_test, model, device,args.task,epoch, mode='test',num_class=args.nb_classes)
347
+
348
+
349
+ if log_writer is not None:
350
+ log_writer.add_scalar('perf/val_acc1', val_stats['acc1'], epoch)
351
+ log_writer.add_scalar('perf/val_auc', val_auc_roc, epoch)
352
+ log_writer.add_scalar('perf/val_loss', val_stats['loss'], epoch)
353
+
354
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
355
+ 'epoch': epoch,
356
+ 'n_parameters': n_parameters}
357
+
358
+ if args.output_dir and misc.is_main_process():
359
+ if log_writer is not None:
360
+ log_writer.flush()
361
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
362
+ f.write(json.dumps(log_stats) + "\n")
363
+
364
+
365
+ total_time = time.time() - start_time
366
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
367
+ print('Training time {}'.format(total_time_str))
368
+
369
+
370
+ if __name__ == '__main__':
371
+ args = get_args_parser()
372
+ args = args.parse_args()
373
+
374
+ if args.output_dir:
375
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
376
+ main(args)
models_mae.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from timm.models.vision_transformer import PatchEmbed, Block
10
+
11
+ from util.pos_embed import get_2d_sincos_pos_embed
12
+
13
+
14
+ class MaskedAutoencoderViT(nn.Module):
15
+ """ Masked Autoencoder with VisionTransformer backbone
16
+ """
17
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
18
+ embed_dim=1024, depth=24, num_heads=16,
19
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
20
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
21
+ super().__init__()
22
+
23
+ # --------------------------------------------------------------------------
24
+ # MAE encoder specifics
25
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
26
+ num_patches = self.patch_embed.num_patches
27
+
28
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
29
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
30
+
31
+ self.blocks = nn.ModuleList([
32
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
33
+ for i in range(depth)])
34
+ self.norm = norm_layer(embed_dim)
35
+ # --------------------------------------------------------------------------
36
+
37
+ # --------------------------------------------------------------------------
38
+ # MAE decoder specifics
39
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
40
+
41
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
42
+
43
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
44
+
45
+ self.decoder_blocks = nn.ModuleList([
46
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
47
+ for i in range(decoder_depth)])
48
+
49
+ self.decoder_norm = norm_layer(decoder_embed_dim)
50
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
51
+ # --------------------------------------------------------------------------
52
+
53
+ self.norm_pix_loss = norm_pix_loss
54
+
55
+ self.initialize_weights()
56
+
57
+ def initialize_weights(self):
58
+ # initialization
59
+ # initialize (and freeze) pos_embed by sin-cos embedding
60
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
61
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
62
+
63
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
64
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
65
+
66
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
67
+ w = self.patch_embed.proj.weight.data
68
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
69
+
70
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
71
+ torch.nn.init.normal_(self.cls_token, std=.02)
72
+ torch.nn.init.normal_(self.mask_token, std=.02)
73
+
74
+ # initialize nn.Linear and nn.LayerNorm
75
+ self.apply(self._init_weights)
76
+
77
+ def _init_weights(self, m):
78
+ if isinstance(m, nn.Linear):
79
+ # we use xavier_uniform following official JAX ViT:
80
+ torch.nn.init.xavier_uniform_(m.weight)
81
+ if isinstance(m, nn.Linear) and m.bias is not None:
82
+ nn.init.constant_(m.bias, 0)
83
+ elif isinstance(m, nn.LayerNorm):
84
+ nn.init.constant_(m.bias, 0)
85
+ nn.init.constant_(m.weight, 1.0)
86
+
87
+ def patchify(self, imgs):
88
+ """
89
+ imgs: (N, 3, H, W)
90
+ x: (N, L, patch_size**2 *3)
91
+ """
92
+ p = self.patch_embed.patch_size[0]
93
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
94
+
95
+ h = w = imgs.shape[2] // p
96
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
97
+ x = torch.einsum('nchpwq->nhwpqc', x)
98
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
99
+ return x
100
+
101
+ def unpatchify(self, x):
102
+ """
103
+ x: (N, L, patch_size**2 *3)
104
+ imgs: (N, 3, H, W)
105
+ """
106
+ p = self.patch_embed.patch_size[0]
107
+ h = w = int(x.shape[1]**.5)
108
+ assert h * w == x.shape[1]
109
+
110
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
111
+ x = torch.einsum('nhwpqc->nchpwq', x)
112
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
113
+ return imgs
114
+
115
+ def random_masking(self, x, mask_ratio):
116
+ """
117
+ Perform per-sample random masking by per-sample shuffling.
118
+ Per-sample shuffling is done by argsort random noise.
119
+ x: [N, L, D], sequence
120
+ """
121
+ N, L, D = x.shape # batch, length, dim
122
+ len_keep = int(L * (1 - mask_ratio))
123
+
124
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
125
+
126
+ # sort noise for each sample
127
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
128
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
129
+
130
+ # keep the first subset
131
+ ids_keep = ids_shuffle[:, :len_keep]
132
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
133
+
134
+ # generate the binary mask: 0 is keep, 1 is remove
135
+ mask = torch.ones([N, L], device=x.device)
136
+ mask[:, :len_keep] = 0
137
+ # unshuffle to get the binary mask
138
+ mask = torch.gather(mask, dim=1, index=ids_restore)
139
+
140
+ return x_masked, mask, ids_restore
141
+
142
+ def forward_encoder(self, x, mask_ratio):
143
+ # embed patches
144
+ x = self.patch_embed(x)
145
+
146
+ # add pos embed w/o cls token
147
+ x = x + self.pos_embed[:, 1:, :]
148
+
149
+ # masking: length -> length * mask_ratio
150
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
151
+
152
+ # append cls token
153
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
154
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
155
+ x = torch.cat((cls_tokens, x), dim=1)
156
+
157
+ # apply Transformer blocks
158
+ for blk in self.blocks:
159
+ x = blk(x)
160
+ x = self.norm(x)
161
+
162
+ return x, mask, ids_restore
163
+
164
+ def forward_decoder(self, x, ids_restore):
165
+ # embed tokens
166
+ x = self.decoder_embed(x)
167
+
168
+ # append mask tokens to sequence
169
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
170
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
171
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
172
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
173
+
174
+ # add pos embed
175
+ x = x + self.decoder_pos_embed
176
+
177
+ # apply Transformer blocks
178
+ for blk in self.decoder_blocks:
179
+ x = blk(x)
180
+ x = self.decoder_norm(x)
181
+
182
+ # predictor projection
183
+ x = self.decoder_pred(x)
184
+
185
+ # remove cls token
186
+ x = x[:, 1:, :]
187
+
188
+ return x
189
+
190
+ def forward_loss(self, imgs, pred, mask):
191
+ """
192
+ imgs: [N, 3, H, W]
193
+ pred: [N, L, p*p*3]
194
+ mask: [N, L], 0 is keep, 1 is remove,
195
+ """
196
+ target = self.patchify(imgs)
197
+ if self.norm_pix_loss:
198
+ mean = target.mean(dim=-1, keepdim=True)
199
+ var = target.var(dim=-1, keepdim=True)
200
+ target = (target - mean) / (var + 1.e-6)**.5
201
+
202
+ loss = (pred - target) ** 2
203
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
204
+
205
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
206
+ return loss
207
+
208
+ def forward(self, imgs, mask_ratio=0.75):
209
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
210
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
211
+ loss = self.forward_loss(imgs, pred, mask)
212
+ return loss, pred, mask
213
+
214
+
215
+
216
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
217
+ model = MaskedAutoencoderViT(
218
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
219
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
220
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
221
+ return model
222
+
223
+
224
+
225
+ # set recommended archs
226
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
models_vit.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import timm.models.vision_transformer
12
+
13
+
14
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
15
+ """ Vision Transformer with support for global average pooling
16
+ """
17
+ def __init__(self, global_pool=False, **kwargs):
18
+ super(VisionTransformer, self).__init__(**kwargs)
19
+
20
+ self.global_pool = global_pool
21
+ if self.global_pool:
22
+ norm_layer = kwargs['norm_layer']
23
+ embed_dim = kwargs['embed_dim']
24
+ self.fc_norm = norm_layer(embed_dim)
25
+
26
+ del self.norm # remove the original norm
27
+
28
+ def forward_features(self, x):
29
+ B = x.shape[0]
30
+ x = self.patch_embed(x)
31
+
32
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
33
+ x = torch.cat((cls_tokens, x), dim=1)
34
+ x = x + self.pos_embed
35
+ x = self.pos_drop(x)
36
+
37
+ for blk in self.blocks:
38
+ x = blk(x)
39
+
40
+ if self.global_pool:
41
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
42
+ outcome = self.fc_norm(x)
43
+ else:
44
+ x = self.norm(x)
45
+ outcome = x[:, 0]
46
+
47
+ return outcome
48
+
49
+
50
+ def vit_large_patch16(**kwargs):
51
+ model = VisionTransformer(
52
+ img_size=224,patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
53
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
54
+ return model
55
+
pic/file_index.jpg ADDED
requirement.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.5.3.56
2
+ pandas==0.25.3
3
+ Pillow==8.3.1
4
+ protobuf==3.17.3
5
+ pycm==3.2
6
+ pydicom==2.3.0
7
+ scikit-image==0.17.2
8
+ scikit-learn==0.24.2
9
+ scipy==1.5.4
10
+ tensorboard==2.6.0
11
+ tensorboard-data-server==0.6.1
12
+ tensorboard-plugin-wit==1.8.0
13
+ timm==0.3.2
14
+ tqdm==4.62.1
15
+
util/datasets.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import os
7
+ from torchvision import datasets, transforms
8
+ from timm.data import create_transform
9
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10
+
11
+
12
+ def build_dataset(is_train, args):
13
+
14
+ transform = build_transform(is_train, args)
15
+ root = os.path.join(args.data_path, is_train)
16
+ dataset = datasets.ImageFolder(root, transform=transform)
17
+
18
+ return dataset
19
+
20
+
21
+ def build_transform(is_train, args):
22
+ mean = IMAGENET_DEFAULT_MEAN
23
+ std = IMAGENET_DEFAULT_STD
24
+ # train transform
25
+ if is_train=='train':
26
+ # this should always dispatch to transforms_imagenet_train
27
+ transform = create_transform(
28
+ input_size=args.input_size,
29
+ is_training=True,
30
+ color_jitter=args.color_jitter,
31
+ auto_augment=args.aa,
32
+ interpolation='bicubic',
33
+ re_prob=args.reprob,
34
+ re_mode=args.remode,
35
+ re_count=args.recount,
36
+ mean=mean,
37
+ std=std,
38
+ )
39
+ return transform
40
+
41
+ # eval transform
42
+ t = []
43
+ if args.input_size <= 224:
44
+ crop_pct = 224 / 256
45
+ else:
46
+ crop_pct = 1.0
47
+ size = int(args.input_size / crop_pct)
48
+ t.append(
49
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
50
+ )
51
+ t.append(transforms.CenterCrop(args.input_size))
52
+ t.append(transforms.ToTensor())
53
+ t.append(transforms.Normalize(mean, std))
54
+ return transforms.Compose(t)
util/lr_decay.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import json
7
+
8
+
9
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
10
+ """
11
+ Parameter groups for layer-wise lr decay
12
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
13
+ """
14
+ param_group_names = {}
15
+ param_groups = {}
16
+
17
+ num_layers = len(model.blocks) + 1
18
+
19
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
20
+
21
+ for n, p in model.named_parameters():
22
+ if not p.requires_grad:
23
+ continue
24
+
25
+ # no decay: all 1D parameters and model specific ones
26
+ if p.ndim == 1 or n in no_weight_decay_list:
27
+ g_decay = "no_decay"
28
+ this_decay = 0.
29
+ else:
30
+ g_decay = "decay"
31
+ this_decay = weight_decay
32
+
33
+ layer_id = get_layer_id_for_vit(n, num_layers)
34
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
35
+
36
+ if group_name not in param_group_names:
37
+ this_scale = layer_scales[layer_id]
38
+
39
+ param_group_names[group_name] = {
40
+ "lr_scale": this_scale,
41
+ "weight_decay": this_decay,
42
+ "params": [],
43
+ }
44
+ param_groups[group_name] = {
45
+ "lr_scale": this_scale,
46
+ "weight_decay": this_decay,
47
+ "params": [],
48
+ }
49
+
50
+ param_group_names[group_name]["params"].append(n)
51
+ param_groups[group_name]["params"].append(p)
52
+
53
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
54
+
55
+ return list(param_groups.values())
56
+
57
+
58
+ def get_layer_id_for_vit(name, num_layers):
59
+ """
60
+ Assign a parameter with its layer id
61
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
62
+ """
63
+ if name in ['cls_token', 'pos_embed']:
64
+ return 0
65
+ elif name.startswith('patch_embed'):
66
+ return 0
67
+ elif name.startswith('blocks'):
68
+ return int(name.split('.')[1]) + 1
69
+ else:
70
+ return num_layers
util/lr_sched.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+
8
+ def adjust_learning_rate(optimizer, epoch, args):
9
+ """Decay the learning rate with half-cycle cosine after warmup"""
10
+ if epoch < args.warmup_epochs:
11
+ lr = args.lr * epoch / args.warmup_epochs
12
+ else:
13
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
14
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
15
+ for param_group in optimizer.param_groups:
16
+ if "lr_scale" in param_group:
17
+ param_group["lr"] = lr * param_group["lr_scale"]
18
+ else:
19
+ param_group["lr"] = lr
20
+ return lr
util/misc.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import builtins
7
+ import datetime
8
+ import os
9
+ import time
10
+ from collections import defaultdict, deque
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch._six import inf
16
+
17
+
18
+ class SmoothedValue(object):
19
+ """Track a series of values and provide access to smoothed values over a
20
+ window or the global series average.
21
+ """
22
+
23
+ def __init__(self, window_size=20, fmt=None):
24
+ if fmt is None:
25
+ fmt = "{median:.4f} ({global_avg:.4f})"
26
+ self.deque = deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+ self.fmt = fmt
30
+
31
+ def update(self, value, n=1):
32
+ self.deque.append(value)
33
+ self.count += n
34
+ self.total += value * n
35
+
36
+ def synchronize_between_processes(self):
37
+ """
38
+ Warning: does not synchronize the deque!
39
+ """
40
+ if not is_dist_avail_and_initialized():
41
+ return
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ return d.median().item()
53
+
54
+ @property
55
+ def avg(self):
56
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
57
+ return d.mean().item()
58
+
59
+ @property
60
+ def global_avg(self):
61
+ return self.total / self.count
62
+
63
+ @property
64
+ def max(self):
65
+ return max(self.deque)
66
+
67
+ @property
68
+ def value(self):
69
+ return self.deque[-1]
70
+
71
+ def __str__(self):
72
+ return self.fmt.format(
73
+ median=self.median,
74
+ avg=self.avg,
75
+ global_avg=self.global_avg,
76
+ max=self.max,
77
+ value=self.value)
78
+
79
+
80
+ class MetricLogger(object):
81
+ def __init__(self, delimiter="\t"):
82
+ self.meters = defaultdict(SmoothedValue)
83
+ self.delimiter = delimiter
84
+
85
+ def update(self, **kwargs):
86
+ for k, v in kwargs.items():
87
+ if v is None:
88
+ continue
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError("'{}' object has no attribute '{}'".format(
100
+ type(self).__name__, attr))
101
+
102
+ def __str__(self):
103
+ loss_str = []
104
+ for name, meter in self.meters.items():
105
+ loss_str.append(
106
+ "{}: {}".format(name, str(meter))
107
+ )
108
+ return self.delimiter.join(loss_str)
109
+
110
+ def synchronize_between_processes(self):
111
+ for meter in self.meters.values():
112
+ meter.synchronize_between_processes()
113
+
114
+ def add_meter(self, name, meter):
115
+ self.meters[name] = meter
116
+
117
+ def log_every(self, iterable, print_freq, header=None):
118
+ i = 0
119
+ if not header:
120
+ header = ''
121
+ start_time = time.time()
122
+ end = time.time()
123
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
124
+ data_time = SmoothedValue(fmt='{avg:.4f}')
125
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
126
+ log_msg = [
127
+ header,
128
+ '[{0' + space_fmt + '}/{1}]',
129
+ 'eta: {eta}',
130
+ '{meters}',
131
+ 'time: {time}',
132
+ 'data: {data}'
133
+ ]
134
+ if torch.cuda.is_available():
135
+ log_msg.append('max mem: {memory:.0f}')
136
+ log_msg = self.delimiter.join(log_msg)
137
+ MB = 1024.0 * 1024.0
138
+ for obj in iterable:
139
+ data_time.update(time.time() - end)
140
+ yield obj
141
+ iter_time.update(time.time() - end)
142
+ if i % print_freq == 0 or i == len(iterable) - 1:
143
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
144
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
145
+ if torch.cuda.is_available():
146
+ print(log_msg.format(
147
+ i, len(iterable), eta=eta_string,
148
+ meters=str(self),
149
+ time=str(iter_time), data=str(data_time),
150
+ memory=torch.cuda.max_memory_allocated() / MB))
151
+ else:
152
+ print(log_msg.format(
153
+ i, len(iterable), eta=eta_string,
154
+ meters=str(self),
155
+ time=str(iter_time), data=str(data_time)))
156
+ i += 1
157
+ end = time.time()
158
+ total_time = time.time() - start_time
159
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
160
+ print('{} Total time: {} ({:.4f} s / it)'.format(
161
+ header, total_time_str, total_time / len(iterable)))
162
+
163
+
164
+ def setup_for_distributed(is_master):
165
+ """
166
+ This function disables printing when not in master process
167
+ """
168
+ builtin_print = builtins.print
169
+
170
+ def print(*args, **kwargs):
171
+ force = kwargs.pop('force', False)
172
+ force = force or (get_world_size() > 8)
173
+ if is_master or force:
174
+ now = datetime.datetime.now().time()
175
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
176
+ builtin_print(*args, **kwargs)
177
+
178
+ builtins.print = print
179
+
180
+
181
+ def is_dist_avail_and_initialized():
182
+ if not dist.is_available():
183
+ return False
184
+ if not dist.is_initialized():
185
+ return False
186
+ return True
187
+
188
+
189
+ def get_world_size():
190
+ if not is_dist_avail_and_initialized():
191
+ return 1
192
+ return dist.get_world_size()
193
+
194
+
195
+ def get_rank():
196
+ if not is_dist_avail_and_initialized():
197
+ return 0
198
+ return dist.get_rank()
199
+
200
+
201
+ def is_main_process():
202
+ return get_rank() == 0
203
+
204
+
205
+ def save_on_master(*args, **kwargs):
206
+ if is_main_process():
207
+ torch.save(*args, **kwargs)
208
+
209
+
210
+ def init_distributed_mode(args):
211
+ if args.dist_on_itp:
212
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
213
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
214
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
215
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
216
+ os.environ['LOCAL_RANK'] = str(args.gpu)
217
+ os.environ['RANK'] = str(args.rank)
218
+ os.environ['WORLD_SIZE'] = str(args.world_size)
219
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
220
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
221
+ args.rank = int(os.environ["RANK"])
222
+ args.world_size = int(os.environ['WORLD_SIZE'])
223
+ args.gpu = int(os.environ['LOCAL_RANK'])
224
+ elif 'SLURM_PROCID' in os.environ:
225
+ args.rank = int(os.environ['SLURM_PROCID'])
226
+ args.gpu = args.rank % torch.cuda.device_count()
227
+ else:
228
+ print('Not using distributed mode')
229
+ setup_for_distributed(is_master=True) # hack
230
+ args.distributed = False
231
+ return
232
+
233
+ args.distributed = True
234
+
235
+ torch.cuda.set_device(args.gpu)
236
+ args.dist_backend = 'nccl'
237
+ print('| distributed init (rank {}): {}, gpu {}'.format(
238
+ args.rank, args.dist_url, args.gpu), flush=True)
239
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
240
+ world_size=args.world_size, rank=args.rank)
241
+ torch.distributed.barrier()
242
+ setup_for_distributed(args.rank == 0)
243
+
244
+
245
+ class NativeScalerWithGradNormCount:
246
+ state_dict_key = "amp_scaler"
247
+
248
+ def __init__(self):
249
+ self._scaler = torch.cuda.amp.GradScaler()
250
+
251
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
252
+ self._scaler.scale(loss).backward(create_graph=create_graph)
253
+ if update_grad:
254
+ if clip_grad is not None:
255
+ assert parameters is not None
256
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
257
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
258
+ else:
259
+ self._scaler.unscale_(optimizer)
260
+ norm = get_grad_norm_(parameters)
261
+ self._scaler.step(optimizer)
262
+ self._scaler.update()
263
+ else:
264
+ norm = None
265
+ return norm
266
+
267
+ def state_dict(self):
268
+ return self._scaler.state_dict()
269
+
270
+ def load_state_dict(self, state_dict):
271
+ self._scaler.load_state_dict(state_dict)
272
+
273
+
274
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
275
+ if isinstance(parameters, torch.Tensor):
276
+ parameters = [parameters]
277
+ parameters = [p for p in parameters if p.grad is not None]
278
+ norm_type = float(norm_type)
279
+ if len(parameters) == 0:
280
+ return torch.tensor(0.)
281
+ device = parameters[0].grad.device
282
+ if norm_type == inf:
283
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
284
+ else:
285
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
286
+ return total_norm
287
+
288
+
289
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
290
+ output_dir = Path(args.output_dir)
291
+ epoch_name = str(epoch)
292
+ if loss_scaler is not None:
293
+ checkpoint_paths = [args.task+'checkpoint-best.pth']
294
+ for checkpoint_path in checkpoint_paths:
295
+ to_save = {
296
+ 'model': model_without_ddp.state_dict(),
297
+ 'optimizer': optimizer.state_dict(),
298
+ 'epoch': epoch,
299
+ 'scaler': loss_scaler.state_dict(),
300
+ 'args': args,
301
+ }
302
+
303
+ save_on_master(to_save, checkpoint_path)
304
+ else:
305
+ client_state = {'epoch': epoch}
306
+ model.save_checkpoint(save_dir=args.task, tag="checkpoint-best", client_state=client_state)
307
+
308
+
309
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
310
+ if args.resume:
311
+ if args.resume.startswith('https'):
312
+ checkpoint = torch.hub.load_state_dict_from_url(
313
+ args.resume, map_location='cpu', check_hash=True)
314
+ else:
315
+ checkpoint = torch.load(args.resume, map_location='cpu')
316
+ model_without_ddp.load_state_dict(checkpoint['model'])
317
+ print("Resume checkpoint %s" % args.resume)
318
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
319
+ optimizer.load_state_dict(checkpoint['optimizer'])
320
+ args.start_epoch = checkpoint['epoch'] + 1
321
+ if 'scaler' in checkpoint:
322
+ loss_scaler.load_state_dict(checkpoint['scaler'])
323
+ print("With optim & sched!")
324
+
325
+
326
+ def all_reduce_mean(x):
327
+ world_size = get_world_size()
328
+ if world_size > 1:
329
+ x_reduce = torch.tensor(x).cuda()
330
+ dist.all_reduce(x_reduce)
331
+ x_reduce /= world_size
332
+ return x_reduce.item()
333
+ else:
334
+ return x
335
+
util/pos_embed.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+
10
+ # --------------------------------------------------------
11
+ # 2D sine-cosine position embedding
12
+ # References:
13
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
14
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
15
+ # --------------------------------------------------------
16
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
17
+ """
18
+ grid_size: int of the grid height and width
19
+ return:
20
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
21
+ """
22
+ grid_h = np.arange(grid_size, dtype=np.float32)
23
+ grid_w = np.arange(grid_size, dtype=np.float32)
24
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
25
+ grid = np.stack(grid, axis=0)
26
+
27
+ grid = grid.reshape([2, 1, grid_size, grid_size])
28
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
29
+ if cls_token:
30
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
31
+ return pos_embed
32
+
33
+
34
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
35
+ assert embed_dim % 2 == 0
36
+
37
+ # use half of dimensions to encode grid_h
38
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
39
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
40
+
41
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
42
+ return emb
43
+
44
+
45
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
46
+ """
47
+ embed_dim: output dimension for each position
48
+ pos: a list of positions to be encoded: size (M,)
49
+ out: (M, D)
50
+ """
51
+ assert embed_dim % 2 == 0
52
+ omega = np.arange(embed_dim // 2, dtype=np.float)
53
+ omega /= embed_dim / 2.
54
+ omega = 1. / 10000**omega # (D/2,)
55
+
56
+ pos = pos.reshape(-1) # (M,)
57
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
58
+
59
+ emb_sin = np.sin(out) # (M, D/2)
60
+ emb_cos = np.cos(out) # (M, D/2)
61
+
62
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
63
+ return emb
64
+
65
+
66
+ # --------------------------------------------------------
67
+ # Interpolate position embeddings for high-resolution
68
+ # References:
69
+ # DeiT: https://github.com/facebookresearch/deit
70
+ # --------------------------------------------------------
71
+ def interpolate_pos_embed(model, checkpoint_model):
72
+ if 'pos_embed' in checkpoint_model:
73
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
74
+ embedding_size = pos_embed_checkpoint.shape[-1]
75
+ num_patches = model.patch_embed.num_patches
76
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
77
+ # height (== width) for the checkpoint position embedding
78
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
79
+ # height (== width) for the new position embedding
80
+ new_size = int(num_patches ** 0.5)
81
+ # class_token and dist_token are kept unchanged
82
+ if orig_size != new_size:
83
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
84
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
85
+ # only the position tokens are interpolated
86
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
87
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
88
+ pos_tokens = torch.nn.functional.interpolate(
89
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
90
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
91
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
92
+ checkpoint_model['pos_embed'] = new_pos_embed