ckadirt commited on
Commit
626cbe5
·
verified ·
1 Parent(s): 87518f7

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb/
2
+ train_logs/
3
+ slurms/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 MedARC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MindEyeV2
2
+
3
+ In-progress
4
+
5
+ 1. Download all of https://huggingface.co/datasets/pscotti/mindeyev2 and place them in a folder. You will need to specify the path to this folder as "data_path" variable.
6
+
7
+ 2. Run setup.sh to install a new "fmri" conda environment.
8
+
9
+ 3. Activate the conda environment with "conda activate fmri"
10
+
11
+ 4. Run Train.ipynb or Train.py (they are the same code)
src/Train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ # from subprocess import call
9
+ # command = "jupyter nbconvert Train.ipynb --to python"
10
+ # call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import time
24
+ import random
25
+ import h5py
26
+ from tqdm import tqdm
27
+
28
+ import webdataset as wds
29
+ import gc
30
+
31
+ import matplotlib.pyplot as plt
32
+ import torch
33
+ import torch.nn as nn
34
+ from torchvision import transforms
35
+
36
+ # tf32 data type is faster than standard float32
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+
39
+ # custom functions #
40
+ import utils
41
+
42
+
43
+ # In[ ]:
44
+
45
+
46
+ local_rank = os.getenv('RANK')
47
+ if local_rank is None:
48
+ local_rank = 0
49
+ else:
50
+ local_rank = int(local_rank)
51
+ print("LOCAL RANK ", local_rank)
52
+
53
+ ### Single-GPU config ###
54
+ ## Feel free to uncomment the below 4 lines and comment out all the multi-gpu config code to simplify things for single-gpu
55
+ # from accelerate import Accelerator
56
+ # num_devices = torch.cuda.device_count()
57
+ # if num_devices==0: num_devices = 1
58
+ # accelerator = Accelerator(split_batches=False)
59
+ # global_batch_size = 128
60
+
61
+ ### Multi-GPU config ###
62
+ from accelerate import Accelerator, DeepSpeedPlugin
63
+ num_devices = torch.cuda.device_count()
64
+ if num_devices==0: num_devices = 1
65
+ if num_devices <= 1 and utils.is_interactive():
66
+ # can emulate a distributed environment for deepspeed to work in jupyter notebook
67
+ os.environ["MASTER_ADDR"] = "localhost"
68
+ os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
69
+ os.environ["RANK"] = "0"
70
+ os.environ["LOCAL_RANK"] = "0"
71
+ os.environ["WORLD_SIZE"] = "1"
72
+ os.environ["GLOBAL_BATCH_SIZE"] = "128" # set this to your batch size!
73
+ global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
74
+
75
+ # alter the deepspeed config according to your global and local batch size
76
+ if local_rank == 0:
77
+ with open('deepspeed_config_stage2.json', 'r') as file:
78
+ config = json.load(file)
79
+ config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
80
+ config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
81
+ with open('deepspeed_config_stage2.json', 'w') as file:
82
+ json.dump(config, file)
83
+ else:
84
+ # give some time for the local_rank=0 gpu to prep new deepspeed config file
85
+ time.sleep(10)
86
+ deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
87
+ accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
88
+
89
+
90
+ # In[ ]:
91
+
92
+
93
+ print("PID of this process =",os.getpid())
94
+ device = accelerator.device
95
+ print("device:",device)
96
+ num_workers = num_devices
97
+ print(accelerator.state)
98
+ world_size = accelerator.state.num_processes
99
+ distributed = not accelerator.state.distributed_type == 'NO'
100
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
101
+ print = accelerator.print # only print if local_rank=0
102
+
103
+
104
+ # # Configurations
105
+
106
+ # In[3]:
107
+
108
+
109
+ # if running this interactively, can specify jupyter_args here for argparser to use
110
+ if utils.is_interactive():
111
+ # Example use
112
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
113
+ --model_name=test \
114
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
115
+ --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
116
+
117
+ jupyter_args = jupyter_args.split()
118
+ print(jupyter_args)
119
+
120
+ from IPython.display import clear_output # function to clear print outputs in cell
121
+ get_ipython().run_line_magic('load_ext', 'autoreload')
122
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
123
+ get_ipython().run_line_magic('autoreload', '2')
124
+
125
+
126
+ # In[4]:
127
+
128
+
129
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
130
+ parser.add_argument(
131
+ "--model_name", type=str, default="testing",
132
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
133
+ )
134
+ parser.add_argument(
135
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
136
+ help="Path to where NSD data is stored / where to download it to",
137
+ )
138
+ parser.add_argument(
139
+ "--subj",type=int, default=1, choices=[1,2,5,7],
140
+ )
141
+ parser.add_argument(
142
+ "--batch_size", type=int, default=32,
143
+ help="Batch size can be increased by 10x if only training v2c and not diffusion prior",
144
+ )
145
+ parser.add_argument(
146
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
147
+ help="whether to log to wandb",
148
+ )
149
+ parser.add_argument(
150
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
151
+ help="if not using wandb and want to resume from a ckpt",
152
+ )
153
+ parser.add_argument(
154
+ "--wandb_project",type=str,default="stability",
155
+ help="wandb project name",
156
+ )
157
+ parser.add_argument(
158
+ "--mixup_pct",type=float,default=.33,
159
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
160
+ )
161
+ parser.add_argument(
162
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
163
+ help="whether to use image augmentation",
164
+ )
165
+ parser.add_argument(
166
+ "--num_epochs",type=int,default=240,
167
+ help="number of epochs of training",
168
+ )
169
+ parser.add_argument(
170
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
171
+ )
172
+ parser.add_argument(
173
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
174
+ )
175
+ parser.add_argument(
176
+ "--ckpt_interval",type=int,default=5,
177
+ help="save backup ckpt and reconstruct every x epochs",
178
+ )
179
+ parser.add_argument(
180
+ "--seed",type=int,default=42,
181
+ )
182
+ parser.add_argument(
183
+ "--max_lr",type=float,default=3e-4,
184
+ )
185
+ parser.add_argument(
186
+ "--n_samples_save",type=int,default=0,choices=[0,1],
187
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
188
+ )
189
+
190
+ if utils.is_interactive():
191
+ args = parser.parse_args(jupyter_args)
192
+ else:
193
+ args = parser.parse_args()
194
+
195
+ # create global variables without the args prefix
196
+ for attribute_name in vars(args).keys():
197
+ globals()[attribute_name] = getattr(args, attribute_name)
198
+
199
+ print("global batch_size", batch_size)
200
+ batch_size = int(batch_size / num_devices)
201
+ print("batch_size", batch_size)
202
+
203
+
204
+ # In[5]:
205
+
206
+
207
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
208
+ if not os.path.exists(outdir):
209
+ os.makedirs(outdir,exist_ok=True)
210
+ if use_image_aug:
211
+ import kornia
212
+ from kornia.augmentation.container import AugmentationSequential
213
+ img_augment = AugmentationSequential(
214
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
215
+ kornia.augmentation.Resize((224, 224)),
216
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
217
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
218
+ kornia.augmentation.RandomGrayscale(p=0.3),
219
+ same_on_batch=False,
220
+ data_keys=["input"],
221
+ )
222
+
223
+
224
+ # # Prep data, models, and dataloaders
225
+
226
+ # ## Dataloader
227
+
228
+ # In[6]:
229
+
230
+
231
+ if subj==1:
232
+ num_train = 24958
233
+ num_test = 2770
234
+ test_batch_size = num_test
235
+
236
+ def my_split_by_node(urls): return urls
237
+
238
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
239
+ print(train_url)
240
+
241
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
242
+ .shuffle(750, initial=1500, rng=random.Random(42))\
243
+ .decode("torch")\
244
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
245
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
246
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
247
+
248
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
249
+ print(test_url)
250
+
251
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
252
+ .shuffle(750, initial=1500, rng=random.Random(42))\
253
+ .decode("torch")\
254
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
255
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
256
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
257
+
258
+
259
+ # ### check dataloaders are working
260
+
261
+ # In[7]:
262
+
263
+
264
+ # test_indices = []
265
+ # test_images = []
266
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
267
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
268
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
269
+ # test_indices = test_indices.astype(np.int16)
270
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
271
+ # print("---\n")
272
+
273
+ # train_indices = []
274
+ # train_images = []
275
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
276
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
277
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
278
+ # train_indices = train_indices.astype(np.int16)
279
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
280
+
281
+
282
+ # ## Load voxel betas, K-means clustering model, and images
283
+
284
+ # In[8]:
285
+
286
+
287
+ # load betas
288
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
289
+ voxels = f['betas'][:]
290
+ print(f"subj0{subj} betas loaded into memory")
291
+ voxels = torch.Tensor(voxels).to("cpu").half()
292
+ if subj==1:
293
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
294
+ print("voxels", voxels.shape)
295
+ num_voxels = voxels.shape[-1]
296
+
297
+ # load orig images
298
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
299
+ images = f['images'][:]
300
+ images = torch.Tensor(images).to("cpu").half()
301
+ print("images", images.shape)
302
+
303
+
304
+ # In[9]:
305
+
306
+
307
+ from models import Clipper
308
+ eva02_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
309
+
310
+ clip_seq_dim = 257
311
+ clip_emb_dim = 768
312
+ hidden_dim = 4096
313
+
314
+
315
+ # In[10]:
316
+
317
+
318
+ class MindEyeModule(nn.Module):
319
+ def __init__(self):
320
+ super(MindEyeModule, self).__init__()
321
+ def forward(self, x):
322
+ return x
323
+
324
+ model = MindEyeModule()
325
+ model
326
+
327
+
328
+ # In[11]:
329
+
330
+
331
+ class RidgeRegression(torch.nn.Module):
332
+ # make sure to add weight_decay when initializing optimizer
333
+ def __init__(self, input_size, out_features):
334
+ super(RidgeRegression, self).__init__()
335
+ self.linear = torch.nn.Linear(input_size, out_features)
336
+ def forward(self, x):
337
+ return self.linear(x)
338
+
339
+ model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
340
+ utils.count_params(model.ridge)
341
+ utils.count_params(model)
342
+
343
+ b = torch.randn((2,voxels.shape[1]))
344
+ print(b.shape, model.ridge(b).shape)
345
+
346
+
347
+ # In[12]:
348
+
349
+
350
+ from functools import partial
351
+ class BrainNetwork(nn.Module):
352
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):
353
+ super().__init__()
354
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
355
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
356
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
357
+ self.mlp = nn.ModuleList([
358
+ nn.Sequential(
359
+ nn.Linear(h, h),
360
+ *[item() for item in act_and_norm],
361
+ nn.Dropout(drop2)
362
+ ) for _ in range(n_blocks)
363
+ ])
364
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
365
+ self.n_blocks = n_blocks
366
+ self.clip_size = clip_size
367
+ self.use_projector = use_projector
368
+ if use_projector:
369
+ self.projector = nn.Sequential(
370
+ nn.LayerNorm(clip_size),
371
+ nn.GELU(),
372
+ nn.Linear(clip_size, 2048),
373
+ nn.LayerNorm(2048),
374
+ nn.GELU(),
375
+ nn.Linear(2048, 2048),
376
+ nn.LayerNorm(2048),
377
+ nn.GELU(),
378
+ nn.Linear(2048, clip_size)
379
+ )
380
+
381
+ def forward(self, x):
382
+ residual = x
383
+ for res_block in range(self.n_blocks):
384
+ x = self.mlp[res_block](x)
385
+ x += residual
386
+ residual = x
387
+ x = x.reshape(len(x), -1)
388
+ x = self.lin1(x)
389
+ if self.use_projector:
390
+ return self.projector(x.reshape(len(x), -1, self.clip_size))
391
+ return x
392
+
393
+ model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)
394
+ utils.count_params(model.backbone)
395
+ utils.count_params(model)
396
+
397
+ b = torch.randn((2,hidden_dim))
398
+ print(b.shape, model.backbone(b).shape)
399
+
400
+
401
+ # In[13]:
402
+
403
+
404
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
405
+ opt_grouped_parameters = [
406
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
407
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
408
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
409
+ ]
410
+
411
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
412
+
413
+ if lr_scheduler_type == 'linear':
414
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
415
+ optimizer,
416
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
417
+ last_epoch=-1
418
+ )
419
+ elif lr_scheduler_type == 'cycle':
420
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
421
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
422
+ optimizer,
423
+ max_lr=max_lr,
424
+ total_steps=total_steps,
425
+ final_div_factor=1000,
426
+ last_epoch=-1, pct_start=2/num_epochs
427
+ )
428
+
429
+ def save_ckpt(tag):
430
+ ckpt_path = outdir+f'/{tag}.pth'
431
+ print(f'saving {ckpt_path}',flush=True)
432
+ unwrapped_model = accelerator.unwrap_model(model)
433
+ try:
434
+ torch.save({
435
+ 'epoch': epoch,
436
+ 'model_state_dict': unwrapped_model.state_dict(),
437
+ 'optimizer_state_dict': optimizer.state_dict(),
438
+ 'lr_scheduler': lr_scheduler.state_dict(),
439
+ 'train_losses': losses,
440
+ 'test_losses': test_losses,
441
+ 'lrs': lrs,
442
+ }, ckpt_path)
443
+ except:
444
+ print("Couldn't save... moving on to prevent crashing.")
445
+ del unwrapped_model
446
+
447
+ print("\nDone with model preparations!")
448
+
449
+
450
+ # # Weights and Biases
451
+
452
+ # In[14]:
453
+
454
+
455
+ # params for wandb
456
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
457
+ import wandb
458
+
459
+ wandb_project = 'stability'
460
+ wandb_run = model_name
461
+ wandb_notes = ''
462
+
463
+ print(f"wandb {wandb_project} run {wandb_run}")
464
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
465
+ wandb_config = {
466
+ "model_name": model_name,
467
+ "clip_variant": clip_variant,
468
+ "batch_size": batch_size,
469
+ "num_epochs": num_epochs,
470
+ "use_image_aug": use_image_aug,
471
+ "max_lr": max_lr,
472
+ "lr_scheduler_type": lr_scheduler_type,
473
+ "mixup_pct": mixup_pct,
474
+ "num_train": num_train,
475
+ "num_test": num_test,
476
+ "seed": seed,
477
+ "distributed": distributed,
478
+ "num_devices": num_devices,
479
+ "world_size": world_size,
480
+ }
481
+ print("wandb_config:\n",wandb_config)
482
+ if True: # wandb_auto_resume
483
+ print("wandb_id:",model_name)
484
+ wandb.init(
485
+ id = model_name,
486
+ project=wandb_project,
487
+ name=wandb_run,
488
+ config=wandb_config,
489
+ notes=wandb_notes,
490
+ resume="allow",
491
+ )
492
+ else:
493
+ wandb.init(
494
+ project=wandb_project,
495
+ name=wandb_run,
496
+ config=wandb_config,
497
+ notes=wandb_notes,
498
+ )
499
+ else:
500
+ wandb_log = False
501
+
502
+
503
+ # # Main
504
+
505
+ # In[15]:
506
+
507
+
508
+ epoch = 0
509
+ losses, test_losses, lrs = [], [], []
510
+ best_test_loss = 1e9
511
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
512
+
513
+ # Optionally resume from checkpoint #
514
+ if resume_from_ckpt:
515
+ print("\n---resuming from last.pth ckpt---\n")
516
+ try:
517
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
518
+ except:
519
+ print('last.pth failed... trying last_backup.pth')
520
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
521
+ epoch = checkpoint['epoch']
522
+ print("Epoch",epoch)
523
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
524
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
525
+ diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
526
+ del checkpoint
527
+ elif wandb_log:
528
+ if wandb.run.resumed:
529
+ print("\n---resuming from last.pth ckpt---\n")
530
+ try:
531
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
532
+ except:
533
+ print('last.pth failed... trying last_backup.pth')
534
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
535
+ epoch = checkpoint['epoch']
536
+ print("Epoch",epoch)
537
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
538
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
539
+ diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
540
+ del checkpoint
541
+ torch.cuda.empty_cache()
542
+
543
+
544
+ # In[16]:
545
+
546
+
547
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
548
+ model, optimizer, train_dl, test_dl, lr_scheduler
549
+ )
550
+
551
+
552
+ # In[17]:
553
+
554
+
555
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
556
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
557
+ test_image, test_voxel = None, None
558
+ mse = nn.MSELoss()
559
+ for epoch in progress_bar:
560
+ model.train()
561
+
562
+ fwd_percent_correct = 0.
563
+ bwd_percent_correct = 0.
564
+ test_fwd_percent_correct = 0.
565
+ test_bwd_percent_correct = 0.
566
+
567
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
568
+ with torch.cuda.amp.autocast():
569
+ optimizer.zero_grad()
570
+
571
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
572
+ image = images[behav[:,0,0].cpu().long()].to(device)
573
+
574
+ if use_image_aug: image = img_augment(image)
575
+
576
+ clip_target = eva02_model.embed_image(image.float())
577
+ assert not torch.any(torch.isnan(clip_target))
578
+
579
+ if epoch < int(mixup_pct * num_epochs):
580
+ voxel, perm, betas, select = utils.mixco(voxel)
581
+
582
+ voxel_ridge = model.ridge(voxel)
583
+
584
+ clip_voxels = model.backbone(voxel_ridge)
585
+
586
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
587
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
588
+
589
+ if epoch < int(mixup_pct * num_epochs):
590
+ loss_clip = utils.mixco_nce(
591
+ clip_voxels_norm,
592
+ clip_target_norm,
593
+ temp=.006,
594
+ perm=perm, betas=betas, select=select)
595
+ else:
596
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
597
+ loss_clip = utils.soft_clip_loss(
598
+ clip_voxels_norm,
599
+ clip_target_norm,
600
+ temp=epoch_temp)
601
+
602
+ loss = loss_clip
603
+
604
+ utils.check_loss(loss)
605
+
606
+ accelerator.backward(loss)
607
+ optimizer.step()
608
+
609
+ losses.append(loss.item())
610
+ lrs.append(optimizer.param_groups[0]['lr'])
611
+
612
+ # forward and backward top 1 accuracy
613
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
614
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
615
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
616
+
617
+ if lr_scheduler_type is not None:
618
+ lr_scheduler.step()
619
+
620
+ model.eval()
621
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
622
+ with torch.no_grad():
623
+ with torch.cuda.amp.autocast():
624
+ # all test samples should be loaded per batch such that test_i should never exceed 0
625
+ if len(behav) != num_test: print("!",len(behav),num_test)
626
+
627
+ ## Average same-image repeats ##
628
+ if test_image is None:
629
+ voxel = voxels[behav[:,0,5].cpu().long()]
630
+ image = behav[:,0,0].cpu().long()
631
+
632
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
633
+ for im in unique_image:
634
+ locs = torch.where(im == image)[0]
635
+ if test_image is None:
636
+ test_image = images[im][None]
637
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
638
+ else:
639
+ test_image = torch.vstack((test_image, images[im][None]))
640
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
641
+
642
+ # random sample of 300
643
+ random_indices = torch.randperm(len(test_voxel))[:300]
644
+ voxel = test_voxel[random_indices].to(device)
645
+ image = test_image[random_indices].to(device)
646
+ assert len(image) == 300
647
+
648
+ clip_target = eva02_model.embed_image(image.float())
649
+
650
+ voxel_ridge = model.ridge(voxel)
651
+
652
+ clip_voxels = model.backbone(voxel_ridge)
653
+
654
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
655
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
656
+
657
+ loss_clip = utils.soft_clip_loss(
658
+ clip_voxels_norm,
659
+ clip_target_norm,
660
+ temp=.006)
661
+
662
+ loss = loss_clip
663
+
664
+ utils.check_loss(loss)
665
+
666
+ test_losses.append(loss.item())
667
+
668
+ # forward and backward top 1 accuracy
669
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
670
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
671
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
672
+
673
+ if local_rank==0:
674
+ if utils.is_interactive():
675
+ # clear_output(wait=True)
676
+ print("---")
677
+
678
+ assert (test_i+1) == 1
679
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
680
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
681
+ "train/lr": lrs[-1],
682
+ "train/num_steps": len(losses),
683
+ "test/num_steps": len(test_losses),
684
+ "train/fwd_pct_correct": fwd_percent_correct.item() / (train_i + 1),
685
+ "train/bwd_pct_correct": bwd_percent_correct.item() / (train_i + 1),
686
+ "test/test_fwd_pct_correct": test_fwd_percent_correct.item() / (test_i + 1),
687
+ "test/test_bwd_pct_correct": test_bwd_percent_correct.item() / (test_i + 1),
688
+ }
689
+ progress_bar.set_postfix(**logs)
690
+
691
+ # Save model checkpoint and reconstruct
692
+ if epoch % ckpt_interval == 0:
693
+ if not utils.is_interactive():
694
+ save_ckpt(f'last')
695
+
696
+ if wandb_log: wandb.log(logs)
697
+
698
+ # wait for other GPUs to catch up if needed
699
+ accelerator.wait_for_everyone()
700
+ torch.cuda.empty_cache()
701
+ gc.collect()
702
+
703
+ print("\n===Finished!===\n")
704
+ if ckpt_saving:
705
+ save_ckpt(f'last')
706
+ if not utils.is_interactive():
707
+ sys.exit(0)
708
+
src/accel.slurm ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=topfmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=ms
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=topfmri
12
+
13
+ module load cuda/11.7 # should match torch.cuda.version
14
+
15
+ export NUM_GPUS=4 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/paulscotti/MindEyeV2/wandb/"
26
+ export WANDB_CACHE_DIR="/fsx/home-paulscotti/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ ###########
34
+
35
+ cd /fsx/proj-fmri/paulscotti/MindEyeV2
36
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=test --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=240 --ckpt_interval=999 --no-use_image_aug
37
+
38
+ # --wandb_log
src/checking_models.ipynb ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 25,
6
+ "id": "ef9e1556-7840-4004-b181-a2c97ac2ab17",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import torch\n",
12
+ "import torch.nn as nn\n",
13
+ "import numpy as np\n",
14
+ "import matplotlib.pyplot as plt"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "id": "b6f12dd4-f3aa-4981-b604-b72e67229011",
20
+ "metadata": {},
21
+ "source": [
22
+ "# DinoV2"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 26,
28
+ "id": "2a604617-b602-4503-b288-e9828684505e",
29
+ "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "name": "stderr",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "Using cache found in /fsx/proj-fmri/shared/cache/dinov2/hub/facebookresearch_dinov2_main\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "# need to change TORCH_HOME env variable to specify pretrained model should go in shared folder, not home directory\n",
41
+ "os.environ['TORCH_HOME'] = '/fsx/proj-fmri/shared/cache/dinov2'\n",
42
+ "dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')\n",
43
+ "# remove initial image patching\n",
44
+ "dinov2_model.patch_embed = nn.Identity()\n",
45
+ "dinov2_model.patch_embed = nn.Identity()"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 27,
51
+ "id": "32da913d-d931-4967-a5e8-bd40c21d1ad9",
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "torch.Size([2, 33, 1024])\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "dinov2_model.to(\"cuda\")\n",
64
+ "input = torch.randn((2,33,1024)).to(\"cuda\")\n",
65
+ "\n",
66
+ "for block in dinov2_model.blocks: input = block(input)\n",
67
+ "input = dinov2_model.norm(input)\n",
68
+ "\n",
69
+ "print(input.shape)"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "id": "febe89c0-06d0-4309-b378-a8d58b99bf4c",
75
+ "metadata": {},
76
+ "source": [
77
+ "# eva"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 28,
83
+ "id": "690204d0-13d7-452b-97af-14d144800e81",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "from urllib.request import urlopen\n",
88
+ "from PIL import Image\n",
89
+ "import timm\n",
90
+ "\n",
91
+ "img = Image.open(urlopen(\n",
92
+ " 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'\n",
93
+ "))\n",
94
+ "\n",
95
+ "model = timm.create_model(\n",
96
+ " \"eva02_enormous_patch14_clip_224.laion2b\",\n",
97
+ " pretrained=True,\n",
98
+ " num_classes=0, # remove classifier nn.Linear\n",
99
+ ")\n",
100
+ "model = model.eval()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 39,
106
+ "id": "035e3e9d-86c9-4ddf-b760-7b78dded7d2e",
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "ename": "ValueError",
111
+ "evalue": "You have to specify pixel_values",
112
+ "output_type": "error",
113
+ "traceback": [
114
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
115
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
116
+ "Cell \u001b[0;32mIn[39], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m data_config \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mresolve_model_data_config(model)\n\u001b[1;32m 3\u001b[0m transforms \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mcreate_transform(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_config, is_training\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m----> 5\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransforms\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# output is (batch_size, num_features) shaped tensor\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(output\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# or equivalently (without needing to set num_classes=0)\u001b[39;00m\n",
117
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
118
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:1433\u001b[0m, in \u001b[0;36mCLIPSegForImageSegmentation.forward\u001b[0;34m(self, input_ids, pixel_values, conditional_pixel_values, conditional_embeddings, attention_mask, position_ids, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1431\u001b[0m \u001b[38;5;66;03m# step 1: forward the query images through the frozen CLIP vision encoder\u001b[39;00m\n\u001b[1;32m 1432\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1433\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclip\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvision_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1436\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# we need the intermediate hidden states\u001b[39;49;00m\n\u001b[1;32m 1437\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1438\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1439\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclip\u001b[38;5;241m.\u001b[39mvisual_projection(vision_outputs[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 1441\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m vision_outputs\u001b[38;5;241m.\u001b[39mhidden_states \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;28;01melse\u001b[39;00m vision_outputs[\u001b[38;5;241m2\u001b[39m]\n",
119
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
120
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:872\u001b[0m, in \u001b[0;36mCLIPSegVisionTransformer.forward\u001b[0;34m(self, pixel_values, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 869\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pixel_values \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou have to specify pixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 874\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(pixel_values)\n\u001b[1;32m 875\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_layrnorm(hidden_states)\n",
121
+ "\u001b[0;31mValueError\u001b[0m: You have to specify pixel_values"
122
+ ]
123
+ }
124
+ ],
125
+ "source": [
126
+ "# get model specific transforms (normalization, resize)\n",
127
+ "data_config = timm.data.resolve_model_data_config(model)\n",
128
+ "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
129
+ "\n",
130
+ "output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor\n",
131
+ "print(output.shape)\n",
132
+ "# or equivalently (without needing to set num_classes=0)\n",
133
+ "\n",
134
+ "output = model.forward_features(transforms(img).unsqueeze(0))\n",
135
+ "# output is unpooled, a (1, 257, 768) shaped tensor\n",
136
+ "print(output.shape)\n",
137
+ "\n",
138
+ "output = model.forward_head(output, pre_logits=True)\n",
139
+ "# output is a (1, num_features) shaped tensor\n",
140
+ "print(output.shape)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "54275c4c-e506-4959-92f1-29e584f5ce51",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "model.forward_features(transforms(img).unsqueeze(0)).shape"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "id": "6546c673-f3ab-4d43-a051-cab20e782bab",
156
+ "metadata": {},
157
+ "source": [
158
+ "# Eva02-clip"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 29,
164
+ "id": "dfbc95de-9af9-4583-98fc-b8061114ef64",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "import timm \n",
169
+ "# couldnt figure out how to load pretrained model from shared folder rather than home directory using timm...\n",
170
+ "eva02_model = timm.create_model(\"eva02_enormous_patch14_clip_224.laion2b\", pretrained=True)\n",
171
+ "# eva02_model.head_drop = nn.Identity()\n",
172
+ "# eva02_model.head = nn.Identity()"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 30,
178
+ "id": "97e3ea29-ae6b-4bd2-b3d7-17839098a6e4",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "text/plain": [
184
+ "torch.Size([2, 1024])"
185
+ ]
186
+ },
187
+ "execution_count": 30,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "eva02_model(torch.randn((2,3,224,224))).shape"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 31,
199
+ "id": "069b76f0-029f-42b1-85f5-a492ee1cc5d1",
200
+ "metadata": {},
201
+ "outputs": [
202
+ {
203
+ "name": "stdout",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "torch.Size([2, 256, 1024])\n"
207
+ ]
208
+ }
209
+ ],
210
+ "source": [
211
+ "image = torch.randn((2,3,224,224))\n",
212
+ "\n",
213
+ "input = eva02_model.patch_embed(image)\n",
214
+ "input = eva02_model.pos_drop(input)\n",
215
+ "for block in eva02_model.blocks: input = block(input)\n",
216
+ "input = eva02_model.norm(input)\n",
217
+ "input = eva02_model.fc_norm(input)\n",
218
+ "input = eva02_model.head_drop(input)\n",
219
+ "input = eva02_model.head(input)\n",
220
+ "\n",
221
+ "print(input.shape)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 32,
227
+ "id": "90e4e8e7-3dd1-43b0-a305-066a6ec13c2e",
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "Help on Eva in module timm.models.eva object:\n",
235
+ "\n",
236
+ "class Eva(torch.nn.modules.module.Module)\n",
237
+ " | Eva(img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
238
+ " | \n",
239
+ " | Eva Vision Transformer w/ Abs & Rotary Pos Embed\n",
240
+ " | \n",
241
+ " | This class implements the EVA and EVA02 models that were based on the BEiT ViT variant\n",
242
+ " | * EVA - abs pos embed, global avg pool\n",
243
+ " | * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)\n",
244
+ " | \n",
245
+ " | Method resolution order:\n",
246
+ " | Eva\n",
247
+ " | torch.nn.modules.module.Module\n",
248
+ " | builtins.object\n",
249
+ " | \n",
250
+ " | Methods defined here:\n",
251
+ " | \n",
252
+ " | __init__(self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
253
+ " | Args:\n",
254
+ " | img_size:\n",
255
+ " | patch_size:\n",
256
+ " | in_chans:\n",
257
+ " | num_classes:\n",
258
+ " | global_pool:\n",
259
+ " | embed_dim:\n",
260
+ " | depth:\n",
261
+ " | num_heads:\n",
262
+ " | qkv_bias:\n",
263
+ " | qkv_fused:\n",
264
+ " | mlp_ratio:\n",
265
+ " | swiglu_mlp:\n",
266
+ " | scale_mlp:\n",
267
+ " | scale_attn_inner:\n",
268
+ " | drop_rate:\n",
269
+ " | pos_drop_rate:\n",
270
+ " | proj_drop_rate:\n",
271
+ " | attn_drop_rate:\n",
272
+ " | drop_path_rate:\n",
273
+ " | norm_layer:\n",
274
+ " | init_values:\n",
275
+ " | class_token:\n",
276
+ " | use_abs_pos_emb:\n",
277
+ " | use_rot_pos_emb:\n",
278
+ " | use_post_norm:\n",
279
+ " | ref_feat_shape:\n",
280
+ " | head_init_scale:\n",
281
+ " | \n",
282
+ " | fix_init_weight(self)\n",
283
+ " | \n",
284
+ " | forward(self, x)\n",
285
+ " | Defines the computation performed at every call.\n",
286
+ " | \n",
287
+ " | Should be overridden by all subclasses.\n",
288
+ " | \n",
289
+ " | .. note::\n",
290
+ " | Although the recipe for forward pass needs to be defined within\n",
291
+ " | this function, one should call the :class:`Module` instance afterwards\n",
292
+ " | instead of this since the former takes care of running the\n",
293
+ " | registered hooks while the latter silently ignores them.\n",
294
+ " | \n",
295
+ " | forward_features(self, x)\n",
296
+ " | \n",
297
+ " | forward_head(self, x, pre_logits: bool = False)\n",
298
+ " | \n",
299
+ " | get_classifier(self)\n",
300
+ " | \n",
301
+ " | group_matcher(self, coarse=False)\n",
302
+ " | \n",
303
+ " | no_weight_decay(self)\n",
304
+ " | \n",
305
+ " | reset_classifier(self, num_classes, global_pool=None)\n",
306
+ " | \n",
307
+ " | set_grad_checkpointing(self, enable=True)\n",
308
+ " | \n",
309
+ " | ----------------------------------------------------------------------\n",
310
+ " | Data and other attributes defined here:\n",
311
+ " | \n",
312
+ " | __annotations__ = {}\n",
313
+ " | \n",
314
+ " | ----------------------------------------------------------------------\n",
315
+ " | Methods inherited from torch.nn.modules.module.Module:\n",
316
+ " | \n",
317
+ " | __call__ = _call_impl(self, *args, **kwargs)\n",
318
+ " | \n",
319
+ " | __delattr__(self, name)\n",
320
+ " | Implement delattr(self, name).\n",
321
+ " | \n",
322
+ " | __dir__(self)\n",
323
+ " | Default dir() implementation.\n",
324
+ " | \n",
325
+ " | __getattr__(self, name: str) -> Union[torch.Tensor, ForwardRef('Module')]\n",
326
+ " | \n",
327
+ " | __repr__(self)\n",
328
+ " | Return repr(self).\n",
329
+ " | \n",
330
+ " | __setattr__(self, name: str, value: Union[torch.Tensor, ForwardRef('Module')]) -> None\n",
331
+ " | Implement setattr(self, name, value).\n",
332
+ " | \n",
333
+ " | __setstate__(self, state)\n",
334
+ " | \n",
335
+ " | add_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
336
+ " | Adds a child module to the current module.\n",
337
+ " | \n",
338
+ " | The module can be accessed as an attribute using the given name.\n",
339
+ " | \n",
340
+ " | Args:\n",
341
+ " | name (str): name of the child module. The child module can be\n",
342
+ " | accessed from this module using the given name\n",
343
+ " | module (Module): child module to be added to the module.\n",
344
+ " | \n",
345
+ " | apply(self: ~T, fn: Callable[[ForwardRef('Module')], NoneType]) -> ~T\n",
346
+ " | Applies ``fn`` recursively to every submodule (as returned by ``.children()``)\n",
347
+ " | as well as self. Typical use includes initializing the parameters of a model\n",
348
+ " | (see also :ref:`nn-init-doc`).\n",
349
+ " | \n",
350
+ " | Args:\n",
351
+ " | fn (:class:`Module` -> None): function to be applied to each submodule\n",
352
+ " | \n",
353
+ " | Returns:\n",
354
+ " | Module: self\n",
355
+ " | \n",
356
+ " | Example::\n",
357
+ " | \n",
358
+ " | >>> @torch.no_grad()\n",
359
+ " | >>> def init_weights(m):\n",
360
+ " | >>> print(m)\n",
361
+ " | >>> if type(m) == nn.Linear:\n",
362
+ " | >>> m.weight.fill_(1.0)\n",
363
+ " | >>> print(m.weight)\n",
364
+ " | >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
365
+ " | >>> net.apply(init_weights)\n",
366
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
367
+ " | Parameter containing:\n",
368
+ " | tensor([[1., 1.],\n",
369
+ " | [1., 1.]], requires_grad=True)\n",
370
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
371
+ " | Parameter containing:\n",
372
+ " | tensor([[1., 1.],\n",
373
+ " | [1., 1.]], requires_grad=True)\n",
374
+ " | Sequential(\n",
375
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
376
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
377
+ " | )\n",
378
+ " | \n",
379
+ " | bfloat16(self: ~T) -> ~T\n",
380
+ " | Casts all floating point parameters and buffers to ``bfloat16`` datatype.\n",
381
+ " | \n",
382
+ " | .. note::\n",
383
+ " | This method modifies the module in-place.\n",
384
+ " | \n",
385
+ " | Returns:\n",
386
+ " | Module: self\n",
387
+ " | \n",
388
+ " | buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]\n",
389
+ " | Returns an iterator over module buffers.\n",
390
+ " | \n",
391
+ " | Args:\n",
392
+ " | recurse (bool): if True, then yields buffers of this module\n",
393
+ " | and all submodules. Otherwise, yields only buffers that\n",
394
+ " | are direct members of this module.\n",
395
+ " | \n",
396
+ " | Yields:\n",
397
+ " | torch.Tensor: module buffer\n",
398
+ " | \n",
399
+ " | Example::\n",
400
+ " | \n",
401
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
402
+ " | >>> for buf in model.buffers():\n",
403
+ " | >>> print(type(buf), buf.size())\n",
404
+ " | <class 'torch.Tensor'> (20L,)\n",
405
+ " | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
406
+ " | \n",
407
+ " | children(self) -> Iterator[ForwardRef('Module')]\n",
408
+ " | Returns an iterator over immediate children modules.\n",
409
+ " | \n",
410
+ " | Yields:\n",
411
+ " | Module: a child module\n",
412
+ " | \n",
413
+ " | cpu(self: ~T) -> ~T\n",
414
+ " | Moves all model parameters and buffers to the CPU.\n",
415
+ " | \n",
416
+ " | .. note::\n",
417
+ " | This method modifies the module in-place.\n",
418
+ " | \n",
419
+ " | Returns:\n",
420
+ " | Module: self\n",
421
+ " | \n",
422
+ " | cuda(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
423
+ " | Moves all model parameters and buffers to the GPU.\n",
424
+ " | \n",
425
+ " | This also makes associated parameters and buffers different objects. So\n",
426
+ " | it should be called before constructing optimizer if the module will\n",
427
+ " | live on GPU while being optimized.\n",
428
+ " | \n",
429
+ " | .. note::\n",
430
+ " | This method modifies the module in-place.\n",
431
+ " | \n",
432
+ " | Args:\n",
433
+ " | device (int, optional): if specified, all parameters will be\n",
434
+ " | copied to that device\n",
435
+ " | \n",
436
+ " | Returns:\n",
437
+ " | Module: self\n",
438
+ " | \n",
439
+ " | double(self: ~T) -> ~T\n",
440
+ " | Casts all floating point parameters and buffers to ``double`` datatype.\n",
441
+ " | \n",
442
+ " | .. note::\n",
443
+ " | This method modifies the module in-place.\n",
444
+ " | \n",
445
+ " | Returns:\n",
446
+ " | Module: self\n",
447
+ " | \n",
448
+ " | eval(self: ~T) -> ~T\n",
449
+ " | Sets the module in evaluation mode.\n",
450
+ " | \n",
451
+ " | This has any effect only on certain modules. See documentations of\n",
452
+ " | particular modules for details of their behaviors in training/evaluation\n",
453
+ " | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
454
+ " | etc.\n",
455
+ " | \n",
456
+ " | This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n",
457
+ " | \n",
458
+ " | See :ref:`locally-disable-grad-doc` for a comparison between\n",
459
+ " | `.eval()` and several similar mechanisms that may be confused with it.\n",
460
+ " | \n",
461
+ " | Returns:\n",
462
+ " | Module: self\n",
463
+ " | \n",
464
+ " | extra_repr(self) -> str\n",
465
+ " | Set the extra representation of the module\n",
466
+ " | \n",
467
+ " | To print customized extra information, you should re-implement\n",
468
+ " | this method in your own modules. Both single-line and multi-line\n",
469
+ " | strings are acceptable.\n",
470
+ " | \n",
471
+ " | float(self: ~T) -> ~T\n",
472
+ " | Casts all floating point parameters and buffers to ``float`` datatype.\n",
473
+ " | \n",
474
+ " | .. note::\n",
475
+ " | This method modifies the module in-place.\n",
476
+ " | \n",
477
+ " | Returns:\n",
478
+ " | Module: self\n",
479
+ " | \n",
480
+ " | get_buffer(self, target: str) -> 'Tensor'\n",
481
+ " | Returns the buffer given by ``target`` if it exists,\n",
482
+ " | otherwise throws an error.\n",
483
+ " | \n",
484
+ " | See the docstring for ``get_submodule`` for a more detailed\n",
485
+ " | explanation of this method's functionality as well as how to\n",
486
+ " | correctly specify ``target``.\n",
487
+ " | \n",
488
+ " | Args:\n",
489
+ " | target: The fully-qualified string name of the buffer\n",
490
+ " | to look for. (See ``get_submodule`` for how to specify a\n",
491
+ " | fully-qualified string.)\n",
492
+ " | \n",
493
+ " | Returns:\n",
494
+ " | torch.Tensor: The buffer referenced by ``target``\n",
495
+ " | \n",
496
+ " | Raises:\n",
497
+ " | AttributeError: If the target string references an invalid\n",
498
+ " | path or resolves to something that is not a\n",
499
+ " | buffer\n",
500
+ " | \n",
501
+ " | get_extra_state(self) -> Any\n",
502
+ " | Returns any extra state to include in the module's state_dict.\n",
503
+ " | Implement this and a corresponding :func:`set_extra_state` for your module\n",
504
+ " | if you need to store extra state. This function is called when building the\n",
505
+ " | module's `state_dict()`.\n",
506
+ " | \n",
507
+ " | Note that extra state should be picklable to ensure working serialization\n",
508
+ " | of the state_dict. We only provide provide backwards compatibility guarantees\n",
509
+ " | for serializing Tensors; other objects may break backwards compatibility if\n",
510
+ " | their serialized pickled form changes.\n",
511
+ " | \n",
512
+ " | Returns:\n",
513
+ " | object: Any extra state to store in the module's state_dict\n",
514
+ " | \n",
515
+ " | get_parameter(self, target: str) -> 'Parameter'\n",
516
+ " | Returns the parameter given by ``target`` if it exists,\n",
517
+ " | otherwise throws an error.\n",
518
+ " | \n",
519
+ " | See the docstring for ``get_submodule`` for a more detailed\n",
520
+ " | explanation of this method's functionality as well as how to\n",
521
+ " | correctly specify ``target``.\n",
522
+ " | \n",
523
+ " | Args:\n",
524
+ " | target: The fully-qualified string name of the Parameter\n",
525
+ " | to look for. (See ``get_submodule`` for how to specify a\n",
526
+ " | fully-qualified string.)\n",
527
+ " | \n",
528
+ " | Returns:\n",
529
+ " | torch.nn.Parameter: The Parameter referenced by ``target``\n",
530
+ " | \n",
531
+ " | Raises:\n",
532
+ " | AttributeError: If the target string references an invalid\n",
533
+ " | path or resolves to something that is not an\n",
534
+ " | ``nn.Parameter``\n",
535
+ " | \n",
536
+ " | get_submodule(self, target: str) -> 'Module'\n",
537
+ " | Returns the submodule given by ``target`` if it exists,\n",
538
+ " | otherwise throws an error.\n",
539
+ " | \n",
540
+ " | For example, let's say you have an ``nn.Module`` ``A`` that\n",
541
+ " | looks like this:\n",
542
+ " | \n",
543
+ " | .. code-block:: text\n",
544
+ " | \n",
545
+ " | A(\n",
546
+ " | (net_b): Module(\n",
547
+ " | (net_c): Module(\n",
548
+ " | (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n",
549
+ " | )\n",
550
+ " | (linear): Linear(in_features=100, out_features=200, bias=True)\n",
551
+ " | )\n",
552
+ " | )\n",
553
+ " | \n",
554
+ " | (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested\n",
555
+ " | submodule ``net_b``, which itself has two submodules ``net_c``\n",
556
+ " | and ``linear``. ``net_c`` then has a submodule ``conv``.)\n",
557
+ " | \n",
558
+ " | To check whether or not we have the ``linear`` submodule, we\n",
559
+ " | would call ``get_submodule(\"net_b.linear\")``. To check whether\n",
560
+ " | we have the ``conv`` submodule, we would call\n",
561
+ " | ``get_submodule(\"net_b.net_c.conv\")``.\n",
562
+ " | \n",
563
+ " | The runtime of ``get_submodule`` is bounded by the degree\n",
564
+ " | of module nesting in ``target``. A query against\n",
565
+ " | ``named_modules`` achieves the same result, but it is O(N) in\n",
566
+ " | the number of transitive modules. So, for a simple check to see\n",
567
+ " | if some submodule exists, ``get_submodule`` should always be\n",
568
+ " | used.\n",
569
+ " | \n",
570
+ " | Args:\n",
571
+ " | target: The fully-qualified string name of the submodule\n",
572
+ " | to look for. (See above example for how to specify a\n",
573
+ " | fully-qualified string.)\n",
574
+ " | \n",
575
+ " | Returns:\n",
576
+ " | torch.nn.Module: The submodule referenced by ``target``\n",
577
+ " | \n",
578
+ " | Raises:\n",
579
+ " | AttributeError: If the target string references an invalid\n",
580
+ " | path or resolves to something that is not an\n",
581
+ " | ``nn.Module``\n",
582
+ " | \n",
583
+ " | half(self: ~T) -> ~T\n",
584
+ " | Casts all floating point parameters and buffers to ``half`` datatype.\n",
585
+ " | \n",
586
+ " | .. note::\n",
587
+ " | This method modifies the module in-place.\n",
588
+ " | \n",
589
+ " | Returns:\n",
590
+ " | Module: self\n",
591
+ " | \n",
592
+ " | ipu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
593
+ " | Moves all model parameters and buffers to the IPU.\n",
594
+ " | \n",
595
+ " | This also makes associated parameters and buffers different objects. So\n",
596
+ " | it should be called before constructing optimizer if the module will\n",
597
+ " | live on IPU while being optimized.\n",
598
+ " | \n",
599
+ " | .. note::\n",
600
+ " | This method modifies the module in-place.\n",
601
+ " | \n",
602
+ " | Arguments:\n",
603
+ " | device (int, optional): if specified, all parameters will be\n",
604
+ " | copied to that device\n",
605
+ " | \n",
606
+ " | Returns:\n",
607
+ " | Module: self\n",
608
+ " | \n",
609
+ " | load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True)\n",
610
+ " | Copies parameters and buffers from :attr:`state_dict` into\n",
611
+ " | this module and its descendants. If :attr:`strict` is ``True``, then\n",
612
+ " | the keys of :attr:`state_dict` must exactly match the keys returned\n",
613
+ " | by this module's :meth:`~torch.nn.Module.state_dict` function.\n",
614
+ " | \n",
615
+ " | Args:\n",
616
+ " | state_dict (dict): a dict containing parameters and\n",
617
+ " | persistent buffers.\n",
618
+ " | strict (bool, optional): whether to strictly enforce that the keys\n",
619
+ " | in :attr:`state_dict` match the keys returned by this module's\n",
620
+ " | :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n",
621
+ " | \n",
622
+ " | Returns:\n",
623
+ " | ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n",
624
+ " | * **missing_keys** is a list of str containing the missing keys\n",
625
+ " | * **unexpected_keys** is a list of str containing the unexpected keys\n",
626
+ " | \n",
627
+ " | Note:\n",
628
+ " | If a parameter or buffer is registered as ``None`` and its corresponding key\n",
629
+ " | exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n",
630
+ " | ``RuntimeError``.\n",
631
+ " | \n",
632
+ " | modules(self) -> Iterator[ForwardRef('Module')]\n",
633
+ " | Returns an iterator over all modules in the network.\n",
634
+ " | \n",
635
+ " | Yields:\n",
636
+ " | Module: a module in the network\n",
637
+ " | \n",
638
+ " | Note:\n",
639
+ " | Duplicate modules are returned only once. In the following\n",
640
+ " | example, ``l`` will be returned only once.\n",
641
+ " | \n",
642
+ " | Example::\n",
643
+ " | \n",
644
+ " | >>> l = nn.Linear(2, 2)\n",
645
+ " | >>> net = nn.Sequential(l, l)\n",
646
+ " | >>> for idx, m in enumerate(net.modules()):\n",
647
+ " | ... print(idx, '->', m)\n",
648
+ " | \n",
649
+ " | 0 -> Sequential(\n",
650
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
651
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
652
+ " | )\n",
653
+ " | 1 -> Linear(in_features=2, out_features=2, bias=True)\n",
654
+ " | \n",
655
+ " | named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]\n",
656
+ " | Returns an iterator over module buffers, yielding both the\n",
657
+ " | name of the buffer as well as the buffer itself.\n",
658
+ " | \n",
659
+ " | Args:\n",
660
+ " | prefix (str): prefix to prepend to all buffer names.\n",
661
+ " | recurse (bool, optional): if True, then yields buffers of this module\n",
662
+ " | and all submodules. Otherwise, yields only buffers that\n",
663
+ " | are direct members of this module. Defaults to True.\n",
664
+ " | remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.\n",
665
+ " | \n",
666
+ " | Yields:\n",
667
+ " | (str, torch.Tensor): Tuple containing the name and buffer\n",
668
+ " | \n",
669
+ " | Example::\n",
670
+ " | \n",
671
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
672
+ " | >>> for name, buf in self.named_buffers():\n",
673
+ " | >>> if name in ['running_var']:\n",
674
+ " | >>> print(buf.size())\n",
675
+ " | \n",
676
+ " | named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]\n",
677
+ " | Returns an iterator over immediate children modules, yielding both\n",
678
+ " | the name of the module as well as the module itself.\n",
679
+ " | \n",
680
+ " | Yields:\n",
681
+ " | (str, Module): Tuple containing a name and child module\n",
682
+ " | \n",
683
+ " | Example::\n",
684
+ " | \n",
685
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
686
+ " | >>> for name, module in model.named_children():\n",
687
+ " | >>> if name in ['conv4', 'conv5']:\n",
688
+ " | >>> print(module)\n",
689
+ " | \n",
690
+ " | named_modules(self, memo: Optional[Set[ForwardRef('Module')]] = None, prefix: str = '', remove_duplicate: bool = True)\n",
691
+ " | Returns an iterator over all modules in the network, yielding\n",
692
+ " | both the name of the module as well as the module itself.\n",
693
+ " | \n",
694
+ " | Args:\n",
695
+ " | memo: a memo to store the set of modules already added to the result\n",
696
+ " | prefix: a prefix that will be added to the name of the module\n",
697
+ " | remove_duplicate: whether to remove the duplicated module instances in the result\n",
698
+ " | or not\n",
699
+ " | \n",
700
+ " | Yields:\n",
701
+ " | (str, Module): Tuple of name and module\n",
702
+ " | \n",
703
+ " | Note:\n",
704
+ " | Duplicate modules are returned only once. In the following\n",
705
+ " | example, ``l`` will be returned only once.\n",
706
+ " | \n",
707
+ " | Example::\n",
708
+ " | \n",
709
+ " | >>> l = nn.Linear(2, 2)\n",
710
+ " | >>> net = nn.Sequential(l, l)\n",
711
+ " | >>> for idx, m in enumerate(net.named_modules()):\n",
712
+ " | ... print(idx, '->', m)\n",
713
+ " | \n",
714
+ " | 0 -> ('', Sequential(\n",
715
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
716
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
717
+ " | ))\n",
718
+ " | 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n",
719
+ " | \n",
720
+ " | named_parameters(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]\n",
721
+ " | Returns an iterator over module parameters, yielding both the\n",
722
+ " | name of the parameter as well as the parameter itself.\n",
723
+ " | \n",
724
+ " | Args:\n",
725
+ " | prefix (str): prefix to prepend to all parameter names.\n",
726
+ " | recurse (bool): if True, then yields parameters of this module\n",
727
+ " | and all submodules. Otherwise, yields only parameters that\n",
728
+ " | are direct members of this module.\n",
729
+ " | remove_duplicate (bool, optional): whether to remove the duplicated\n",
730
+ " | parameters in the result. Defaults to True.\n",
731
+ " | \n",
732
+ " | Yields:\n",
733
+ " | (str, Parameter): Tuple containing the name and parameter\n",
734
+ " | \n",
735
+ " | Example::\n",
736
+ " | \n",
737
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
738
+ " | >>> for name, param in self.named_parameters():\n",
739
+ " | >>> if name in ['bias']:\n",
740
+ " | >>> print(param.size())\n",
741
+ " | \n",
742
+ " | parameters(self, recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]\n",
743
+ " | Returns an iterator over module parameters.\n",
744
+ " | \n",
745
+ " | This is typically passed to an optimizer.\n",
746
+ " | \n",
747
+ " | Args:\n",
748
+ " | recurse (bool): if True, then yields parameters of this module\n",
749
+ " | and all submodules. Otherwise, yields only parameters that\n",
750
+ " | are direct members of this module.\n",
751
+ " | \n",
752
+ " | Yields:\n",
753
+ " | Parameter: module parameter\n",
754
+ " | \n",
755
+ " | Example::\n",
756
+ " | \n",
757
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
758
+ " | >>> for param in model.parameters():\n",
759
+ " | >>> print(type(param), param.size())\n",
760
+ " | <class 'torch.Tensor'> (20L,)\n",
761
+ " | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
762
+ " | \n",
763
+ " | register_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle\n",
764
+ " | Registers a backward hook on the module.\n",
765
+ " | \n",
766
+ " | This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and\n",
767
+ " | the behavior of this function will change in future versions.\n",
768
+ " | \n",
769
+ " | Returns:\n",
770
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
771
+ " | a handle that can be used to remove the added hook by calling\n",
772
+ " | ``handle.remove()``\n",
773
+ " | \n",
774
+ " | register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None\n",
775
+ " | Adds a buffer to the module.\n",
776
+ " | \n",
777
+ " | This is typically used to register a buffer that should not to be\n",
778
+ " | considered a model parameter. For example, BatchNorm's ``running_mean``\n",
779
+ " | is not a parameter, but is part of the module's state. Buffers, by\n",
780
+ " | default, are persistent and will be saved alongside parameters. This\n",
781
+ " | behavior can be changed by setting :attr:`persistent` to ``False``. The\n",
782
+ " | only difference between a persistent buffer and a non-persistent buffer\n",
783
+ " | is that the latter will not be a part of this module's\n",
784
+ " | :attr:`state_dict`.\n",
785
+ " | \n",
786
+ " | Buffers can be accessed as attributes using given names.\n",
787
+ " | \n",
788
+ " | Args:\n",
789
+ " | name (str): name of the buffer. The buffer can be accessed\n",
790
+ " | from this module using the given name\n",
791
+ " | tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n",
792
+ " | that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n",
793
+ " | the buffer is **not** included in the module's :attr:`state_dict`.\n",
794
+ " | persistent (bool): whether the buffer is part of this module's\n",
795
+ " | :attr:`state_dict`.\n",
796
+ " | \n",
797
+ " | Example::\n",
798
+ " | \n",
799
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
800
+ " | >>> self.register_buffer('running_mean', torch.zeros(num_features))\n",
801
+ " | \n",
802
+ " | register_forward_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
803
+ " | Registers a forward hook on the module.\n",
804
+ " | \n",
805
+ " | The hook will be called every time after :func:`forward` has computed an output.\n",
806
+ " | \n",
807
+ " | If ``with_kwargs`` is ``False`` or not specified, the input contains only\n",
808
+ " | the positional arguments given to the module. Keyword arguments won't be\n",
809
+ " | passed to the hooks and only to the ``forward``. The hook can modify the\n",
810
+ " | output. It can modify the input inplace but it will not have effect on\n",
811
+ " | forward since this is called after :func:`forward` is called. The hook\n",
812
+ " | should have the following signature::\n",
813
+ " | \n",
814
+ " | hook(module, args, output) -> None or modified output\n",
815
+ " | \n",
816
+ " | If ``with_kwargs`` is ``True``, the forward hook will be passed the\n",
817
+ " | ``kwargs`` given to the forward function and be expected to return the\n",
818
+ " | output possibly modified. The hook should have the following signature::\n",
819
+ " | \n",
820
+ " | hook(module, args, kwargs, output) -> None or modified output\n",
821
+ " | \n",
822
+ " | Args:\n",
823
+ " | hook (Callable): The user defined hook to be registered.\n",
824
+ " | prepend (bool): If ``True``, the provided ``hook`` will be fired\n",
825
+ " | before all existing ``forward`` hooks on this\n",
826
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
827
+ " | ``hook`` will be fired after all existing ``forward`` hooks on\n",
828
+ " | this :class:`torch.nn.modules.Module`. Note that global\n",
829
+ " | ``forward`` hooks registered with\n",
830
+ " | :func:`register_module_forward_hook` will fire before all hooks\n",
831
+ " | registered by this method.\n",
832
+ " | Default: ``False``\n",
833
+ " | with_kwargs (bool): If ``True``, the ``hook`` will be passed the\n",
834
+ " | kwargs given to the forward function.\n",
835
+ " | Default: ``False``\n",
836
+ " | \n",
837
+ " | Returns:\n",
838
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
839
+ " | a handle that can be used to remove the added hook by calling\n",
840
+ " | ``handle.remove()``\n",
841
+ " | \n",
842
+ " | register_forward_pre_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
843
+ " | Registers a forward pre-hook on the module.\n",
844
+ " | \n",
845
+ " | The hook will be called every time before :func:`forward` is invoked.\n",
846
+ " | \n",
847
+ " | \n",
848
+ " | If ``with_kwargs`` is false or not specified, the input contains only\n",
849
+ " | the positional arguments given to the module. Keyword arguments won't be\n",
850
+ " | passed to the hooks and only to the ``forward``. The hook can modify the\n",
851
+ " | input. User can either return a tuple or a single modified value in the\n",
852
+ " | hook. We will wrap the value into a tuple if a single value is returned\n",
853
+ " | (unless that value is already a tuple). The hook should have the\n",
854
+ " | following signature::\n",
855
+ " | \n",
856
+ " | hook(module, args) -> None or modified input\n",
857
+ " | \n",
858
+ " | If ``with_kwargs`` is true, the forward pre-hook will be passed the\n",
859
+ " | kwargs given to the forward function. And if the hook modifies the\n",
860
+ " | input, both the args and kwargs should be returned. The hook should have\n",
861
+ " | the following signature::\n",
862
+ " | \n",
863
+ " | hook(module, args, kwargs) -> None or a tuple of modified input and kwargs\n",
864
+ " | \n",
865
+ " | Args:\n",
866
+ " | hook (Callable): The user defined hook to be registered.\n",
867
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
868
+ " | all existing ``forward_pre`` hooks on this\n",
869
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
870
+ " | ``hook`` will be fired after all existing ``forward_pre`` hooks\n",
871
+ " | on this :class:`torch.nn.modules.Module`. Note that global\n",
872
+ " | ``forward_pre`` hooks registered with\n",
873
+ " | :func:`register_module_forward_pre_hook` will fire before all\n",
874
+ " | hooks registered by this method.\n",
875
+ " | Default: ``False``\n",
876
+ " | with_kwargs (bool): If true, the ``hook`` will be passed the kwargs\n",
877
+ " | given to the forward function.\n",
878
+ " | Default: ``False``\n",
879
+ " | \n",
880
+ " | Returns:\n",
881
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
882
+ " | a handle that can be used to remove the added hook by calling\n",
883
+ " | ``handle.remove()``\n",
884
+ " | \n",
885
+ " | register_full_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
886
+ " | Registers a backward hook on the module.\n",
887
+ " | \n",
888
+ " | The hook will be called every time the gradients with respect to a module\n",
889
+ " | are computed, i.e. the hook will execute if and only if the gradients with\n",
890
+ " | respect to module outputs are computed. The hook should have the following\n",
891
+ " | signature::\n",
892
+ " | \n",
893
+ " | hook(module, grad_input, grad_output) -> tuple(Tensor) or None\n",
894
+ " | \n",
895
+ " | The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients\n",
896
+ " | with respect to the inputs and outputs respectively. The hook should\n",
897
+ " | not modify its arguments, but it can optionally return a new gradient with\n",
898
+ " | respect to the input that will be used in place of :attr:`grad_input` in\n",
899
+ " | subsequent computations. :attr:`grad_input` will only correspond to the inputs given\n",
900
+ " | as positional arguments and all kwarg arguments are ignored. Entries\n",
901
+ " | in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n",
902
+ " | arguments.\n",
903
+ " | \n",
904
+ " | For technical reasons, when this hook is applied to a Module, its forward function will\n",
905
+ " | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
906
+ " | of each Tensor returned by the Module's forward function.\n",
907
+ " | \n",
908
+ " | .. warning ::\n",
909
+ " | Modifying inputs or outputs inplace is not allowed when using backward hooks and\n",
910
+ " | will raise an error.\n",
911
+ " | \n",
912
+ " | Args:\n",
913
+ " | hook (Callable): The user-defined hook to be registered.\n",
914
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
915
+ " | all existing ``backward`` hooks on this\n",
916
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
917
+ " | ``hook`` will be fired after all existing ``backward`` hooks on\n",
918
+ " | this :class:`torch.nn.modules.Module`. Note that global\n",
919
+ " | ``backward`` hooks registered with\n",
920
+ " | :func:`register_module_full_backward_hook` will fire before\n",
921
+ " | all hooks registered by this method.\n",
922
+ " | \n",
923
+ " | Returns:\n",
924
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
925
+ " | a handle that can be used to remove the added hook by calling\n",
926
+ " | ``handle.remove()``\n",
927
+ " | \n",
928
+ " | register_full_backward_pre_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
929
+ " | Registers a backward pre-hook on the module.\n",
930
+ " | \n",
931
+ " | The hook will be called every time the gradients for the module are computed.\n",
932
+ " | The hook should have the following signature::\n",
933
+ " | \n",
934
+ " | hook(module, grad_output) -> Tensor or None\n",
935
+ " | \n",
936
+ " | The :attr:`grad_output` is a tuple. The hook should\n",
937
+ " | not modify its arguments, but it can optionally return a new gradient with\n",
938
+ " | respect to the output that will be used in place of :attr:`grad_output` in\n",
939
+ " | subsequent computations. Entries in :attr:`grad_output` will be ``None`` for\n",
940
+ " | all non-Tensor arguments.\n",
941
+ " | \n",
942
+ " | For technical reasons, when this hook is applied to a Module, its forward function will\n",
943
+ " | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
944
+ " | of each Tensor returned by the Module's forward function.\n",
945
+ " | \n",
946
+ " | .. warning ::\n",
947
+ " | Modifying inputs inplace is not allowed when using backward hooks and\n",
948
+ " | will raise an error.\n",
949
+ " | \n",
950
+ " | Args:\n",
951
+ " | hook (Callable): The user-defined hook to be registered.\n",
952
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
953
+ " | all existing ``backward_pre`` hooks on this\n",
954
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
955
+ " | ``hook`` will be fired after all existing ``backward_pre`` hooks\n",
956
+ " | on this :class:`torch.nn.modules.Module`. Note that global\n",
957
+ " | ``backward_pre`` hooks registered with\n",
958
+ " | :func:`register_module_full_backward_pre_hook` will fire before\n",
959
+ " | all hooks registered by this method.\n",
960
+ " | \n",
961
+ " | Returns:\n",
962
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
963
+ " | a handle that can be used to remove the added hook by calling\n",
964
+ " | ``handle.remove()``\n",
965
+ " | \n",
966
+ " | register_load_state_dict_post_hook(self, hook)\n",
967
+ " | Registers a post hook to be run after module's ``load_state_dict``\n",
968
+ " | is called.\n",
969
+ " | \n",
970
+ " | It should have the following signature::\n",
971
+ " | hook(module, incompatible_keys) -> None\n",
972
+ " | \n",
973
+ " | The ``module`` argument is the current module that this hook is registered\n",
974
+ " | on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting\n",
975
+ " | of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``\n",
976
+ " | is a ``list`` of ``str`` containing the missing keys and\n",
977
+ " | ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.\n",
978
+ " | \n",
979
+ " | The given incompatible_keys can be modified inplace if needed.\n",
980
+ " | \n",
981
+ " | Note that the checks performed when calling :func:`load_state_dict` with\n",
982
+ " | ``strict=True`` are affected by modifications the hook makes to\n",
983
+ " | ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either\n",
984
+ " | set of keys will result in an error being thrown when ``strict=True``, and\n",
985
+ " | clearing out both missing and unexpected keys will avoid an error.\n",
986
+ " | \n",
987
+ " | Returns:\n",
988
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
989
+ " | a handle that can be used to remove the added hook by calling\n",
990
+ " | ``handle.remove()``\n",
991
+ " | \n",
992
+ " | register_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
993
+ " | Alias for :func:`add_module`.\n",
994
+ " | \n",
995
+ " | register_parameter(self, name: str, param: Optional[torch.nn.parameter.Parameter]) -> None\n",
996
+ " | Adds a parameter to the module.\n",
997
+ " | \n",
998
+ " | The parameter can be accessed as an attribute using given name.\n",
999
+ " | \n",
1000
+ " | Args:\n",
1001
+ " | name (str): name of the parameter. The parameter can be accessed\n",
1002
+ " | from this module using the given name\n",
1003
+ " | param (Parameter or None): parameter to be added to the module. If\n",
1004
+ " | ``None``, then operations that run on parameters, such as :attr:`cuda`,\n",
1005
+ " | are ignored. If ``None``, the parameter is **not** included in the\n",
1006
+ " | module's :attr:`state_dict`.\n",
1007
+ " | \n",
1008
+ " | register_state_dict_pre_hook(self, hook)\n",
1009
+ " | These hooks will be called with arguments: ``self``, ``prefix``,\n",
1010
+ " | and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered\n",
1011
+ " | hooks can be used to perform pre-processing before the ``state_dict``\n",
1012
+ " | call is made.\n",
1013
+ " | \n",
1014
+ " | requires_grad_(self: ~T, requires_grad: bool = True) -> ~T\n",
1015
+ " | Change if autograd should record operations on parameters in this\n",
1016
+ " | module.\n",
1017
+ " | \n",
1018
+ " | This method sets the parameters' :attr:`requires_grad` attributes\n",
1019
+ " | in-place.\n",
1020
+ " | \n",
1021
+ " | This method is helpful for freezing part of the module for finetuning\n",
1022
+ " | or training parts of a model individually (e.g., GAN training).\n",
1023
+ " | \n",
1024
+ " | See :ref:`locally-disable-grad-doc` for a comparison between\n",
1025
+ " | `.requires_grad_()` and several similar mechanisms that may be confused with it.\n",
1026
+ " | \n",
1027
+ " | Args:\n",
1028
+ " | requires_grad (bool): whether autograd should record operations on\n",
1029
+ " | parameters in this module. Default: ``True``.\n",
1030
+ " | \n",
1031
+ " | Returns:\n",
1032
+ " | Module: self\n",
1033
+ " | \n",
1034
+ " | set_extra_state(self, state: Any)\n",
1035
+ " | This function is called from :func:`load_state_dict` to handle any extra state\n",
1036
+ " | found within the `state_dict`. Implement this function and a corresponding\n",
1037
+ " | :func:`get_extra_state` for your module if you need to store extra state within its\n",
1038
+ " | `state_dict`.\n",
1039
+ " | \n",
1040
+ " | Args:\n",
1041
+ " | state (dict): Extra state from the `state_dict`\n",
1042
+ " | \n",
1043
+ " | share_memory(self: ~T) -> ~T\n",
1044
+ " | See :meth:`torch.Tensor.share_memory_`\n",
1045
+ " | \n",
1046
+ " | state_dict(self, *args, destination=None, prefix='', keep_vars=False)\n",
1047
+ " | Returns a dictionary containing references to the whole state of the module.\n",
1048
+ " | \n",
1049
+ " | Both parameters and persistent buffers (e.g. running averages) are\n",
1050
+ " | included. Keys are corresponding parameter and buffer names.\n",
1051
+ " | Parameters and buffers set to ``None`` are not included.\n",
1052
+ " | \n",
1053
+ " | .. note::\n",
1054
+ " | The returned object is a shallow copy. It contains references\n",
1055
+ " | to the module's parameters and buffers.\n",
1056
+ " | \n",
1057
+ " | .. warning::\n",
1058
+ " | Currently ``state_dict()`` also accepts positional arguments for\n",
1059
+ " | ``destination``, ``prefix`` and ``keep_vars`` in order. However,\n",
1060
+ " | this is being deprecated and keyword arguments will be enforced in\n",
1061
+ " | future releases.\n",
1062
+ " | \n",
1063
+ " | .. warning::\n",
1064
+ " | Please avoid the use of argument ``destination`` as it is not\n",
1065
+ " | designed for end-users.\n",
1066
+ " | \n",
1067
+ " | Args:\n",
1068
+ " | destination (dict, optional): If provided, the state of module will\n",
1069
+ " | be updated into the dict and the same object is returned.\n",
1070
+ " | Otherwise, an ``OrderedDict`` will be created and returned.\n",
1071
+ " | Default: ``None``.\n",
1072
+ " | prefix (str, optional): a prefix added to parameter and buffer\n",
1073
+ " | names to compose the keys in state_dict. Default: ``''``.\n",
1074
+ " | keep_vars (bool, optional): by default the :class:`~torch.Tensor` s\n",
1075
+ " | returned in the state dict are detached from autograd. If it's\n",
1076
+ " | set to ``True``, detaching will not be performed.\n",
1077
+ " | Default: ``False``.\n",
1078
+ " | \n",
1079
+ " | Returns:\n",
1080
+ " | dict:\n",
1081
+ " | a dictionary containing a whole state of the module\n",
1082
+ " | \n",
1083
+ " | Example::\n",
1084
+ " | \n",
1085
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
1086
+ " | >>> module.state_dict().keys()\n",
1087
+ " | ['bias', 'weight']\n",
1088
+ " | \n",
1089
+ " | to(self, *args, **kwargs)\n",
1090
+ " | Moves and/or casts the parameters and buffers.\n",
1091
+ " | \n",
1092
+ " | This can be called as\n",
1093
+ " | \n",
1094
+ " | .. function:: to(device=None, dtype=None, non_blocking=False)\n",
1095
+ " | :noindex:\n",
1096
+ " | \n",
1097
+ " | .. function:: to(dtype, non_blocking=False)\n",
1098
+ " | :noindex:\n",
1099
+ " | \n",
1100
+ " | .. function:: to(tensor, non_blocking=False)\n",
1101
+ " | :noindex:\n",
1102
+ " | \n",
1103
+ " | .. function:: to(memory_format=torch.channels_last)\n",
1104
+ " | :noindex:\n",
1105
+ " | \n",
1106
+ " | Its signature is similar to :meth:`torch.Tensor.to`, but only accepts\n",
1107
+ " | floating point or complex :attr:`dtype`\\ s. In addition, this method will\n",
1108
+ " | only cast the floating point or complex parameters and buffers to :attr:`dtype`\n",
1109
+ " | (if given). The integral parameters and buffers will be moved\n",
1110
+ " | :attr:`device`, if that is given, but with dtypes unchanged. When\n",
1111
+ " | :attr:`non_blocking` is set, it tries to convert/move asynchronously\n",
1112
+ " | with respect to the host if possible, e.g., moving CPU Tensors with\n",
1113
+ " | pinned memory to CUDA devices.\n",
1114
+ " | \n",
1115
+ " | See below for examples.\n",
1116
+ " | \n",
1117
+ " | .. note::\n",
1118
+ " | This method modifies the module in-place.\n",
1119
+ " | \n",
1120
+ " | Args:\n",
1121
+ " | device (:class:`torch.device`): the desired device of the parameters\n",
1122
+ " | and buffers in this module\n",
1123
+ " | dtype (:class:`torch.dtype`): the desired floating point or complex dtype of\n",
1124
+ " | the parameters and buffers in this module\n",
1125
+ " | tensor (torch.Tensor): Tensor whose dtype and device are the desired\n",
1126
+ " | dtype and device for all parameters and buffers in this module\n",
1127
+ " | memory_format (:class:`torch.memory_format`): the desired memory\n",
1128
+ " | format for 4D parameters and buffers in this module (keyword\n",
1129
+ " | only argument)\n",
1130
+ " | \n",
1131
+ " | Returns:\n",
1132
+ " | Module: self\n",
1133
+ " | \n",
1134
+ " | Examples::\n",
1135
+ " | \n",
1136
+ " | >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n",
1137
+ " | >>> linear = nn.Linear(2, 2)\n",
1138
+ " | >>> linear.weight\n",
1139
+ " | Parameter containing:\n",
1140
+ " | tensor([[ 0.1913, -0.3420],\n",
1141
+ " | [-0.5113, -0.2325]])\n",
1142
+ " | >>> linear.to(torch.double)\n",
1143
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1144
+ " | >>> linear.weight\n",
1145
+ " | Parameter containing:\n",
1146
+ " | tensor([[ 0.1913, -0.3420],\n",
1147
+ " | [-0.5113, -0.2325]], dtype=torch.float64)\n",
1148
+ " | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)\n",
1149
+ " | >>> gpu1 = torch.device(\"cuda:1\")\n",
1150
+ " | >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)\n",
1151
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1152
+ " | >>> linear.weight\n",
1153
+ " | Parameter containing:\n",
1154
+ " | tensor([[ 0.1914, -0.3420],\n",
1155
+ " | [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')\n",
1156
+ " | >>> cpu = torch.device(\"cpu\")\n",
1157
+ " | >>> linear.to(cpu)\n",
1158
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1159
+ " | >>> linear.weight\n",
1160
+ " | Parameter containing:\n",
1161
+ " | tensor([[ 0.1914, -0.3420],\n",
1162
+ " | [-0.5112, -0.2324]], dtype=torch.float16)\n",
1163
+ " | \n",
1164
+ " | >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)\n",
1165
+ " | >>> linear.weight\n",
1166
+ " | Parameter containing:\n",
1167
+ " | tensor([[ 0.3741+0.j, 0.2382+0.j],\n",
1168
+ " | [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)\n",
1169
+ " | >>> linear(torch.ones(3, 2, dtype=torch.cdouble))\n",
1170
+ " | tensor([[0.6122+0.j, 0.1150+0.j],\n",
1171
+ " | [0.6122+0.j, 0.1150+0.j],\n",
1172
+ " | [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)\n",
1173
+ " | \n",
1174
+ " | to_empty(self: ~T, *, device: Union[str, torch.device]) -> ~T\n",
1175
+ " | Moves the parameters and buffers to the specified device without copying storage.\n",
1176
+ " | \n",
1177
+ " | Args:\n",
1178
+ " | device (:class:`torch.device`): The desired device of the parameters\n",
1179
+ " | and buffers in this module.\n",
1180
+ " | \n",
1181
+ " | Returns:\n",
1182
+ " | Module: self\n",
1183
+ " | \n",
1184
+ " | train(self: ~T, mode: bool = True) -> ~T\n",
1185
+ " | Sets the module in training mode.\n",
1186
+ " | \n",
1187
+ " | This has any effect only on certain modules. See documentations of\n",
1188
+ " | particular modules for details of their behaviors in training/evaluation\n",
1189
+ " | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
1190
+ " | etc.\n",
1191
+ " | \n",
1192
+ " | Args:\n",
1193
+ " | mode (bool): whether to set training mode (``True``) or evaluation\n",
1194
+ " | mode (``False``). Default: ``True``.\n",
1195
+ " | \n",
1196
+ " | Returns:\n",
1197
+ " | Module: self\n",
1198
+ " | \n",
1199
+ " | type(self: ~T, dst_type: Union[torch.dtype, str]) -> ~T\n",
1200
+ " | Casts all parameters and buffers to :attr:`dst_type`.\n",
1201
+ " | \n",
1202
+ " | .. note::\n",
1203
+ " | This method modifies the module in-place.\n",
1204
+ " | \n",
1205
+ " | Args:\n",
1206
+ " | dst_type (type or string): the desired type\n",
1207
+ " | \n",
1208
+ " | Returns:\n",
1209
+ " | Module: self\n",
1210
+ " | \n",
1211
+ " | xpu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
1212
+ " | Moves all model parameters and buffers to the XPU.\n",
1213
+ " | \n",
1214
+ " | This also makes associated parameters and buffers different objects. So\n",
1215
+ " | it should be called before constructing optimizer if the module will\n",
1216
+ " | live on XPU while being optimized.\n",
1217
+ " | \n",
1218
+ " | .. note::\n",
1219
+ " | This method modifies the module in-place.\n",
1220
+ " | \n",
1221
+ " | Arguments:\n",
1222
+ " | device (int, optional): if specified, all parameters will be\n",
1223
+ " | copied to that device\n",
1224
+ " | \n",
1225
+ " | Returns:\n",
1226
+ " | Module: self\n",
1227
+ " | \n",
1228
+ " | zero_grad(self, set_to_none: bool = True) -> None\n",
1229
+ " | Sets gradients of all model parameters to zero. See similar function\n",
1230
+ " | under :class:`torch.optim.Optimizer` for more context.\n",
1231
+ " | \n",
1232
+ " | Args:\n",
1233
+ " | set_to_none (bool): instead of setting to zero, set the grads to None.\n",
1234
+ " | See :meth:`torch.optim.Optimizer.zero_grad` for details.\n",
1235
+ " | \n",
1236
+ " | ----------------------------------------------------------------------\n",
1237
+ " | Data descriptors inherited from torch.nn.modules.module.Module:\n",
1238
+ " | \n",
1239
+ " | __dict__\n",
1240
+ " | dictionary for instance variables (if defined)\n",
1241
+ " | \n",
1242
+ " | __weakref__\n",
1243
+ " | list of weak references to the object (if defined)\n",
1244
+ " | \n",
1245
+ " | ----------------------------------------------------------------------\n",
1246
+ " | Data and other attributes inherited from torch.nn.modules.module.Module:\n",
1247
+ " | \n",
1248
+ " | T_destination = ~T_destination\n",
1249
+ " | \n",
1250
+ " | call_super_init = False\n",
1251
+ " | \n",
1252
+ " | dump_patches = False\n",
1253
+ "\n"
1254
+ ]
1255
+ }
1256
+ ],
1257
+ "source": [
1258
+ "help(eva02_model)"
1259
+ ]
1260
+ },
1261
+ {
1262
+ "cell_type": "markdown",
1263
+ "id": "2f5ac1a7-6f1b-4417-8a67-1b2e32d385dd",
1264
+ "metadata": {},
1265
+ "source": [
1266
+ "# DETR"
1267
+ ]
1268
+ },
1269
+ {
1270
+ "cell_type": "code",
1271
+ "execution_count": 33,
1272
+ "id": "5c3ade1b-18ea-4368-abd9-53be1fdfb610",
1273
+ "metadata": {},
1274
+ "outputs": [
1275
+ {
1276
+ "name": "stdout",
1277
+ "output_type": "stream",
1278
+ "text": [
1279
+ "[2023-08-28 01:51:14,033] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
1280
+ ]
1281
+ },
1282
+ {
1283
+ "name": "stderr",
1284
+ "output_type": "stream",
1285
+ "text": [
1286
+ "The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.\n"
1287
+ ]
1288
+ }
1289
+ ],
1290
+ "source": [
1291
+ "from transformers import DetrImageProcessor, DetrForObjectDetection\n",
1292
+ "import torch\n",
1293
+ "from PIL import Image\n",
1294
+ "import requests\n",
1295
+ "\n",
1296
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
1297
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
1298
+ "\n",
1299
+ "processor = DetrImageProcessor.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')\n",
1300
+ "model = DetrForObjectDetection.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')"
1301
+ ]
1302
+ },
1303
+ {
1304
+ "cell_type": "code",
1305
+ "execution_count": 34,
1306
+ "id": "1d5aa2d7-4868-4751-8d90-7c52be028cd9",
1307
+ "metadata": {},
1308
+ "outputs": [],
1309
+ "source": [
1310
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
1311
+ "outputs = model(**inputs)"
1312
+ ]
1313
+ },
1314
+ {
1315
+ "cell_type": "code",
1316
+ "execution_count": 35,
1317
+ "id": "ae6bafc6-cee4-4e59-b7ba-12efc2a65b74",
1318
+ "metadata": {},
1319
+ "outputs": [
1320
+ {
1321
+ "name": "stdout",
1322
+ "output_type": "stream",
1323
+ "text": [
1324
+ "Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]\n",
1325
+ "Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]\n",
1326
+ "Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]\n",
1327
+ "Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]\n",
1328
+ "Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]\n"
1329
+ ]
1330
+ }
1331
+ ],
1332
+ "source": [
1333
+ "# convert outputs (bounding boxes and class logits) to COCO API\n",
1334
+ "# let's only keep detections with score > 0.9\n",
1335
+ "target_sizes = torch.tensor([image.size[::-1]])\n",
1336
+ "results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]\n",
1337
+ "\n",
1338
+ "for score, label, box in zip(results[\"scores\"], results[\"labels\"], results[\"boxes\"]):\n",
1339
+ " box = [round(i, 2) for i in box.tolist()]\n",
1340
+ " print(\n",
1341
+ " f\"Detected {model.config.id2label[label.item()]} with confidence \"\n",
1342
+ " f\"{round(score.item(), 3)} at location {box}\"\n",
1343
+ " )"
1344
+ ]
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": 36,
1349
+ "id": "6dcc5934-79d4-4062-8b32-e42b3ebcdc0f",
1350
+ "metadata": {},
1351
+ "outputs": [
1352
+ {
1353
+ "data": {
1354
+ "text/plain": [
1355
+ "DetrImageProcessor {\n",
1356
+ " \"do_normalize\": true,\n",
1357
+ " \"do_pad\": true,\n",
1358
+ " \"do_rescale\": true,\n",
1359
+ " \"do_resize\": true,\n",
1360
+ " \"feature_extractor_type\": \"DetrFeatureExtractor\",\n",
1361
+ " \"format\": \"coco_detection\",\n",
1362
+ " \"image_mean\": [\n",
1363
+ " 0.485,\n",
1364
+ " 0.456,\n",
1365
+ " 0.406\n",
1366
+ " ],\n",
1367
+ " \"image_processor_type\": \"DetrImageProcessor\",\n",
1368
+ " \"image_std\": [\n",
1369
+ " 0.229,\n",
1370
+ " 0.224,\n",
1371
+ " 0.225\n",
1372
+ " ],\n",
1373
+ " \"resample\": 2,\n",
1374
+ " \"rescale_factor\": 0.00392156862745098,\n",
1375
+ " \"size\": {\n",
1376
+ " \"longest_edge\": 1333,\n",
1377
+ " \"shortest_edge\": 800\n",
1378
+ " }\n",
1379
+ "}"
1380
+ ]
1381
+ },
1382
+ "execution_count": 36,
1383
+ "metadata": {},
1384
+ "output_type": "execute_result"
1385
+ }
1386
+ ],
1387
+ "source": [
1388
+ "processor"
1389
+ ]
1390
+ },
1391
+ {
1392
+ "cell_type": "markdown",
1393
+ "id": "db1d89cc-b432-473e-af69-d81c435ac731",
1394
+ "metadata": {},
1395
+ "source": [
1396
+ "# CLIPSeg"
1397
+ ]
1398
+ },
1399
+ {
1400
+ "cell_type": "code",
1401
+ "execution_count": 37,
1402
+ "id": "15db14d1-ee4d-4429-9286-054c4498293b",
1403
+ "metadata": {},
1404
+ "outputs": [],
1405
+ "source": [
1406
+ "from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n",
1407
+ "\n",
1408
+ "processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')\n",
1409
+ "model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')"
1410
+ ]
1411
+ },
1412
+ {
1413
+ "cell_type": "code",
1414
+ "execution_count": 38,
1415
+ "id": "4aa225d4-5a3b-4dbb-ae57-dea2872ff492",
1416
+ "metadata": {},
1417
+ "outputs": [
1418
+ {
1419
+ "ename": "AttributeError",
1420
+ "evalue": "'JpegImageFile' object has no attribute 'shape'",
1421
+ "output_type": "error",
1422
+ "traceback": [
1423
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1424
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1425
+ "Cell \u001b[0;32mIn[38], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mimage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\n",
1426
+ "\u001b[0;31mAttributeError\u001b[0m: 'JpegImageFile' object has no attribute 'shape'"
1427
+ ]
1428
+ }
1429
+ ],
1430
+ "source": [
1431
+ "image.shape"
1432
+ ]
1433
+ },
1434
+ {
1435
+ "cell_type": "code",
1436
+ "execution_count": null,
1437
+ "id": "ad7e2daf-0c7c-4fec-b29e-9ba47a037c6b",
1438
+ "metadata": {},
1439
+ "outputs": [],
1440
+ "source": [
1441
+ "from PIL import Image\n",
1442
+ "import requests\n",
1443
+ "import h5py\n",
1444
+ "\n",
1445
+ "# url = \"https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640\"\n",
1446
+ "# image = Image.open(requests.get(url, stream=True).raw)\n",
1447
+ "\n",
1448
+ "image_path = \"/fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\"\n",
1449
+ "with h5py.File(image_path, 'r') as file:\n",
1450
+ " image = file['images'][0]\n",
1451
+ "image = np.moveaxis(image, 0, -1).astype(np.float32)\n",
1452
+ "plt.imshow(image)\n",
1453
+ "\n",
1454
+ "prompts = [\"person\",\"animal\",\"object\",\"background\"]\n",
1455
+ "import torch\n",
1456
+ "\n",
1457
+ "# Rescale to [0, 255]\n",
1458
+ "array = (image * 255).astype(np.uint8)\n",
1459
+ "\n",
1460
+ "# Convert to PIL image\n",
1461
+ "image = Image.fromarray(array)\n",
1462
+ "\n",
1463
+ "inputs = processor(text=prompts, images=[image] * len(prompts), padding=\"max_length\", return_tensors=\"pt\")\n",
1464
+ "# predict\n",
1465
+ "with torch.no_grad():\n",
1466
+ " outputs = model(**inputs)\n",
1467
+ "preds = outputs.logits.unsqueeze(1)\n",
1468
+ "print(preds.shape)"
1469
+ ]
1470
+ },
1471
+ {
1472
+ "cell_type": "code",
1473
+ "execution_count": null,
1474
+ "id": "131eb5b7-2f16-4a79-8402-edc1a1d8c348",
1475
+ "metadata": {},
1476
+ "outputs": [],
1477
+ "source": [
1478
+ "preds = ((preds[0] + preds[1] + preds[2] + preds[-1].max() - preds[-1]) / 4)[None]\n",
1479
+ "preds.shape"
1480
+ ]
1481
+ },
1482
+ {
1483
+ "cell_type": "code",
1484
+ "execution_count": null,
1485
+ "id": "e2bf99e7-064d-4c22-997f-aa1a35dbab82",
1486
+ "metadata": {},
1487
+ "outputs": [],
1488
+ "source": [
1489
+ "_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))\n",
1490
+ "[a.axis('off') for a in ax.flatten()]\n",
1491
+ "ax[0].imshow(image)\n",
1492
+ "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(1)];\n",
1493
+ "# [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];"
1494
+ ]
1495
+ },
1496
+ {
1497
+ "cell_type": "code",
1498
+ "execution_count": null,
1499
+ "id": "b58b926f-a2b2-423b-b367-18808cf6b4f7",
1500
+ "metadata": {},
1501
+ "outputs": [],
1502
+ "source": []
1503
+ }
1504
+ ],
1505
+ "metadata": {
1506
+ "kernelspec": {
1507
+ "display_name": "Python 3 (ipykernel)",
1508
+ "language": "python",
1509
+ "name": "python3"
1510
+ },
1511
+ "language_info": {
1512
+ "codemirror_mode": {
1513
+ "name": "ipython",
1514
+ "version": 3
1515
+ },
1516
+ "file_extension": ".py",
1517
+ "mimetype": "text/x-python",
1518
+ "name": "python",
1519
+ "nbconvert_exporter": "python",
1520
+ "pygments_lexer": "ipython3",
1521
+ "version": "3.10.8"
1522
+ }
1523
+ },
1524
+ "nbformat": 4,
1525
+ "nbformat_minor": 5
1526
+ }
src/cnd_prov/cnd_prov-Copy1.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #new one
2
+ import pandas as pd
3
+ import datasets
4
+ import os
5
+ import pickle
6
+ import numpy as np
7
+ from PIL import Image
8
+ import matplotlib
9
+
10
+ _VERSION = datasets.Version("0.0.3")
11
+
12
+ _DESCRIPTION = "TODO"
13
+ _HOMEPAGE = "TODO"
14
+ _LICENSE = "TODO"
15
+ _CITATION = "TODO"
16
+
17
+ _FEATURES = datasets.Features(
18
+ {
19
+ "target": datasets.Image(),
20
+ "source": datasets.Image(),
21
+ "heatmap": datasets.Image(),
22
+ "depth": datasets.Image(),
23
+ "prompt": datasets.Value("string"),
24
+ },
25
+ )
26
+
27
+ METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl"
28
+ SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source"
29
+ TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target"
30
+ HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg"
31
+ DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth"
32
+
33
+ # METADATA_URL = hf_hub_url(
34
+ # "fusing/fill50k",
35
+ # filename="train.jsonl",
36
+ # repo_type="dataset",
37
+ # )
38
+
39
+ # IMAGES_URL = hf_hub_url(
40
+ # "fusing/fill50k",
41
+ # filename="images.zip",
42
+ # repo_type="dataset",
43
+ # )
44
+
45
+ # CONDITIONING_IMAGES_URL = hf_hub_url(
46
+ # "fusing/fill50k",
47
+ # filename="conditioning_images.zip",
48
+ # repo_type="dataset",
49
+ # )
50
+
51
+ _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
52
+
53
+
54
+ class CocoTest(datasets.GeneratorBasedBuilder):
55
+ BUILDER_CONFIGS = [_DEFAULT_CONFIG]
56
+ DEFAULT_CONFIG_NAME = "default"
57
+
58
+ def _info(self):
59
+ return datasets.DatasetInfo(
60
+ description=_DESCRIPTION,
61
+ features=_FEATURES,
62
+ supervised_keys=None,
63
+ homepage=_HOMEPAGE,
64
+ license=_LICENSE,
65
+ citation=_CITATION,
66
+ )
67
+
68
+ def _split_generators(self, dl_manager):
69
+ metadata_path = METADATA_DIR
70
+ target_dir = TARGET_DIR
71
+ source_dir = SOURCE_DIR
72
+ heatmap_dir = HEATMAP_DIR
73
+ depth_dir = DEPTH_DIR
74
+
75
+ return [
76
+ datasets.SplitGenerator(
77
+ name=datasets.Split.TRAIN,
78
+ # These kwargs will be passed to _generate_examples
79
+ gen_kwargs={
80
+ "metadata_path": metadata_path,
81
+ "target_dir": TARGET_DIR,
82
+ "source_dir": SOURCE_DIR,
83
+ "heatmap_dir": HEATMAP_DIR,
84
+ "depth_dir": DEPTH_DIR,
85
+ "num_examples": 120,
86
+ },
87
+ ),
88
+ datasets.SplitGenerator(
89
+ name=datasets.Split.VALIDATION,
90
+ # These kwargs will be passed to _generate_examples
91
+ gen_kwargs={
92
+ "metadata_path": metadata_path,
93
+ "target_dir": TARGET_DIR,
94
+ "source_dir": SOURCE_DIR,
95
+ "heatmap_dir": HEATMAP_DIR,
96
+ "depth_dir": DEPTH_DIR,
97
+ "num_examples": 120,
98
+ },
99
+ ),
100
+ ]
101
+
102
+
103
+
104
+ def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples):
105
+ data = []
106
+ with open(metadata_path, 'rb') as f:
107
+ loaded_data = pickle.load(f)
108
+ for line in loaded_data[:num_examples]:
109
+ data.append(line)
110
+
111
+ for _, item in enumerate(data):
112
+ source_filename = item['source']
113
+ target_filename = item['target']
114
+ heatmap_filename = item['h_map']
115
+ depth_filename = item['depth']
116
+ prompt = item['prompt']
117
+
118
+
119
+
120
+ tgt_img = open(target_filename, "rb").read()
121
+ src_img = open(source_filename, "rb").read()
122
+ h_img = open(heatmap_filename, "rb").read()
123
+
124
+
125
+ # one channel depth image
126
+ d_img = Image.open(depth_filename).convert('1')
127
+
128
+
129
+
130
+ yield item["target"], {
131
+ "prompt": prompt,
132
+ "target": {
133
+ "path": target_filename,
134
+ "bytes": tgt_img,
135
+ },
136
+ "source": {
137
+ "path": source_filename,
138
+ "bytes": src_img,
139
+ },
140
+ "heatmap": {
141
+ "path": heatmap_filename,
142
+ "bytes": h_img,
143
+ },
144
+ "depth": {
145
+ "path": depth_filename,
146
+ "bytes": d_img,
147
+ },
148
+ }
src/cnd_prov/cnd_prov.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #new one
2
+ import pandas as pd
3
+ import datasets
4
+ import os
5
+ import pickle
6
+
7
+ _VERSION = datasets.Version("0.0.3")
8
+
9
+ _DESCRIPTION = "TODO"
10
+ _HOMEPAGE = "TODO"
11
+ _LICENSE = "TODO"
12
+ _CITATION = "TODO"
13
+
14
+ _FEATURES = datasets.Features(
15
+ {
16
+ "target": datasets.Image(),
17
+ "source": datasets.Image(),
18
+ "heatmap": datasets.Image(),
19
+ "depth": datasets.Image(),
20
+ "prompt": datasets.Value("string"),
21
+ },
22
+ )
23
+
24
+ METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl"
25
+ SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source"
26
+ TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target"
27
+ HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg"
28
+ DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth"
29
+
30
+ # METADATA_URL = hf_hub_url(
31
+ # "fusing/fill50k",
32
+ # filename="train.jsonl",
33
+ # repo_type="dataset",
34
+ # )
35
+
36
+ # IMAGES_URL = hf_hub_url(
37
+ # "fusing/fill50k",
38
+ # filename="images.zip",
39
+ # repo_type="dataset",
40
+ # )
41
+
42
+ # CONDITIONING_IMAGES_URL = hf_hub_url(
43
+ # "fusing/fill50k",
44
+ # filename="conditioning_images.zip",
45
+ # repo_type="dataset",
46
+ # )
47
+
48
+ _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
49
+
50
+
51
+ class CocoTest(datasets.GeneratorBasedBuilder):
52
+ BUILDER_CONFIGS = [_DEFAULT_CONFIG]
53
+ DEFAULT_CONFIG_NAME = "default"
54
+
55
+ def _info(self):
56
+ return datasets.DatasetInfo(
57
+ description=_DESCRIPTION,
58
+ features=_FEATURES,
59
+ supervised_keys=None,
60
+ homepage=_HOMEPAGE,
61
+ license=_LICENSE,
62
+ citation=_CITATION,
63
+ )
64
+
65
+ def _split_generators(self, dl_manager):
66
+ metadata_path = METADATA_DIR
67
+ target_dir = TARGET_DIR
68
+ source_dir = SOURCE_DIR
69
+ heatmap_dir = HEATMAP_DIR
70
+ depth_dir = DEPTH_DIR
71
+
72
+ return [
73
+ datasets.SplitGenerator(
74
+ name=datasets.Split.TRAIN,
75
+ # These kwargs will be passed to _generate_examples
76
+ gen_kwargs={
77
+ "metadata_path": metadata_path,
78
+ "target_dir": TARGET_DIR,
79
+ "source_dir": SOURCE_DIR,
80
+ "heatmap_dir": HEATMAP_DIR,
81
+ "depth_dir": DEPTH_DIR,
82
+ "num_examples": 190573,
83
+ },
84
+ ),
85
+ datasets.SplitGenerator(
86
+ name=datasets.Split.VALIDATION,
87
+ # These kwargs will be passed to _generate_examples
88
+ gen_kwargs={
89
+ "metadata_path": metadata_path,
90
+ "target_dir": TARGET_DIR,
91
+ "source_dir": SOURCE_DIR,
92
+ "heatmap_dir": HEATMAP_DIR,
93
+ "depth_dir": DEPTH_DIR,
94
+ "num_examples": 20000,
95
+ },
96
+ ),
97
+ ]
98
+
99
+ def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples):
100
+ data = []
101
+ with open(metadata_path, 'rb') as f:
102
+ loaded_data = pickle.load(f)
103
+ for line in loaded_data[:num_examples]:
104
+ data.append(line)
105
+
106
+ for _, item in enumerate(data):
107
+ source_filename = item['source']
108
+ target_filename = item['target']
109
+ heatmap_filename = item['h_map']
110
+ depth_filename = item['depth']
111
+ prompt = item['prompt']
112
+
113
+
114
+
115
+ tgt_img = open(target_filename, "rb").read()
116
+ src_img = open(source_filename, "rb").read()
117
+ h_img = open(heatmap_filename, "rb").read()
118
+ d_img = open(depth_filename, "rb").read()
119
+
120
+ yield item["target"], {
121
+ "prompt": prompt,
122
+ "target": {
123
+ "path": target_filename,
124
+ "bytes": tgt_img,
125
+ },
126
+ "source": {
127
+ "path": source_filename,
128
+ "bytes": src_img,
129
+ },
130
+ "heatmap": {
131
+ "path": heatmap_filename,
132
+ "bytes": h_img,
133
+ },
134
+ "depth": {
135
+ "path": depth_filename,
136
+ "bytes": d_img,
137
+ },
138
+ }
src/cnd_prov/data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bab2a7c32b2da8c899f18015e17b88a876a3baa6f676a315d47e07c8713d313
3
+ size 58375712
src/deepspeed_config_stage1.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 1, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "none", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 8, "train_micro_batch_size_per_gpu": 8, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
src/deepspeed_config_stage2.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 2, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "none", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 128, "train_micro_batch_size_per_gpu": 128, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
src/deepspeed_config_stage3.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": false
4
+ },
5
+ "fp16": {
6
+ "enabled": true
7
+ },
8
+ "zero_optimization": {
9
+ "stage": 3,
10
+ "contiguous_gradients": true,
11
+ "stage3_gather_16bit_weights_on_model_save": true,
12
+ "stage3_max_live_parameters": 1e9,
13
+ "stage3_max_reuse_distance": 1e9,
14
+ "stage3_prefetch_bucket_size": 1e7,
15
+ "stage3_param_persistence_threshold": 1e5,
16
+ "reduce_bucket_size": 1e7,
17
+ "sub_group_size": 1e9,
18
+ "offload_optimizer": {
19
+ "device": "none",
20
+ "nvme_path": "/scratch",
21
+ "pin_memory": true
22
+ },
23
+ "offload_param": {
24
+ "device": "none",
25
+ "nvme_path": "/scratch",
26
+ "buffer_size": 4e9,
27
+ "pin_memory": true
28
+ }
29
+ },
30
+ "aio": {
31
+ "block_size": 26214400,
32
+ "queue_depth": 32,
33
+ "thread_count": 1,
34
+ "single_submit": false,
35
+ "overlap_events": true
36
+ },
37
+ "gradient_accumulation_steps": 1,
38
+ "gradient_clipping": 1.0,
39
+ "steps_per_print": 20000,
40
+ "train_batch_size": 128,
41
+ "train_micro_batch_size_per_gpu": 16,
42
+ "wall_clock_breakdown": false,
43
+ "zero_allow_untested_optimizer": true
44
+ }
src/getdepthimages.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/huggingface_to_s3.ipynb ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cf698d59-1cc2-4859-9c43-9a5d4d924ee1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Transfer huggingface mindeyev2 dataset to Stability aws s3"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "94c7404c-7a0f-4508-a630-954bc9af11fa",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/shared1000.npy -O /fsx/proj-fmri/shared/mindeyev2_dataset/shared1000.npy\n",
22
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj01.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj01.hdf5\n",
23
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj02.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj02.hdf5\n",
24
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj03.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj03.hdf5\n",
25
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj04.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj04.hdf5\n",
26
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj05.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj05.hdf5\n",
27
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj06.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj06.hdf5\n",
28
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj07.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj07.hdf5\n",
29
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj08.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj08.hdf5\n",
30
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/coco_images_224_float16.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\n",
31
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/COCO_73k_subj_indices.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_subj_indices.hdf5\n",
32
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/0.tar\n",
33
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/1.tar\n",
34
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/2.tar\n",
35
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/3.tar\n",
36
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/4.tar\n",
37
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/5.tar\n",
38
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/6.tar\n",
39
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/7.tar\n",
40
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/8.tar\n",
41
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/9.tar\n",
42
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/10.tar\n",
43
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/11.tar\n",
44
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/12.tar\n",
45
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/13.tar\n",
46
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/14.tar\n",
47
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/15.tar\n",
48
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/16.tar\n",
49
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/17.tar\n",
50
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/18.tar\n",
51
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/19.tar\n",
52
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/20.tar\n",
53
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/21.tar\n",
54
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/22.tar\n",
55
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/23.tar\n",
56
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/24.tar\n",
57
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/25.tar\n",
58
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/26.tar\n",
59
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/27.tar\n",
60
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/28.tar\n",
61
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/29.tar\n",
62
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/30.tar\n",
63
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/31.tar\n",
64
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/32.tar\n",
65
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/33.tar\n",
66
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/34.tar\n",
67
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/35.tar\n",
68
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/36.tar\n",
69
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n",
70
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/0.tar\n",
71
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/1.tar\n",
72
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/2.tar\n",
73
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/3.tar\n",
74
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/4.tar\n",
75
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/5.tar\n",
76
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/6.tar\n",
77
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/7.tar\n",
78
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/8.tar\n",
79
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/9.tar\n",
80
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/10.tar\n",
81
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/11.tar\n",
82
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/12.tar\n",
83
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/13.tar\n",
84
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/14.tar\n",
85
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/15.tar\n",
86
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/16.tar\n",
87
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/17.tar\n",
88
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/18.tar\n",
89
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/19.tar\n",
90
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/20.tar\n",
91
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/21.tar\n",
92
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/22.tar\n",
93
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/23.tar\n",
94
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/24.tar\n",
95
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/25.tar\n",
96
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/26.tar\n",
97
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/27.tar\n",
98
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/28.tar\n",
99
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/29.tar\n",
100
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/30.tar\n",
101
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/31.tar\n",
102
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/32.tar\n",
103
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/33.tar\n",
104
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/34.tar\n",
105
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/35.tar\n",
106
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/36.tar\n",
107
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/test/0.tar\n",
108
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/0.tar\n",
109
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/1.tar\n",
110
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/2.tar\n",
111
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/3.tar\n",
112
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/4.tar\n",
113
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/5.tar\n",
114
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/6.tar\n",
115
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/7.tar\n",
116
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/8.tar\n",
117
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/9.tar\n",
118
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/10.tar\n",
119
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/11.tar\n",
120
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/12.tar\n",
121
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/13.tar\n",
122
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/14.tar\n",
123
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/15.tar\n",
124
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/16.tar\n",
125
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/17.tar\n",
126
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/18.tar\n",
127
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/19.tar\n",
128
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/20.tar\n",
129
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/21.tar\n",
130
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/22.tar\n",
131
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/23.tar\n",
132
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/24.tar\n",
133
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/25.tar\n",
134
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/26.tar\n",
135
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/27.tar\n",
136
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/28.tar\n",
137
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/29.tar\n",
138
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/30.tar\n",
139
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/31.tar\n",
140
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/32.tar\n",
141
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/33.tar\n",
142
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/34.tar\n",
143
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/35.tar\n",
144
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/36.tar\n",
145
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/test/0.tar\n",
146
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/0.tar\n",
147
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/1.tar\n",
148
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/2.tar\n",
149
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/3.tar\n",
150
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/4.tar\n",
151
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/5.tar\n",
152
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/6.tar\n",
153
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/7.tar\n",
154
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/8.tar\n",
155
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/9.tar\n",
156
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/10.tar\n",
157
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/11.tar\n",
158
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/12.tar\n",
159
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/13.tar\n",
160
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/14.tar\n",
161
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/15.tar\n",
162
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/16.tar\n",
163
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/17.tar\n",
164
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/18.tar\n",
165
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/19.tar\n",
166
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/20.tar\n",
167
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/21.tar\n",
168
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/22.tar\n",
169
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/23.tar\n",
170
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/24.tar\n",
171
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/25.tar\n",
172
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/26.tar\n",
173
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/27.tar\n",
174
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/28.tar\n",
175
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/29.tar\n",
176
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/30.tar\n",
177
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/31.tar\n",
178
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/32.tar\n",
179
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/33.tar\n",
180
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/34.tar\n",
181
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/35.tar\n",
182
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/36.tar\n",
183
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/test/0.tar\n",
184
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/0.tar\n",
185
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/1.tar\n",
186
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/2.tar\n",
187
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/3.tar\n",
188
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/4.tar\n",
189
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/5.tar\n",
190
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/6.tar\n",
191
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/7.tar\n",
192
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/8.tar\n",
193
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/9.tar\n",
194
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/10.tar\n",
195
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/11.tar\n",
196
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/12.tar\n",
197
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/13.tar\n",
198
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/14.tar\n",
199
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/15.tar\n",
200
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/16.tar\n",
201
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/17.tar\n",
202
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/18.tar\n",
203
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/19.tar\n",
204
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/20.tar\n",
205
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/21.tar\n",
206
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/22.tar\n",
207
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/23.tar\n",
208
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/24.tar\n",
209
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/25.tar\n",
210
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/26.tar\n",
211
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/27.tar\n",
212
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/28.tar\n",
213
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/29.tar\n",
214
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/30.tar\n",
215
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/31.tar\n",
216
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/32.tar\n",
217
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/33.tar\n",
218
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/34.tar\n",
219
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/35.tar\n",
220
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/36.tar\n",
221
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/test/0.tar\n",
222
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/0.tar\n",
223
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/1.tar\n",
224
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/2.tar\n",
225
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/3.tar\n",
226
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/4.tar\n",
227
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/5.tar\n",
228
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/6.tar\n",
229
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/7.tar\n",
230
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/8.tar\n",
231
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/9.tar\n",
232
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/10.tar\n",
233
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/11.tar\n",
234
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/12.tar\n",
235
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/13.tar\n",
236
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/14.tar\n",
237
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/15.tar\n",
238
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/16.tar\n",
239
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/17.tar\n",
240
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/18.tar\n",
241
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/19.tar\n",
242
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/20.tar\n",
243
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/21.tar\n",
244
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/22.tar\n",
245
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/23.tar\n",
246
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/24.tar\n",
247
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/25.tar\n",
248
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/26.tar\n",
249
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/27.tar\n",
250
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/28.tar\n",
251
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/29.tar\n",
252
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/30.tar\n",
253
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/31.tar\n",
254
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/32.tar\n",
255
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/33.tar\n",
256
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/34.tar\n",
257
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/35.tar\n",
258
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/36.tar\n",
259
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/test/0.tar\n",
260
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/0.tar\n",
261
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/1.tar\n",
262
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/2.tar\n",
263
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/3.tar\n",
264
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/4.tar\n",
265
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/5.tar\n",
266
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/6.tar\n",
267
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/7.tar\n",
268
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/8.tar\n",
269
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/9.tar\n",
270
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/10.tar\n",
271
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/11.tar\n",
272
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/12.tar\n",
273
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/13.tar\n",
274
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/14.tar\n",
275
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/15.tar\n",
276
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/16.tar\n",
277
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/17.tar\n",
278
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/18.tar\n",
279
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/19.tar\n",
280
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/20.tar\n",
281
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/21.tar\n",
282
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/22.tar\n",
283
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/23.tar\n",
284
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/24.tar\n",
285
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/25.tar\n",
286
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/26.tar\n",
287
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/27.tar\n",
288
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/28.tar\n",
289
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/29.tar\n",
290
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/30.tar\n",
291
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/31.tar\n",
292
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/32.tar\n",
293
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/33.tar\n",
294
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/34.tar\n",
295
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/35.tar\n",
296
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/36.tar\n",
297
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/test/0.tar\n",
298
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/0.tar\n",
299
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/1.tar\n",
300
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/2.tar\n",
301
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/3.tar\n",
302
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/4.tar\n",
303
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/5.tar\n",
304
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/6.tar\n",
305
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/7.tar\n",
306
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/8.tar\n",
307
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/9.tar\n",
308
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/10.tar\n",
309
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/11.tar\n",
310
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/12.tar\n",
311
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/13.tar\n",
312
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/14.tar\n",
313
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/15.tar\n",
314
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/16.tar\n",
315
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/17.tar\n",
316
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/18.tar\n",
317
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/19.tar\n",
318
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/20.tar\n",
319
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/21.tar\n",
320
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/22.tar\n",
321
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/23.tar\n",
322
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/24.tar\n",
323
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/25.tar\n",
324
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/26.tar\n",
325
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/27.tar\n",
326
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/28.tar\n",
327
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/29.tar\n",
328
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/30.tar\n",
329
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/31.tar\n",
330
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/32.tar\n",
331
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/33.tar\n",
332
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/34.tar\n",
333
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/35.tar\n",
334
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/36.tar\n",
335
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/test/0.tar\n",
336
+ "aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\n"
337
+ ]
338
+ }
339
+ ],
340
+ "source": [
341
+ "import os\n",
342
+ "# from subprocess import call\n",
343
+ "# PS Note: it's faster to print the wget statements and then manually copy paste all them into terminal than to use subprocess call()\n",
344
+ "tmp = '/fsx/proj-fmri/shared/mindeyev2_dataset/' #'/scratch/mindeyev2_dataset/'\n",
345
+ "\n",
346
+ "hf_base_link = 'https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/'\n",
347
+ "\n",
348
+ "os.makedirs(tmp,exist_ok=True)\n",
349
+ "\n",
350
+ "files = [\n",
351
+ " \"shared1000.npy\",\n",
352
+ " \"betas_all_subj01.hdf5\",\n",
353
+ " \"betas_all_subj02.hdf5\",\n",
354
+ " \"betas_all_subj03.hdf5\",\n",
355
+ " \"betas_all_subj04.hdf5\",\n",
356
+ " \"betas_all_subj05.hdf5\",\n",
357
+ " \"betas_all_subj06.hdf5\",\n",
358
+ " \"betas_all_subj07.hdf5\",\n",
359
+ " \"betas_all_subj08.hdf5\",\n",
360
+ " \"coco_images_224_float16.hdf5\",\n",
361
+ " \"COCO_73k_subj_indices.hdf5\",\n",
362
+ "]\n",
363
+ "\n",
364
+ "for f in files: \n",
365
+ " command = f\"wget --show-progress {hf_base_link}{f} -O {tmp}{f}\"\n",
366
+ " print(command)\n",
367
+ " # call(command,shell=True)\n",
368
+ "\n",
369
+ "for sub in range(1,9):\n",
370
+ " subject = f'subj0{sub}'\n",
371
+ "\n",
372
+ " tmp_fol = f'{tmp}wds/{subject}/'\n",
373
+ " os.makedirs(tmp_fol,exist_ok=True)\n",
374
+ " os.makedirs(tmp_fol+'train',exist_ok=True)\n",
375
+ " os.makedirs(tmp_fol+'test',exist_ok=True)\n",
376
+ "\n",
377
+ " for i in range(37):\n",
378
+ " link = f'train/{i}.tar'\n",
379
+ " command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
380
+ " print(command)\n",
381
+ " # call(command,shell=True)\n",
382
+ "\n",
383
+ " link = f'test/0.tar'\n",
384
+ " command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
385
+ " print(command)\n",
386
+ " # call(command,shell=True)\n",
387
+ "\n",
388
+ "command = \"aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\"\n",
389
+ "print(command)"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "id": "30966082-59c2-411c-9b2e-4f4e3f9eb0f3",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": []
399
+ }
400
+ ],
401
+ "metadata": {
402
+ "kernelspec": {
403
+ "display_name": "Python 3 (ipykernel)",
404
+ "language": "python",
405
+ "name": "python3"
406
+ },
407
+ "language_info": {
408
+ "codemirror_mode": {
409
+ "name": "ipython",
410
+ "version": 3
411
+ },
412
+ "file_extension": ".py",
413
+ "mimetype": "text/x-python",
414
+ "name": "python",
415
+ "nbconvert_exporter": "python",
416
+ "pygments_lexer": "ipython3",
417
+ "version": "3.10.8"
418
+ }
419
+ },
420
+ "nbformat": 4,
421
+ "nbformat_minor": 5
422
+ }
src/models.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from torchvision import transforms
4
+ import torch
5
+ import torch.nn as nn
6
+ import PIL
7
+ import clip
8
+ import open_clip
9
+ from functools import partial
10
+
11
+ # for prior
12
+ from dalle2_pytorch import DiffusionPrior
13
+ from dalle2_pytorch.dalle2_pytorch import l2norm, default, exists
14
+ from tqdm.auto import tqdm
15
+ import random
16
+ import json
17
+ from dalle2_pytorch.train_configs import DiffusionPriorNetworkConfig
18
+ # vd prior
19
+ from dalle2_pytorch.dalle2_pytorch import RotaryEmbedding, CausalTransformer, SinusoidalPosEmb, MLP, Rearrange, repeat, rearrange, prob_mask_like, LayerNorm, RelPosBias, FeedForward, Attention
20
+
21
+ # for pipeline
22
+ from diffusers import StableDiffusionImageVariationPipeline, VersatileDiffusionDualGuidedPipeline
23
+ from typing import Callable, List, Optional, Union
24
+
25
+ from diffusers.models.vae import Decoder
26
+
27
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+
29
+
30
+ # class BrainMLP(nn.Module):
31
+ # def __init__(self, out_dim=257*768, in_dim=15724, clip_size=768, h=4096):
32
+ # super().__init__()
33
+ # self.lin0 = nn.Sequential(
34
+ # nn.Linear(in_dim, h, bias=False),
35
+ # nn.LayerNorm(h),
36
+ # nn.GELU(inplace=True),
37
+ # nn.Dropout(0.5))
38
+ # self.mlp = nn.ModuleList([
39
+ # nn.Sequential(
40
+ # nn.Linear(h, h),
41
+ # nn.LayerNorm(h),
42
+ # nn.GELU(inplace=True),
43
+ # nn.Dropout(0.15)
44
+ # ) for _ in range(4)])
45
+ # self.lin1 = nn.Linear(h, out_dim, bias=True)
46
+ # self.proj = nn.Sequential(
47
+ # nn.LayerNorm(clip_size),
48
+ # nn.GELU(),
49
+ # nn.Linear(clip_size, 2048),
50
+ # nn.LayerNorm(2048),
51
+ # nn.GELU(),
52
+ # nn.Linear(2048, 2048),
53
+ # nn.LayerNorm(2048),
54
+ # nn.GELU(),
55
+ # nn.Linear(2048, clip_size))
56
+ # def forward(self, x):
57
+ # x = self.lin0(x)
58
+ # residual = x
59
+ # for res_block in range(self.n_blocks):
60
+ # x = self.mlp[res_block](x)
61
+ # x += residual
62
+ # residual = x
63
+ # diffusion_prior_input = self.lin1(x.reshape(len(x), -1))
64
+ # disjointed_clip_fmri = self.proj(diffusion_prior_input.reshape(
65
+ # len(x),-1, self.clip_size))
66
+ # return diffusion_prior_input, disjointed_clip_fmri
67
+
68
+
69
+
70
+ class Clipper(torch.nn.Module):
71
+ def __init__(self, clip_variant, clamp_embs=False, norm_embs=False,
72
+ hidden_state=False, device=torch.device('cpu')):
73
+ super().__init__()
74
+ assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \
75
+ "clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64"
76
+ print(clip_variant, device)
77
+
78
+ if clip_variant=="ViT-L/14" and hidden_state:
79
+ # from transformers import CLIPVisionModelWithProjection
80
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",cache_dir="/fsx/proj-medarc/fmri/cache")
81
+ from transformers import CLIPVisionModelWithProjection
82
+ sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
83
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval()
84
+ image_encoder = image_encoder.to(device)
85
+ for param in image_encoder.parameters():
86
+ param.requires_grad = False # dont need to calculate gradients
87
+ self.image_encoder = image_encoder
88
+ elif hidden_state:
89
+ raise Exception("hidden_state embeddings only works with ViT-L/14 right now")
90
+
91
+ clip_model, preprocess = clip.load(clip_variant, device=device)
92
+ clip_model.eval() # dont want to train model
93
+ for param in clip_model.parameters():
94
+ param.requires_grad = False # dont need to calculate gradients
95
+
96
+ self.clip = clip_model
97
+ self.clip_variant = clip_variant
98
+ if clip_variant == "RN50x64":
99
+ self.clip_size = (448,448)
100
+ else:
101
+ self.clip_size = (224,224)
102
+
103
+ preproc = transforms.Compose([
104
+ transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
105
+ transforms.CenterCrop(size=self.clip_size),
106
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
107
+ ])
108
+ self.preprocess = preproc
109
+ self.hidden_state = hidden_state
110
+ self.mean = np.array([0.48145466, 0.4578275, 0.40821073])
111
+ self.std = np.array([0.26862954, 0.26130258, 0.27577711])
112
+ self.normalize = transforms.Normalize(self.mean, self.std)
113
+ self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist())
114
+ self.clamp_embs = clamp_embs
115
+ self.norm_embs = norm_embs
116
+ self.device= device
117
+
118
+ def versatile_normalize_embeddings(encoder_output):
119
+ embeds = encoder_output.last_hidden_state
120
+ embeds = image_encoder.vision_model.post_layernorm(embeds)
121
+ embeds = image_encoder.visual_projection(embeds)
122
+ return embeds
123
+ self.versatile_normalize_embeddings = versatile_normalize_embeddings
124
+
125
+ def resize_image(self, image):
126
+ # note: antialias should be False if planning to use Pinkney's Image Variation SD model
127
+ return transforms.Resize(self.clip_size)(image.to(self.device))
128
+
129
+ def embed_image(self, image):
130
+ """Expects images in -1 to 1 range"""
131
+ if self.hidden_state:
132
+ # clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation
133
+ clip_emb = self.preprocess((image).to(self.device))
134
+ clip_emb = self.image_encoder(clip_emb)
135
+ clip_emb = self.versatile_normalize_embeddings(clip_emb)
136
+ else:
137
+ clip_emb = self.preprocess(image.to(self.device))
138
+ clip_emb = self.clip.encode_image(clip_emb)
139
+ # input is now in CLIP space, but mind-reader preprint further processes embeddings:
140
+ if self.clamp_embs:
141
+ clip_emb = torch.clamp(clip_emb, -1.5, 1.5)
142
+ if self.norm_embs:
143
+ if self.hidden_state:
144
+ # normalize all tokens by cls token's norm
145
+ clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1)
146
+ else:
147
+ clip_emb = nn.functional.normalize(clip_emb, dim=-1)
148
+ return clip_emb
149
+
150
+ def embed_text(self, text_samples):
151
+ clip_text = clip.tokenize(text_samples).to(self.device)
152
+ clip_text = self.clip.encode_text(clip_text)
153
+ if self.clamp_embs:
154
+ clip_text = torch.clamp(clip_text, -1.5, 1.5)
155
+ if self.norm_embs:
156
+ clip_text = nn.functional.normalize(clip_text, dim=-1)
157
+ return clip_text
158
+
159
+ def embed_curated_annotations(self, annots):
160
+ for i,b in enumerate(annots):
161
+ t = ''
162
+ while t == '':
163
+ rand = torch.randint(5,(1,1))[0][0]
164
+ t = b[0,rand]
165
+ if i==0:
166
+ txt = np.array(t)
167
+ else:
168
+ txt = np.vstack((txt,t))
169
+ txt = txt.flatten()
170
+ return self.embed_text(txt)
171
+
172
+ class OpenClipper(torch.nn.Module):
173
+ def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')):
174
+ super().__init__()
175
+ print(clip_variant, device)
176
+ assert clip_variant == 'ViT-H-14' # not setup for other models yet
177
+
178
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',
179
+ pretrained='laion2b_s32b_b79k', device=device)
180
+ clip_model.eval() # dont want to train model
181
+ for param in clip_model.parameters():
182
+ param.requires_grad = False # dont need to calculate gradients
183
+
184
+ # overwrite preprocess to accept torch inputs instead of PIL Image
185
+ preprocess = transforms.Compose([
186
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None),
187
+ transforms.CenterCrop(224),
188
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
189
+ ])
190
+
191
+ tokenizer = open_clip.get_tokenizer('ViT-H-14')
192
+
193
+ self.clip = clip_model
194
+ self.norm_embs = norm_embs
195
+ self.preprocess = preprocess
196
+ self.tokenizer = tokenizer
197
+ self.device = device
198
+
199
+ def embed_image(self, image):
200
+ """Expects images in -1 to 1 range"""
201
+ image = self.preprocess(image).to(self.device)
202
+ with torch.no_grad(), torch.cuda.amp.autocast():
203
+ image_features = self.clip.encode_image(image)
204
+ if self.norm_embs:
205
+ image_features = nn.functional.normalize(image_features, dim=-1)
206
+ return image_features
207
+
208
+ def embed_text(self, text_samples):
209
+ text = self.tokenizer(text_samples).to(self.device)
210
+ with torch.no_grad(), torch.cuda.amp.autocast():
211
+ text_features = self.clip.encode_text(text)
212
+ if self.norm_embs:
213
+ text_features = nn.functional.normalize(text_features, dim=-1)
214
+ return text_features
215
+
216
+ def embed_curated_annotations(self, annots):
217
+ for i,b in enumerate(annots):
218
+ t = ''
219
+ while t == '':
220
+ rand = torch.randint(5,(1,1))[0][0]
221
+ t = b[0,rand]
222
+ if i==0:
223
+ txt = np.array(t)
224
+ else:
225
+ txt = np.vstack((txt,t))
226
+ txt = txt.flatten()
227
+ return self.embed_text(txt)
228
+
229
+ class BrainNetwork(nn.Module):
230
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):
231
+ super().__init__()
232
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
233
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
234
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
235
+ # self.temp = nn.Parameter(torch.tensor(.006))
236
+ self.lin0 = nn.Sequential(
237
+ nn.Linear(in_dim, h),
238
+ *[item() for item in act_and_norm],
239
+ nn.Dropout(drop1),
240
+ )
241
+ self.mlp = nn.ModuleList([
242
+ nn.Sequential(
243
+ nn.Linear(h, h),
244
+ *[item() for item in act_and_norm],
245
+ nn.Dropout(drop2)
246
+ ) for _ in range(n_blocks)
247
+ ])
248
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
249
+ self.n_blocks = n_blocks
250
+ self.clip_size = clip_size
251
+
252
+ self.use_projector = use_projector
253
+ if use_projector:
254
+ self.projector = nn.Sequential(
255
+ nn.LayerNorm(clip_size),
256
+ nn.GELU(),
257
+ nn.Linear(clip_size, 2048),
258
+ nn.LayerNorm(2048),
259
+ nn.GELU(),
260
+ nn.Linear(2048, 2048),
261
+ nn.LayerNorm(2048),
262
+ nn.GELU(),
263
+ nn.Linear(2048, clip_size)
264
+ )
265
+
266
+ def forward(self, x):
267
+ '''
268
+ bs, 1, 15724 -> bs, 32, h
269
+ bs, 32, h -> bs, 32h
270
+ b2, 32h -> bs, 768
271
+ '''
272
+ if x.ndim == 4:
273
+ # case when we passed 3D data of shape [N, 81, 104, 83]
274
+ assert x.shape[1] == 81 and x.shape[2] == 104 and x.shape[3] == 83
275
+ # [N, 699192]
276
+ x = x.reshape(x.shape[0], -1)
277
+ x = self.lin0(x) # bs, h
278
+ residual = x
279
+ for res_block in range(self.n_blocks):
280
+ x = self.mlp[res_block](x)
281
+ x += residual
282
+ residual = x
283
+ x = x.reshape(len(x), -1)
284
+ x = self.lin1(x)
285
+ if self.use_projector:
286
+ return x, self.projector(x.reshape(len(x), -1, self.clip_size))
287
+ return x
288
+
289
+ class BrainDiffusionPriorOld(DiffusionPrior):
290
+ """
291
+ Differences from original:
292
+ - Allow for passing of generators to torch random functions
293
+ - Option to include the voxel2clip model and pass voxels into forward method
294
+ - Return predictions when computing loss
295
+ - Load pretrained model from @nousr trained on LAION aesthetics
296
+ """
297
+ def __init__(self, *args, **kwargs):
298
+ voxel2clip = kwargs.pop('voxel2clip', None)
299
+ super().__init__(*args, **kwargs)
300
+ self.voxel2clip = voxel2clip
301
+
302
+ @torch.no_grad()
303
+ def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
304
+ generator=None):
305
+ b, *_, device = *x.shape, x.device
306
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
307
+ if generator is None:
308
+ noise = torch.randn_like(x)
309
+ else:
310
+ #noise = torch.randn_like(x)
311
+ noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
312
+ # no noise when t == 0
313
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
314
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
315
+ return pred, x_start
316
+
317
+ @torch.no_grad()
318
+ def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
319
+ batch, device = shape[0], self.device
320
+
321
+ if generator is None:
322
+ image_embed = torch.randn(shape, device = device)
323
+ else:
324
+ image_embed = torch.randn(shape, device = device, generator=generator)
325
+ x_start = None # for self-conditioning
326
+
327
+ if self.init_image_embed_l2norm:
328
+ image_embed = l2norm(image_embed) * self.image_embed_scale
329
+
330
+ for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
331
+ times = torch.full((batch,), i, device = device, dtype = torch.long)
332
+
333
+ self_cond = x_start if self.net.self_cond else None
334
+ image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
335
+ generator=generator)
336
+
337
+ if self.sampling_final_clamp_l2norm and self.predict_x_start:
338
+ image_embed = self.l2norm_clamp_embed(image_embed)
339
+
340
+ return image_embed
341
+
342
+ def p_losses(self, image_embed, times, text_cond, noise = None):
343
+ noise = default(noise, lambda: torch.randn_like(image_embed))
344
+
345
+ image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
346
+
347
+ self_cond = None
348
+ if self.net.self_cond and random.random() < 0.5:
349
+ with torch.no_grad():
350
+ self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
351
+
352
+ pred = self.net(
353
+ image_embed_noisy,
354
+ times,
355
+ self_cond = self_cond,
356
+ text_cond_drop_prob = self.text_cond_drop_prob,
357
+ image_cond_drop_prob = self.image_cond_drop_prob,
358
+ **text_cond
359
+ )
360
+
361
+ if self.predict_x_start and self.training_clamp_l2norm:
362
+ pred = self.l2norm_clamp_embed(pred)
363
+
364
+ if self.predict_v:
365
+ target = self.noise_scheduler.calculate_v(image_embed, times, noise)
366
+ elif self.predict_x_start:
367
+ target = image_embed
368
+ else:
369
+ target = noise
370
+
371
+ loss = self.noise_scheduler.loss_fn(pred, target)
372
+ return loss, pred
373
+
374
+ def forward(
375
+ self,
376
+ text = None,
377
+ image = None,
378
+ voxel = None,
379
+ text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
380
+ image_embed = None,
381
+ text_encodings = None, # as well as CLIP text encodings
382
+ *args,
383
+ **kwargs
384
+ ):
385
+ assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
386
+ assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
387
+ assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
388
+
389
+ if exists(voxel):
390
+ assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
391
+ assert not exists(text_embed), 'cannot pass in both text and voxels'
392
+ text_embed = self.voxel2clip(voxel)
393
+
394
+ if exists(image):
395
+ image_embed, _ = self.clip.embed_image(image)
396
+
397
+ # calculate text conditionings, based on what is passed in
398
+
399
+ if exists(text):
400
+ text_embed, text_encodings = self.clip.embed_text(text)
401
+
402
+ text_cond = dict(text_embed = text_embed)
403
+
404
+ if self.condition_on_text_encodings:
405
+ assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
406
+ text_cond = {**text_cond, 'text_encodings': text_encodings}
407
+
408
+ # timestep conditioning from ddpm
409
+
410
+ batch, device = image_embed.shape[0], image_embed.device
411
+ times = self.noise_scheduler.sample_random_times(batch)
412
+
413
+ # scale image embed (Katherine)
414
+
415
+ image_embed *= self.image_embed_scale
416
+
417
+ # calculate forward loss
418
+
419
+ loss, pred = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
420
+
421
+ return loss, pred#, text_embed
422
+
423
+ @staticmethod
424
+ def from_pretrained(net_kwargs={}, prior_kwargs={}, voxel2clip_path=None, ckpt_dir='./checkpoints'):
425
+ # "https://huggingface.co/nousr/conditioned-prior/raw/main/vit-l-14/aesthetic/prior_config.json"
426
+ config_url = os.path.join(ckpt_dir, "prior_config.json")
427
+ config = json.load(open(config_url))
428
+
429
+ config['prior']['net']['max_text_len'] = 256
430
+ config['prior']['net'].update(net_kwargs)
431
+ # print('net_config', config['prior']['net'])
432
+ net_config = DiffusionPriorNetworkConfig(**config['prior']['net'])
433
+
434
+ kwargs = config['prior']
435
+ kwargs.pop('clip')
436
+ kwargs.pop('net')
437
+ kwargs.update(prior_kwargs)
438
+ # print('prior_config', kwargs)
439
+
440
+ diffusion_prior_network = net_config.create()
441
+ diffusion_prior = BrainDiffusionPriorOld(net=diffusion_prior_network, clip=None, **kwargs).to(torch.device('cpu'))
442
+
443
+ # 'https://huggingface.co/nousr/conditioned-prior/resolve/main/vit-l-14/aesthetic/best.pth'
444
+ ckpt_url = os.path.join(ckpt_dir, 'best.pth')
445
+ ckpt = torch.load(ckpt_url, map_location=torch.device('cpu'))
446
+
447
+ # Note these keys will be missing (maybe due to an update to the code since training):
448
+ # "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed"
449
+ # I don't think these get used if `cond_drop_prob = 0` though (which is the default here)
450
+ diffusion_prior.load_state_dict(ckpt, strict=False)
451
+ # keys = diffusion_prior.load_state_dict(ckpt, strict=False)
452
+ # print("missing keys in prior checkpoint (probably ok)", keys.missing_keys)
453
+
454
+ if voxel2clip_path:
455
+ # load the voxel2clip weights
456
+ checkpoint = torch.load(voxel2clip_path, map_location=torch.device('cpu'))
457
+
458
+ state_dict = checkpoint['model_state_dict']
459
+ for key in list(state_dict.keys()):
460
+ if 'module.' in key:
461
+ state_dict[key.replace('module.', '')] = state_dict[key]
462
+ del state_dict[key]
463
+ diffusion_prior.voxel2clip.load_state_dict(state_dict)
464
+
465
+ return diffusion_prior
466
+
467
+ class BrainDiffusionPrior(DiffusionPrior):
468
+ """
469
+ Differences from original:
470
+ - Allow for passing of generators to torch random functions
471
+ - Option to include the voxel2clip model and pass voxels into forward method
472
+ - Return predictions when computing loss
473
+ - Load pretrained model from @nousr trained on LAION aesthetics
474
+ """
475
+ def __init__(self, *args, **kwargs):
476
+ voxel2clip = kwargs.pop('voxel2clip', None)
477
+ super().__init__(*args, **kwargs)
478
+ self.voxel2clip = voxel2clip
479
+
480
+ @torch.no_grad()
481
+ def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
482
+ generator=None):
483
+ b, *_, device = *x.shape, x.device
484
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
485
+ if generator is None:
486
+ noise = torch.randn_like(x)
487
+ else:
488
+ #noise = torch.randn_like(x)
489
+ noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
490
+ # no noise when t == 0
491
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
492
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
493
+ return pred, x_start
494
+
495
+ @torch.no_grad()
496
+ def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
497
+ batch, device = shape[0], self.device
498
+
499
+ if generator is None:
500
+ image_embed = torch.randn(shape, device = device)
501
+ else:
502
+ image_embed = torch.randn(shape, device = device, generator=generator)
503
+ x_start = None # for self-conditioning
504
+
505
+ if self.init_image_embed_l2norm:
506
+ image_embed = l2norm(image_embed) * self.image_embed_scale
507
+
508
+ for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
509
+ times = torch.full((batch,), i, device = device, dtype = torch.long)
510
+
511
+ self_cond = x_start if self.net.self_cond else None
512
+ image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
513
+ generator=generator)
514
+
515
+ if self.sampling_final_clamp_l2norm and self.predict_x_start:
516
+ image_embed = self.l2norm_clamp_embed(image_embed)
517
+
518
+ return image_embed
519
+
520
+ def p_losses(self, image_embed, times, text_cond, noise = None):
521
+ noise = default(noise, lambda: torch.randn_like(image_embed))
522
+
523
+ image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
524
+
525
+ self_cond = None
526
+ if self.net.self_cond and random.random() < 0.5:
527
+ with torch.no_grad():
528
+ self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
529
+
530
+ pred = self.net(
531
+ image_embed_noisy,
532
+ times,
533
+ self_cond = self_cond,
534
+ text_cond_drop_prob = self.text_cond_drop_prob,
535
+ image_cond_drop_prob = self.image_cond_drop_prob,
536
+ **text_cond
537
+ )
538
+
539
+ if self.predict_x_start and self.training_clamp_l2norm:
540
+ pred = self.l2norm_clamp_embed(pred)
541
+
542
+ if self.predict_v:
543
+ target = self.noise_scheduler.calculate_v(image_embed, times, noise)
544
+ elif self.predict_x_start:
545
+ target = image_embed
546
+ else:
547
+ target = noise
548
+
549
+ loss = self.noise_scheduler.loss_fn(pred, target)
550
+ return loss, pred
551
+
552
+ def forward(
553
+ self,
554
+ text = None,
555
+ image = None,
556
+ voxel = None,
557
+ text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
558
+ image_embed = None,
559
+ text_encodings = None, # as well as CLIP text encodings
560
+ *args,
561
+ **kwargs
562
+ ):
563
+ assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
564
+ assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
565
+ assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
566
+
567
+ if exists(voxel):
568
+ assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
569
+ assert not exists(text_embed), 'cannot pass in both text and voxels'
570
+ if self.voxel2clip.use_projector:
571
+ clip_voxels_mse, clip_voxels = self.voxel2clip(voxel)
572
+ text_embed = clip_voxels_mse
573
+ else:
574
+ clip_voxels = self.voxel2clip(voxel)
575
+ text_embed = clip_voxels_mse = clip_voxels
576
+ # text_embed = self.voxel2clip(voxel)
577
+
578
+ if exists(image):
579
+ image_embed, _ = self.clip.embed_image(image)
580
+
581
+ # calculate text conditionings, based on what is passed in
582
+
583
+ if exists(text):
584
+ text_embed, text_encodings = self.clip.embed_text(text)
585
+
586
+ text_cond = dict(text_embed = text_embed)
587
+
588
+ if self.condition_on_text_encodings:
589
+ assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
590
+ text_cond = {**text_cond, 'text_encodings': text_encodings}
591
+
592
+ # timestep conditioning from ddpm
593
+
594
+ batch, device = image_embed.shape[0], image_embed.device
595
+ times = self.noise_scheduler.sample_random_times(batch)
596
+
597
+ # PS: I dont think we need this? also if uncommented this does in-place global variable change
598
+ # scale image embed (Katherine)
599
+ # image_embed *= self.image_embed_scale
600
+
601
+ # calculate forward loss
602
+
603
+ loss, pred = self.p_losses(image_embed*self.image_embed_scale, times, text_cond = text_cond, *args, **kwargs)
604
+
605
+ # undo the scaling so we can directly use it for real mse loss and reconstruction
606
+ return loss, pred
607
+
608
+ class BrainSD(StableDiffusionImageVariationPipeline):
609
+ """
610
+ Differences from original:
611
+ - Keep generated images on GPU and return tensors
612
+ - No NSFW checker
613
+ - Can pass in image or image_embedding to generate a variation
614
+ NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
615
+ """
616
+
617
+ def decode_latents(self, latents):
618
+ latents = 1 / 0.18215 * latents
619
+ image = self.vae.decode(latents).sample
620
+ image = (image / 2 + 0.5).clamp(0, 1)
621
+ # # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
622
+ # image = image.cpu().permute(0, 2, 3, 1).float().numpy()
623
+ return image
624
+
625
+ @torch.no_grad()
626
+ def __call__(
627
+ self,
628
+ image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
629
+ height: Optional[int] = None,
630
+ width: Optional[int] = None,
631
+ num_inference_steps: int = 50,
632
+ guidance_scale: float = 7.5,
633
+ num_images_per_prompt: Optional[int] = 1,
634
+ eta: float = 0.0,
635
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
636
+ latents: Optional[torch.FloatTensor] = None,
637
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
638
+ callback_steps: Optional[int] = 1,
639
+ image_embeddings: Optional[torch.FloatTensor] = None,
640
+ ):
641
+
642
+ # 0. Default height and width to unet
643
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
644
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
645
+
646
+ device = self._execution_device
647
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
648
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
649
+ # corresponds to doing no classifier free guidance.
650
+ do_classifier_free_guidance = guidance_scale > 1.0
651
+
652
+ if image_embeddings is None:
653
+ assert image is not None, "If image_embeddings is None, image must not be None"
654
+
655
+ # resize and normalize the way that's recommended
656
+ tform = transforms.Compose([
657
+ #transforms.ToTensor(), ## don't need this since we've already got tensors
658
+ transforms.Resize(
659
+ (224, 224),
660
+ interpolation=transforms.InterpolationMode.BICUBIC,
661
+ antialias=False,
662
+ ),
663
+ transforms.Normalize(
664
+ [0.48145466, 0.4578275, 0.40821073],
665
+ [0.26862954, 0.26130258, 0.27577711]),
666
+ ])
667
+ image = tform(image)
668
+
669
+ # 1. Check inputs. Raise error if not correct
670
+ self.check_inputs(image, height, width, callback_steps)
671
+
672
+ # 2. Define call parameters
673
+ if isinstance(image, PIL.Image.Image):
674
+ batch_size = 1
675
+ elif isinstance(image, list):
676
+ batch_size = len(image)
677
+ else:
678
+ batch_size = image.shape[0]
679
+
680
+ # 3. Encode input image
681
+ image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
682
+ else:
683
+ batch_size = image_embeddings.shape[0] // 2
684
+
685
+ # 4. Prepare timesteps
686
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
687
+ timesteps = self.scheduler.timesteps
688
+
689
+ # 5. Prepare latent variables
690
+ num_channels_latents = self.unet.in_channels
691
+
692
+ latents = self.prepare_latents(
693
+ batch_size * num_images_per_prompt,
694
+ num_channels_latents,
695
+ height,
696
+ width,
697
+ image_embeddings.dtype,
698
+ device,
699
+ generator,
700
+ latents,
701
+ )
702
+
703
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
704
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
705
+
706
+ # 7. Denoising loop
707
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
708
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
709
+ for i, t in enumerate(timesteps):
710
+ # expand the latents if we are doing classifier free guidance
711
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
712
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
713
+
714
+ # predict the noise residual
715
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
716
+
717
+ # perform guidance
718
+ if do_classifier_free_guidance:
719
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
720
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
721
+
722
+ # compute the previous noisy sample x_t -> x_t-1
723
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
724
+
725
+ # call the callback, if provided
726
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
727
+ progress_bar.update()
728
+ if callback is not None and i % callback_steps == 0:
729
+ callback(i, t, latents)
730
+
731
+ # 8. Post-processing
732
+ image = self.decode_latents(latents)
733
+
734
+ # # 9. Run safety checker
735
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
736
+
737
+ # # 10. Convert to PIL
738
+ # if output_type == "pil":
739
+ # image = self.numpy_to_pil(image)
740
+
741
+ # if not return_dict:
742
+ # return (image, has_nsfw_concept)
743
+
744
+ # return StableDiffusionPipelineOutput(images=image)
745
+
746
+ return image
747
+
748
+ class Voxel2StableDiffusionModel(torch.nn.Module):
749
+ def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False):
750
+ super().__init__()
751
+ self.lin0 = nn.Sequential(
752
+ nn.Linear(in_dim, h, bias=False),
753
+ nn.LayerNorm(h),
754
+ nn.SiLU(inplace=True),
755
+ nn.Dropout(0.5),
756
+ )
757
+
758
+ self.mlp = nn.ModuleList([
759
+ nn.Sequential(
760
+ nn.Linear(h, h, bias=False),
761
+ nn.LayerNorm(h),
762
+ nn.SiLU(inplace=True),
763
+ nn.Dropout(0.25)
764
+ ) for _ in range(n_blocks)
765
+ ])
766
+ self.lin1 = nn.Linear(h, 16384, bias=False)
767
+ self.norm = nn.LayerNorm(512)
768
+
769
+ self.register_parameter('queries', nn.Parameter(torch.randn(1, 256, 512) * 0.044))
770
+ self.transformer = nn.TransformerDecoder(
771
+ nn.TransformerDecoderLayer(d_model=512, nhead=8, norm_first=True,
772
+ dim_feedforward=1024, activation=nn.functional.gelu,
773
+ batch_first=True, dropout=0.25),
774
+ num_layers=n_blocks
775
+ )
776
+
777
+ # option 1 -> 124.56M
778
+ # self.lin1 = nn.Linear(h, 32768, bias=True)
779
+ # self.upsampler = Decoder(
780
+ # in_channels=64,
781
+ # out_channels=4,
782
+ # up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
783
+ # block_out_channels=[64, 128, 256, 256],
784
+ # layers_per_block=1,
785
+ # )
786
+
787
+ # option2 -> 132.52M
788
+ # self.lin1 = nn.Linear(h, 1024, bias=True)
789
+ # self.upsampler = Decoder(
790
+ # in_channels=64,
791
+ # out_channels=4,
792
+ # up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"],
793
+ # block_out_channels=[64, 128, 256, 256, 512],
794
+ # layers_per_block=1,
795
+ # )
796
+
797
+ if use_cont:
798
+ self.maps_projector = nn.Sequential(
799
+ nn.LayerNorm(512),
800
+ nn.Linear(512, 512),
801
+ nn.LayerNorm(512),
802
+ nn.ReLU(True),
803
+ nn.Linear(512, 512),
804
+ nn.LayerNorm(512),
805
+ nn.ReLU(True),
806
+ nn.Linear(512, 512)
807
+ )
808
+ else:
809
+ self.maps_projector = nn.Identity()
810
+
811
+ self.upsampler = nn.Sequential(
812
+ nn.GroupNorm(1, 32),
813
+ nn.SiLU(inplace=True),
814
+ nn.Conv2d(32, 320, 3, padding=1),
815
+ nn.GroupNorm(32, 320),
816
+ nn.SiLU(inplace=True),
817
+ nn.Conv2d(320, 320, 3, padding=1),
818
+ nn.GroupNorm(32, 320),
819
+ nn.SiLU(inplace=True),
820
+ nn.Conv2d(320, 4, 3, padding=1)
821
+ )
822
+
823
+ def forward(self, x, return_transformer_feats=False):
824
+ x = self.lin0(x)
825
+ residual = x
826
+ for res_block in self.mlp:
827
+ x = res_block(x)
828
+ x = x + residual
829
+ residual = x
830
+ x = x.reshape(len(x), -1)
831
+ x = self.lin1(x) # bs, 4096
832
+
833
+ # # x = x.reshape(x.shape[0], -1, 8, 8).contiguous() # bs, 64, 8, 8
834
+ # x = x.reshape(x.shape[0], -1, 64, 64).contiguous()
835
+ # return self.upsampler(x)
836
+
837
+ # decoder
838
+ x = self.norm(x.reshape(x.shape[0], 32, 512))
839
+ preds = self.transformer(self.queries.expand(x.shape[0], -1, -1), x)
840
+ sd_embeds = preds.permute(0,2,1).reshape(-1, 512, 16, 16)
841
+ sd_embeds = nn.functional.pixel_shuffle(sd_embeds, 4) # bs, 32, 32, 32
842
+
843
+ # contrastive
844
+ if return_transformer_feats:
845
+ return self.upsampler(sd_embeds), self.maps_projector(preds)
846
+
847
+ return self.upsampler(sd_embeds)
848
+
849
+ class BrainNetworkDETR(BrainNetwork):
850
+ # 133M
851
+ def __init__(self, out_dim=768, in_dim=15724, h=4096, n_blocks=4, norm_type='ln', act_first=False,
852
+ encoder_tokens=32, decoder_tokens=257):
853
+ # encoder
854
+ super().__init__(out_dim*encoder_tokens, in_dim, h, n_blocks, norm_type, act_first)
855
+ self.norm = nn.LayerNorm(out_dim)
856
+ self.encoder_tokens = encoder_tokens
857
+
858
+ self.register_parameter('queries', nn.Parameter(torch.randn(1, decoder_tokens, out_dim)))
859
+ self.transformer = nn.TransformerDecoder(
860
+ nn.TransformerDecoderLayer(d_model=out_dim, nhead=8,
861
+ dim_feedforward=1024,
862
+ batch_first=True, dropout=0.25),
863
+ num_layers=n_blocks
864
+ )
865
+ self.decoder_projector = nn.Sequential(
866
+ nn.LayerNorm(out_dim),
867
+ nn.Linear(out_dim, out_dim)
868
+ )
869
+
870
+
871
+ def forward(self, x):
872
+ enc = super().forward(x)
873
+ enc = self.norm(enc.reshape(enc.shape[0], self.encoder_tokens, -1))
874
+
875
+ dec = self.transformer(self.queries.expand(x.shape[0], -1, -1), enc)
876
+ dec = self.decoder_projector(dec)
877
+ return dec
878
+
879
+ class VersatileDiffusionPriorNetwork(nn.Module):
880
+ def __init__(
881
+ self,
882
+ dim,
883
+ num_timesteps = None,
884
+ num_time_embeds = 1,
885
+ # num_image_embeds = 1,
886
+ # num_brain_embeds = 1,
887
+ num_tokens = 257,
888
+ causal = True,
889
+ learned_query_mode = 'none',
890
+ **kwargs
891
+ ):
892
+ super().__init__()
893
+ self.dim = dim
894
+ self.num_time_embeds = num_time_embeds
895
+ self.continuous_embedded_time = not exists(num_timesteps)
896
+ self.learned_query_mode = learned_query_mode
897
+
898
+ self.to_time_embeds = nn.Sequential(
899
+ nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
900
+ Rearrange('b (n d) -> b n d', n = num_time_embeds)
901
+ )
902
+
903
+ if self.learned_query_mode == 'token':
904
+ self.learned_query = nn.Parameter(torch.randn(num_tokens, dim))
905
+ if self.learned_query_mode == 'pos_emb':
906
+ scale = dim ** -0.5
907
+ self.learned_query = nn.Parameter(torch.randn(num_tokens, dim) * scale)
908
+ if self.learned_query_mode == 'all_pos_emb':
909
+ scale = dim ** -0.5
910
+ self.learned_query = nn.Parameter(torch.randn(num_tokens*2+1, dim) * scale)
911
+ self.causal_transformer = FlaggedCausalTransformer(dim = dim, causal=causal, **kwargs)
912
+
913
+ self.null_brain_embeds = nn.Parameter(torch.randn(num_tokens, dim))
914
+ self.null_image_embed = nn.Parameter(torch.randn(num_tokens, dim))
915
+
916
+ self.num_tokens = num_tokens
917
+ self.self_cond = False
918
+
919
+ def forward_with_cond_scale(
920
+ self,
921
+ *args,
922
+ cond_scale = 1.,
923
+ **kwargs
924
+ ):
925
+ logits = self.forward(*args, **kwargs)
926
+
927
+ if cond_scale == 1:
928
+ return logits
929
+
930
+ null_logits = self.forward(*args, brain_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
931
+ return null_logits + (logits - null_logits) * cond_scale
932
+
933
+ def forward(
934
+ self,
935
+ image_embed,
936
+ diffusion_timesteps,
937
+ *,
938
+ self_cond=None,
939
+ brain_embed=None,
940
+ text_embed=None,
941
+ brain_cond_drop_prob = 0.,
942
+ text_cond_drop_prob = None,
943
+ image_cond_drop_prob = 0.
944
+ ):
945
+ if text_embed is not None:
946
+ brain_embed = text_embed
947
+ if text_cond_drop_prob is not None:
948
+ brain_cond_drop_prob = text_cond_drop_prob
949
+
950
+ image_embed = image_embed.view(len(image_embed),-1,768)
951
+ # text_embed = text_embed.view(len(text_embed),-1,768)
952
+ brain_embed = brain_embed.view(len(brain_embed),-1,768)
953
+ # print(*image_embed.shape)
954
+ # print(*image_embed.shape, image_embed.device, image_embed.dtype)
955
+
956
+ batch, _, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
957
+ # num_time_embeds, num_image_embeds, num_brain_embeds = self.num_time_embeds, self.num_image_embeds, self.num_brain_embeds
958
+
959
+ # classifier free guidance masks
960
+ brain_keep_mask = prob_mask_like((batch,), 1 - brain_cond_drop_prob, device = device)
961
+ brain_keep_mask = rearrange(brain_keep_mask, 'b -> b 1 1')
962
+
963
+ image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
964
+ image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
965
+
966
+ # mask out brain embeddings with null brain embeddings
967
+
968
+ # import pdb; pdb.set_trace()
969
+ null_brain_embeds = self.null_brain_embeds.to(brain_embed.dtype)
970
+ brain_embed = torch.where(
971
+ brain_keep_mask,
972
+ brain_embed,
973
+ null_brain_embeds[None]
974
+ )
975
+
976
+ # mask out image embeddings with null image embeddings
977
+ null_image_embed = self.null_image_embed.to(image_embed.dtype)
978
+ image_embed = torch.where(
979
+ image_keep_mask,
980
+ image_embed,
981
+ null_image_embed[None]
982
+ )
983
+
984
+ # whether brain embedding is used for conditioning depends on whether brain encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
985
+ # but let's just do it right
986
+ if self.continuous_embedded_time:
987
+ # if continuous cast to flat, else keep int for indexing embeddings
988
+ diffusion_timesteps = diffusion_timesteps.type(dtype)
989
+ time_embed = self.to_time_embeds(diffusion_timesteps)
990
+
991
+ if self.learned_query_mode == 'token':
992
+ learned_queries = repeat(self.learned_query, 'n d -> b n d', b = batch)
993
+ elif self.learned_query_mode == 'pos_emb':
994
+ pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
995
+ image_embed = image_embed + pos_embs
996
+ learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
997
+ elif self.learned_query_mode == 'all_pos_emb':
998
+ pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
999
+ learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
1000
+ else:
1001
+ learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
1002
+
1003
+ tokens = torch.cat((
1004
+ brain_embed, # 257
1005
+ time_embed, # 1
1006
+ image_embed, # 257
1007
+ learned_queries # 257
1008
+ ), dim = -2)
1009
+ if self.learned_query_mode == 'all_pos_emb':
1010
+ tokens = tokens + pos_embs
1011
+
1012
+ # attend
1013
+ tokens = self.causal_transformer(tokens)
1014
+
1015
+ # get learned query, which should predict the image embedding (per DDPM timestep)
1016
+ pred_image_embed = tokens[..., -self.num_tokens:, :]
1017
+
1018
+ return pred_image_embed
1019
+
1020
+ # import math
1021
+ # from collections import namedtuple
1022
+ # from einops import rearrange, repeat, reduce, pack, unpack
1023
+ # from einops.layers.torch import Rearrange
1024
+ # from torch import einsum
1025
+ # class Attention(nn.Module):
1026
+ # def __init__(
1027
+ # self,
1028
+ # dim,
1029
+ # *,
1030
+ # dim_head = 64,
1031
+ # heads = 8,
1032
+ # dropout = 0.,
1033
+ # causal = False,
1034
+ # rotary_emb = None,
1035
+ # cosine_sim = True,
1036
+ # cosine_sim_scale = 16
1037
+ # ):
1038
+ # super().__init__()
1039
+ # self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
1040
+ # self.cosine_sim = cosine_sim
1041
+
1042
+ # self.heads = heads
1043
+ # inner_dim = dim_head * heads
1044
+ # self.dim = dim
1045
+ # self.inner_dim = inner_dim
1046
+
1047
+ # self.causal = causal
1048
+ # self.norm = LayerNorm(dim)
1049
+ # self.dropout = nn.Dropout(dropout)
1050
+
1051
+ # self.null_kv = nn.Parameter(torch.randn(2, dim_head))
1052
+ # self.to_q = nn.Linear(dim, inner_dim, bias = False)
1053
+ # self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
1054
+
1055
+ # self.rotary_emb = rotary_emb
1056
+
1057
+ # self.to_out = nn.Sequential(
1058
+ # nn.Linear(inner_dim, dim, bias = False),
1059
+ # LayerNorm(dim)
1060
+ # )
1061
+
1062
+ # def forward(self, x, mask = None, attn_bias = None):
1063
+ # b, n, device = *x.shape[:2], x.device
1064
+ # print("xinit", torch.any(torch.isnan(x)))
1065
+ # x = self.norm(x)
1066
+ # print("xnorm", torch.any(torch.isnan(x)))
1067
+ # print("xnorm.shape", x.shape)
1068
+
1069
+ # q = self.to_q(x)
1070
+ # print("q0", torch.any(torch.isnan(q)))
1071
+
1072
+ # k, v = self.to_kv(x).chunk(2, dim = -1)
1073
+ # print("k0", torch.any(torch.isnan(k)))
1074
+ # print("v0", torch.any(torch.isnan(v)))
1075
+
1076
+
1077
+ # q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
1078
+ # q = q * self.scale
1079
+
1080
+ # # rotary embeddings
1081
+
1082
+ # if exists(self.rotary_emb):
1083
+ # q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
1084
+
1085
+ # # add null key / value for classifier free guidance in prior net
1086
+
1087
+ # nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
1088
+ # k = torch.cat((nk, k), dim = -2)
1089
+ # v = torch.cat((nv, v), dim = -2)
1090
+
1091
+ # # whether to use cosine sim
1092
+
1093
+ # if self.cosine_sim:
1094
+ # q, k = map(l2norm, (q, k))
1095
+
1096
+ # q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
1097
+ # print("q2", torch.any(torch.isnan(q)))
1098
+ # print("k2", torch.any(torch.isnan(k)))
1099
+
1100
+ # # calculate query / key similarities
1101
+
1102
+ # sim = einsum('b h i d, b j d -> b h i j', q, k)
1103
+
1104
+ # # relative positional encoding (T5 style)
1105
+
1106
+ # if exists(attn_bias):
1107
+ # sim = sim + attn_bias
1108
+
1109
+ # # masking
1110
+
1111
+ # max_neg_value = -torch.finfo(sim.dtype).max
1112
+
1113
+ # print("sim1", torch.any(torch.isnan(sim)))
1114
+
1115
+ # if exists(mask):
1116
+ # mask = F.pad(mask, (1, 0), value = True)
1117
+ # mask = rearrange(mask, 'b j -> b 1 1 j')
1118
+ # sim = sim.masked_fill(~mask, max_neg_value)
1119
+
1120
+ # print("sim2", torch.any(torch.isnan(sim)))
1121
+
1122
+ # if self.causal:
1123
+ # i, j = sim.shape[-2:]
1124
+ # causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
1125
+ # sim = sim.masked_fill(causal_mask, max_neg_value)
1126
+
1127
+ # # attention
1128
+
1129
+ # print("simFinal", torch.any(torch.isnan(sim)))
1130
+ # attn = sim.softmax(dim = -1, dtype = torch.float32)
1131
+ # print("attn", torch.any(torch.isnan(attn)))
1132
+ # attn = attn.type(sim.dtype)
1133
+
1134
+ # attn = self.dropout(attn)
1135
+
1136
+ # # aggregate values
1137
+
1138
+ # out = einsum('b h i j, b j d -> b h i d', attn, v)
1139
+
1140
+ # out = rearrange(out, 'b h n d -> b n (h d)')
1141
+ # return self.to_out(out)
1142
+
1143
+ class FlaggedCausalTransformer(nn.Module):
1144
+ def __init__(
1145
+ self,
1146
+ *,
1147
+ dim,
1148
+ depth,
1149
+ dim_head = 64,
1150
+ heads = 8,
1151
+ ff_mult = 4,
1152
+ norm_in = False,
1153
+ norm_out = True,
1154
+ attn_dropout = 0.,
1155
+ ff_dropout = 0.,
1156
+ final_proj = True,
1157
+ normformer = False,
1158
+ rotary_emb = False,
1159
+ causal=True
1160
+ ):
1161
+ super().__init__()
1162
+ self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
1163
+
1164
+ self.rel_pos_bias = RelPosBias(heads = heads)
1165
+
1166
+ rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
1167
+
1168
+ self.layers = nn.ModuleList([])
1169
+ for _ in range(depth):
1170
+ self.layers.append(nn.ModuleList([
1171
+ Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
1172
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
1173
+ ]))
1174
+
1175
+ self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
1176
+ self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
1177
+
1178
+ def forward(self, x):
1179
+ n, device = x.shape[1], x.device
1180
+
1181
+ x = self.init_norm(x)
1182
+
1183
+ attn_bias = self.rel_pos_bias(n, n + 1, device = device)
1184
+
1185
+ for attn, ff in self.layers:
1186
+ x = attn(x, attn_bias = attn_bias) + x
1187
+ x = ff(x) + x
1188
+
1189
+ out = self.norm(x)
1190
+ return self.project_out(out)
1191
+
1192
+ class BrainVD(VersatileDiffusionDualGuidedPipeline):
1193
+ """
1194
+ Differences from original:
1195
+ - Keep generated images on GPU and return tensors
1196
+ - No NSFW checker
1197
+ - Can pass in image or image_embedding to generate a variation
1198
+ NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
1199
+ """
1200
+
1201
+ def decode_latents(self, latents):
1202
+ latents = 1 / self.vae.config.scaling_factor * latents
1203
+ image = self.vae.decode(latents).sample
1204
+ image = (image / 2 + 0.5).clamp(0, 1)
1205
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1206
+ # image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1207
+ return image
1208
+
1209
+ def check_inputs(self, prompt, image, height, width, callback_steps):
1210
+ if prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
1211
+ raise ValueError(f"`prompt` has to be of type None, `str` or `list` but is {type(prompt)}")
1212
+ if image is not None and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
1213
+ raise ValueError(f"`image` has to be of type None, `PIL.Image` or `list` but is {type(image)}")
1214
+
1215
+ if height % 8 != 0 or width % 8 != 0:
1216
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
1217
+
1218
+ if (callback_steps is None) or (
1219
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
1220
+ ):
1221
+ raise ValueError(
1222
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
1223
+ f" {type(callback_steps)}."
1224
+ )
1225
+
1226
+ @torch.no_grad()
1227
+ def __call__(
1228
+ self,
1229
+ prompt: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
1230
+ image: Union[str, List[str]] = None,
1231
+ text_to_image_strength: float = 0.5,
1232
+ height: Optional[int] = None,
1233
+ width: Optional[int] = None,
1234
+ num_inference_steps: int = 50,
1235
+ guidance_scale: float = 7.5,
1236
+ num_images_per_prompt: Optional[int] = 1,
1237
+ eta: float = 0.0,
1238
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1239
+ latents: Optional[torch.FloatTensor] = None,
1240
+ output_type: Optional[str] = "pil",
1241
+ return_dict: bool = True,
1242
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1243
+ callback_steps: Optional[int] = 1,
1244
+ image_embeddings: Optional[torch.FloatTensor] = None,
1245
+ prompt_embeddings: Optional[torch.FloatTensor] = None,
1246
+ **kwargs,
1247
+ ):
1248
+
1249
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
1250
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
1251
+
1252
+ self.check_inputs(prompt, image, height, width, callback_steps)
1253
+
1254
+ prompt = [prompt] if prompt is not None and not isinstance(prompt, list) else prompt
1255
+ image = [image] if image is not None and not isinstance(image, list) else image
1256
+ device = self._execution_device
1257
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1258
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1259
+ # corresponds to doing no classifier free guidance.
1260
+ do_classifier_free_guidance = guidance_scale > 1.0
1261
+
1262
+
1263
+ # 3. Encode input prompt
1264
+ if image_embeddings is None:
1265
+ if image is not None:
1266
+ image_embeddings = self._encode_image_prompt(
1267
+ image, device, num_images_per_prompt, do_classifier_free_guidance
1268
+ )
1269
+ batch_size = len(image)
1270
+ else:
1271
+ image_embeddings = None
1272
+
1273
+ if prompt_embeddings is None:
1274
+ if prompt is not None:
1275
+ prompt_embeddings = self._encode_text_prompt(
1276
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance
1277
+ )
1278
+ batch_size = len(prompt)
1279
+ else:
1280
+ prompt_embeddings = None
1281
+ if image_embeddings is not None:
1282
+ batch_size = image_embeddings.shape[0] // 2
1283
+ elif prompt_embeddings is not None:
1284
+ batch_size = prompt_embeddings.shape[0] // 2
1285
+
1286
+ if image_embeddings is not None and prompt_embeddings is not None:
1287
+ dual_prompt_embeddings = torch.cat([prompt_embeddings, image_embeddings], dim=1)
1288
+ elif image_embeddings is None:
1289
+ dual_prompt_embeddings = prompt_embeddings
1290
+ text_to_image_strength = 1.
1291
+ elif prompt_embeddings is None:
1292
+ dual_prompt_embeddings = image_embeddings
1293
+ text_to_image_strength = 0.
1294
+ else:
1295
+ raise ValueError()
1296
+
1297
+ # 4. Prepare timesteps
1298
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1299
+ timesteps = self.scheduler.timesteps
1300
+
1301
+ # 5. Prepare latent variables
1302
+ num_channels_latents = self.image_unet.in_channels
1303
+ latents = self.prepare_latents(
1304
+ batch_size * num_images_per_prompt,
1305
+ num_channels_latents,
1306
+ height,
1307
+ width,
1308
+ dual_prompt_embeddings.dtype,
1309
+ device,
1310
+ generator,
1311
+ latents,
1312
+ )
1313
+
1314
+ # 6. Prepare extra step kwargs.
1315
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1316
+
1317
+ # 7. Combine the attention blocks of the image and text UNets
1318
+ self.set_transformer_params(text_to_image_strength, ("text", "image"))
1319
+
1320
+ # 8. Denoising loop
1321
+ for i, t in enumerate(self.progress_bar(timesteps)):
1322
+ # expand the latents if we are doing classifier free guidance
1323
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1324
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1325
+
1326
+ # predict the noise residual
1327
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
1328
+
1329
+ # perform guidance
1330
+ if do_classifier_free_guidance:
1331
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1332
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1333
+
1334
+ # compute the previous noisy sample x_t -> x_t-1
1335
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1336
+
1337
+ # call the callback, if provided
1338
+ if callback is not None and i % callback_steps == 0:
1339
+ callback(i, t, latents)
1340
+
1341
+ # 8. Post-processing
1342
+ image = self.decode_latents(latents)
1343
+
1344
+ return image
src/setup.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Commands to setup a new conda environment and install all the necessary packages
3
+ # See the environment.yaml file for "conda env export > environment.yaml" after running this.
4
+
5
+ set -e
6
+
7
+ conda create -n fmri python=3.10.8 -y
8
+ conda activate fmri
9
+
10
+ conda install numpy matplotlib tqdm scikit-image jupyterlab -y
11
+
12
+ pip install accelerate webdataset clip pandas matplotlib ftfy regex kornia umap-learn h5py
13
+ pip install torchvision==0.15.2 torch==2.0.1
14
+ pip install diffusers
src/utils.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torchvision import transforms
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import PIL
7
+ import random
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import math
12
+ import webdataset as wds
13
+ import tempfile
14
+ from torchvision.utils import make_grid
15
+ from diffusers.utils import randn_tensor
16
+
17
+ import json
18
+ from torchmetrics.image.fid import FrechetInceptionDistance
19
+ from PIL import Image
20
+ import requests
21
+ import io
22
+ import time
23
+
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ def is_interactive():
27
+ import __main__ as main
28
+ return not hasattr(main, '__file__')
29
+
30
+ def seed_everything(seed=0, cudnn_deterministic=True):
31
+ random.seed(seed)
32
+ os.environ['PYTHONHASHSEED'] = str(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ if cudnn_deterministic:
38
+ torch.backends.cudnn.deterministic = True
39
+ else:
40
+ ## needs to be False to use conv3D
41
+ print('Note: not using cudnn.deterministic')
42
+
43
+ def np_to_Image(x):
44
+ if x.ndim==4:
45
+ x=x[0]
46
+ return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
47
+
48
+ def torch_to_Image(x):
49
+ if x.ndim==4:
50
+ x=x[0]
51
+ return transforms.ToPILImage()(x)
52
+
53
+ def Image_to_torch(x):
54
+ try:
55
+ x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
56
+ except:
57
+ x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
58
+ return x
59
+
60
+ def torch_to_matplotlib(x,device=device):
61
+ if torch.mean(x)>10:
62
+ x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
63
+ else:
64
+ x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
65
+ if device=='cpu':
66
+ return x[0]
67
+ else:
68
+ return x.cpu().numpy()[0]
69
+
70
+ def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
71
+ #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
72
+ numerator = A @ B.T
73
+ A_l2 = torch.mul(A, A).sum(axis=dim)
74
+ B_l2 = torch.mul(B, B).sum(axis=dim)
75
+ denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
76
+ return torch.div(numerator, denominator)
77
+
78
+ def batchwise_pearson_correlation(Z, B):
79
+ # Calculate means
80
+ Z_mean = torch.mean(Z, dim=1, keepdim=True)
81
+ B_mean = torch.mean(B, dim=1, keepdim=True)
82
+
83
+ # Subtract means
84
+ Z_centered = Z - Z_mean
85
+ B_centered = B - B_mean
86
+
87
+ # Calculate Pearson correlation coefficient
88
+ numerator = Z_centered @ B_centered.T
89
+ Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True)
90
+ B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True)
91
+ denominator = Z_centered_norm @ B_centered_norm.T
92
+
93
+ pearson_correlation = (numerator / denominator)
94
+ return pearson_correlation
95
+
96
+ def batchwise_cosine_similarity(Z,B):
97
+ # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
98
+ B = B.T
99
+ Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
100
+ B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
101
+ cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
102
+ return cosine_similarity
103
+
104
+ def topk(similarities,labels,k=5):
105
+ if k > similarities.shape[0]:
106
+ k = similarities.shape[0]
107
+ topsum=0
108
+ for i in range(k):
109
+ topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
110
+ return topsum
111
+
112
+ def get_non_diagonals(a):
113
+ a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
114
+ # make diagonals -1
115
+ a=a.fill_diagonal_(-1)
116
+ return a
117
+
118
+ def gather_features(image_features, voxel_features, accelerator):
119
+ all_image_features = accelerator.gather(image_features.contiguous())
120
+ if voxel_features is not None:
121
+ all_voxel_features = accelerator.gather(voxel_features.contiguous())
122
+ return all_image_features, all_voxel_features
123
+ return all_image_features
124
+
125
+ def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None):
126
+ # if not distributed:
127
+ clip_clip = (targs @ targs.T)/temp
128
+ brain_clip = (preds @ targs.T)/temp
129
+ # else:
130
+ # all_targs = gather_features(targs, None, accelerator)
131
+ # clip_clip = (targs @ all_targs.T)/temp
132
+ # brain_clip = (preds @ all_targs.T)/temp
133
+
134
+ loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
135
+ loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
136
+
137
+ loss = (loss1 + loss2)/2
138
+ return loss
139
+
140
+ def mixco(voxels, beta=0.15, s_thresh=0.5):
141
+ perm = torch.randperm(voxels.shape[0])
142
+ voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
143
+ betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
144
+ select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
145
+ betas_shape = [-1] + [1]*(len(voxels.shape)-1)
146
+ voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
147
+ voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
148
+ betas[~select] = 1
149
+ return voxels, perm, betas, select
150
+
151
+ def mixco_clip_target(clip_target, perm, select, betas):
152
+ clip_target_shuffle = clip_target[perm]
153
+ clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \
154
+ clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1)
155
+ return clip_target
156
+
157
+ def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
158
+ accelerator=None, local_rank=None, bidirectional=True):
159
+ brain_clip = (preds @ targs.T)/temp
160
+
161
+ if perm is not None and betas is not None and select is not None:
162
+ probs = torch.diag(betas)
163
+ probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
164
+
165
+ loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
166
+ if bidirectional:
167
+ loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
168
+ loss = (loss + loss2)/2
169
+ return loss
170
+ else:
171
+ loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
172
+ if bidirectional:
173
+ loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
174
+ loss = (loss + loss2)/2
175
+ return loss
176
+
177
+ def count_params(model):
178
+ total = sum(p.numel() for p in model.parameters())
179
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
180
+ print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
181
+
182
+ def image_grid(imgs, rows, cols):
183
+ w, h = imgs[0].size
184
+ grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
185
+ for i, img in enumerate(imgs):
186
+ grid.paste(img, box=(i%cols*w, i//cols*h))
187
+ return grid
188
+
189
+ def check_loss(loss):
190
+ if loss.isnan().any():
191
+ raise ValueError('NaN loss')
192
+
193
+ def cosine_anneal(start, end, steps):
194
+ return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
195
+
196
+ import braceexpand
197
+ def get_dataloaders(
198
+ batch_size,
199
+ image_var='images',
200
+ num_devices=None,
201
+ num_workers=None,
202
+ train_url=None,
203
+ val_url=None,
204
+ meta_url=None,
205
+ num_train=None,
206
+ num_val=None,
207
+ cache_dir="/scratch/tmp/wds-cache",
208
+ seed=0,
209
+ voxels_key="nsdgeneral.npy",
210
+ val_batch_size=None,
211
+ to_tuple=["voxels", "images", "trial"],
212
+ local_rank=0,
213
+ world_size=1,
214
+ ):
215
+ print("Getting dataloaders...")
216
+ assert image_var == 'images'
217
+
218
+ def my_split_by_node(urls):
219
+ return urls
220
+
221
+ train_url = list(braceexpand.braceexpand(train_url))
222
+ val_url = list(braceexpand.braceexpand(val_url))
223
+
224
+ if num_devices is None:
225
+ num_devices = torch.cuda.device_count()
226
+
227
+ if num_workers is None:
228
+ num_workers = num_devices
229
+
230
+ if num_train is None:
231
+ metadata = json.load(open(meta_url))
232
+ num_train = metadata['totals']['train']
233
+ if num_val is None:
234
+ metadata = json.load(open(meta_url))
235
+ num_val = metadata['totals']['val']
236
+
237
+ if val_batch_size is None:
238
+ val_batch_size = batch_size
239
+
240
+ global_batch_size = batch_size * num_devices
241
+ num_batches = math.floor(num_train / global_batch_size)
242
+ num_worker_batches = math.floor(num_batches / num_workers)
243
+ if num_worker_batches == 0: num_worker_batches = 1
244
+
245
+ print("\nnum_train",num_train)
246
+ print("global_batch_size",global_batch_size)
247
+ print("batch_size",batch_size)
248
+ print("num_workers",num_workers)
249
+ print("num_batches",num_batches)
250
+ print("num_worker_batches", num_worker_batches)
251
+
252
+ # train_url = train_url[local_rank:world_size]
253
+ train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
254
+ .shuffle(500, initial=500, rng=random.Random(42))\
255
+ .decode("torch")\
256
+ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
257
+ .to_tuple(*to_tuple)#\
258
+ # .batched(batch_size, partial=True)#\
259
+ # .with_epoch(num_worker_batches)
260
+
261
+ # BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False)
262
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False)
263
+
264
+ # Validation
265
+ print("val_batch_size",val_batch_size)
266
+ val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
267
+ .shuffle(500, initial=500, rng=random.Random(42))\
268
+ .decode("torch")\
269
+ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
270
+ .to_tuple(*to_tuple)#\
271
+ # .batched(val_batch_size, partial=True)
272
+ val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True)
273
+
274
+ return train_dl, val_dl, num_train, num_val