Upload folder using huggingface_hub
Browse files- .gitignore +164 -0
- LICENSE +21 -0
- README.md +11 -0
- src/Train.ipynb +0 -0
- src/Train.py +708 -0
- src/accel.slurm +38 -0
- src/checking_models.ipynb +1526 -0
- src/cnd_prov/cnd_prov-Copy1.py +148 -0
- src/cnd_prov/cnd_prov.py +138 -0
- src/cnd_prov/data.pkl +3 -0
- src/deepspeed_config_stage1.json +1 -0
- src/deepspeed_config_stage2.json +1 -0
- src/deepspeed_config_stage3.json +44 -0
- src/getdepthimages.ipynb +0 -0
- src/huggingface_to_s3.ipynb +422 -0
- src/models.py +1344 -0
- src/setup.sh +14 -0
- src/utils.py +274 -0
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb/
|
2 |
+
train_logs/
|
3 |
+
slurms/
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 MedARC
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MindEyeV2
|
2 |
+
|
3 |
+
In-progress
|
4 |
+
|
5 |
+
1. Download all of https://huggingface.co/datasets/pscotti/mindeyev2 and place them in a folder. You will need to specify the path to this folder as "data_path" variable.
|
6 |
+
|
7 |
+
2. Run setup.sh to install a new "fmri" conda environment.
|
8 |
+
|
9 |
+
3. Activate the conda environment with "conda activate fmri"
|
10 |
+
|
11 |
+
4. Run Train.ipynb or Train.py (they are the same code)
|
src/Train.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/Train.py
ADDED
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
# # Code to convert this notebook to .py if you want to run it via command line or with Slurm
|
8 |
+
# from subprocess import call
|
9 |
+
# command = "jupyter nbconvert Train.ipynb --to python"
|
10 |
+
# call(command,shell=True)
|
11 |
+
|
12 |
+
|
13 |
+
# # Import packages & functions
|
14 |
+
|
15 |
+
# In[2]:
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import json
|
21 |
+
import argparse
|
22 |
+
import numpy as np
|
23 |
+
import time
|
24 |
+
import random
|
25 |
+
import h5py
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
import webdataset as wds
|
29 |
+
import gc
|
30 |
+
|
31 |
+
import matplotlib.pyplot as plt
|
32 |
+
import torch
|
33 |
+
import torch.nn as nn
|
34 |
+
from torchvision import transforms
|
35 |
+
|
36 |
+
# tf32 data type is faster than standard float32
|
37 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
38 |
+
|
39 |
+
# custom functions #
|
40 |
+
import utils
|
41 |
+
|
42 |
+
|
43 |
+
# In[ ]:
|
44 |
+
|
45 |
+
|
46 |
+
local_rank = os.getenv('RANK')
|
47 |
+
if local_rank is None:
|
48 |
+
local_rank = 0
|
49 |
+
else:
|
50 |
+
local_rank = int(local_rank)
|
51 |
+
print("LOCAL RANK ", local_rank)
|
52 |
+
|
53 |
+
### Single-GPU config ###
|
54 |
+
## Feel free to uncomment the below 4 lines and comment out all the multi-gpu config code to simplify things for single-gpu
|
55 |
+
# from accelerate import Accelerator
|
56 |
+
# num_devices = torch.cuda.device_count()
|
57 |
+
# if num_devices==0: num_devices = 1
|
58 |
+
# accelerator = Accelerator(split_batches=False)
|
59 |
+
# global_batch_size = 128
|
60 |
+
|
61 |
+
### Multi-GPU config ###
|
62 |
+
from accelerate import Accelerator, DeepSpeedPlugin
|
63 |
+
num_devices = torch.cuda.device_count()
|
64 |
+
if num_devices==0: num_devices = 1
|
65 |
+
if num_devices <= 1 and utils.is_interactive():
|
66 |
+
# can emulate a distributed environment for deepspeed to work in jupyter notebook
|
67 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
68 |
+
os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
|
69 |
+
os.environ["RANK"] = "0"
|
70 |
+
os.environ["LOCAL_RANK"] = "0"
|
71 |
+
os.environ["WORLD_SIZE"] = "1"
|
72 |
+
os.environ["GLOBAL_BATCH_SIZE"] = "128" # set this to your batch size!
|
73 |
+
global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
|
74 |
+
|
75 |
+
# alter the deepspeed config according to your global and local batch size
|
76 |
+
if local_rank == 0:
|
77 |
+
with open('deepspeed_config_stage2.json', 'r') as file:
|
78 |
+
config = json.load(file)
|
79 |
+
config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
|
80 |
+
config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
|
81 |
+
with open('deepspeed_config_stage2.json', 'w') as file:
|
82 |
+
json.dump(config, file)
|
83 |
+
else:
|
84 |
+
# give some time for the local_rank=0 gpu to prep new deepspeed config file
|
85 |
+
time.sleep(10)
|
86 |
+
deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
|
87 |
+
accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
|
88 |
+
|
89 |
+
|
90 |
+
# In[ ]:
|
91 |
+
|
92 |
+
|
93 |
+
print("PID of this process =",os.getpid())
|
94 |
+
device = accelerator.device
|
95 |
+
print("device:",device)
|
96 |
+
num_workers = num_devices
|
97 |
+
print(accelerator.state)
|
98 |
+
world_size = accelerator.state.num_processes
|
99 |
+
distributed = not accelerator.state.distributed_type == 'NO'
|
100 |
+
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
|
101 |
+
print = accelerator.print # only print if local_rank=0
|
102 |
+
|
103 |
+
|
104 |
+
# # Configurations
|
105 |
+
|
106 |
+
# In[3]:
|
107 |
+
|
108 |
+
|
109 |
+
# if running this interactively, can specify jupyter_args here for argparser to use
|
110 |
+
if utils.is_interactive():
|
111 |
+
# Example use
|
112 |
+
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
|
113 |
+
--model_name=test \
|
114 |
+
--subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
|
115 |
+
--max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
|
116 |
+
|
117 |
+
jupyter_args = jupyter_args.split()
|
118 |
+
print(jupyter_args)
|
119 |
+
|
120 |
+
from IPython.display import clear_output # function to clear print outputs in cell
|
121 |
+
get_ipython().run_line_magic('load_ext', 'autoreload')
|
122 |
+
# this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
|
123 |
+
get_ipython().run_line_magic('autoreload', '2')
|
124 |
+
|
125 |
+
|
126 |
+
# In[4]:
|
127 |
+
|
128 |
+
|
129 |
+
parser = argparse.ArgumentParser(description="Model Training Configuration")
|
130 |
+
parser.add_argument(
|
131 |
+
"--model_name", type=str, default="testing",
|
132 |
+
help="name of model, used for ckpt saving and wandb logging (if enabled)",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
|
136 |
+
help="Path to where NSD data is stored / where to download it to",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--subj",type=int, default=1, choices=[1,2,5,7],
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--batch_size", type=int, default=32,
|
143 |
+
help="Batch size can be increased by 10x if only training v2c and not diffusion prior",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--wandb_log",action=argparse.BooleanOptionalAction,default=False,
|
147 |
+
help="whether to log to wandb",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
|
151 |
+
help="if not using wandb and want to resume from a ckpt",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--wandb_project",type=str,default="stability",
|
155 |
+
help="wandb project name",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--mixup_pct",type=float,default=.33,
|
159 |
+
help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
|
163 |
+
help="whether to use image augmentation",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--num_epochs",type=int,default=240,
|
167 |
+
help="number of epochs of training",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--ckpt_interval",type=int,default=5,
|
177 |
+
help="save backup ckpt and reconstruct every x epochs",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--seed",type=int,default=42,
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--max_lr",type=float,default=3e-4,
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--n_samples_save",type=int,default=0,choices=[0,1],
|
187 |
+
help="Number of reconstructions for monitoring progress, 0 will speed up training",
|
188 |
+
)
|
189 |
+
|
190 |
+
if utils.is_interactive():
|
191 |
+
args = parser.parse_args(jupyter_args)
|
192 |
+
else:
|
193 |
+
args = parser.parse_args()
|
194 |
+
|
195 |
+
# create global variables without the args prefix
|
196 |
+
for attribute_name in vars(args).keys():
|
197 |
+
globals()[attribute_name] = getattr(args, attribute_name)
|
198 |
+
|
199 |
+
print("global batch_size", batch_size)
|
200 |
+
batch_size = int(batch_size / num_devices)
|
201 |
+
print("batch_size", batch_size)
|
202 |
+
|
203 |
+
|
204 |
+
# In[5]:
|
205 |
+
|
206 |
+
|
207 |
+
outdir = os.path.abspath(f'../train_logs/{model_name}')
|
208 |
+
if not os.path.exists(outdir):
|
209 |
+
os.makedirs(outdir,exist_ok=True)
|
210 |
+
if use_image_aug:
|
211 |
+
import kornia
|
212 |
+
from kornia.augmentation.container import AugmentationSequential
|
213 |
+
img_augment = AugmentationSequential(
|
214 |
+
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
|
215 |
+
kornia.augmentation.Resize((224, 224)),
|
216 |
+
kornia.augmentation.RandomHorizontalFlip(p=0.3),
|
217 |
+
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
|
218 |
+
kornia.augmentation.RandomGrayscale(p=0.3),
|
219 |
+
same_on_batch=False,
|
220 |
+
data_keys=["input"],
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
+
# # Prep data, models, and dataloaders
|
225 |
+
|
226 |
+
# ## Dataloader
|
227 |
+
|
228 |
+
# In[6]:
|
229 |
+
|
230 |
+
|
231 |
+
if subj==1:
|
232 |
+
num_train = 24958
|
233 |
+
num_test = 2770
|
234 |
+
test_batch_size = num_test
|
235 |
+
|
236 |
+
def my_split_by_node(urls): return urls
|
237 |
+
|
238 |
+
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
|
239 |
+
print(train_url)
|
240 |
+
|
241 |
+
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
|
242 |
+
.shuffle(750, initial=1500, rng=random.Random(42))\
|
243 |
+
.decode("torch")\
|
244 |
+
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
|
245 |
+
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
|
246 |
+
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
|
247 |
+
|
248 |
+
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
|
249 |
+
print(test_url)
|
250 |
+
|
251 |
+
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
|
252 |
+
.shuffle(750, initial=1500, rng=random.Random(42))\
|
253 |
+
.decode("torch")\
|
254 |
+
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
|
255 |
+
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
|
256 |
+
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
|
257 |
+
|
258 |
+
|
259 |
+
# ### check dataloaders are working
|
260 |
+
|
261 |
+
# In[7]:
|
262 |
+
|
263 |
+
|
264 |
+
# test_indices = []
|
265 |
+
# test_images = []
|
266 |
+
# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
|
267 |
+
# test_indices = np.append(test_indices, behav[:,0,5].numpy())
|
268 |
+
# test_images = np.append(test_images, behav[:,0,0].numpy())
|
269 |
+
# test_indices = test_indices.astype(np.int16)
|
270 |
+
# print(test_i, (test_i+1) * test_batch_size, len(test_indices))
|
271 |
+
# print("---\n")
|
272 |
+
|
273 |
+
# train_indices = []
|
274 |
+
# train_images = []
|
275 |
+
# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
|
276 |
+
# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
|
277 |
+
# train_images = np.append(train_images, behav[:,0,0].numpy())
|
278 |
+
# train_indices = train_indices.astype(np.int16)
|
279 |
+
# print(train_i, (train_i+1) * batch_size, len(train_indices))
|
280 |
+
|
281 |
+
|
282 |
+
# ## Load voxel betas, K-means clustering model, and images
|
283 |
+
|
284 |
+
# In[8]:
|
285 |
+
|
286 |
+
|
287 |
+
# load betas
|
288 |
+
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
|
289 |
+
voxels = f['betas'][:]
|
290 |
+
print(f"subj0{subj} betas loaded into memory")
|
291 |
+
voxels = torch.Tensor(voxels).to("cpu").half()
|
292 |
+
if subj==1:
|
293 |
+
voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
|
294 |
+
print("voxels", voxels.shape)
|
295 |
+
num_voxels = voxels.shape[-1]
|
296 |
+
|
297 |
+
# load orig images
|
298 |
+
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
|
299 |
+
images = f['images'][:]
|
300 |
+
images = torch.Tensor(images).to("cpu").half()
|
301 |
+
print("images", images.shape)
|
302 |
+
|
303 |
+
|
304 |
+
# In[9]:
|
305 |
+
|
306 |
+
|
307 |
+
from models import Clipper
|
308 |
+
eva02_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
|
309 |
+
|
310 |
+
clip_seq_dim = 257
|
311 |
+
clip_emb_dim = 768
|
312 |
+
hidden_dim = 4096
|
313 |
+
|
314 |
+
|
315 |
+
# In[10]:
|
316 |
+
|
317 |
+
|
318 |
+
class MindEyeModule(nn.Module):
|
319 |
+
def __init__(self):
|
320 |
+
super(MindEyeModule, self).__init__()
|
321 |
+
def forward(self, x):
|
322 |
+
return x
|
323 |
+
|
324 |
+
model = MindEyeModule()
|
325 |
+
model
|
326 |
+
|
327 |
+
|
328 |
+
# In[11]:
|
329 |
+
|
330 |
+
|
331 |
+
class RidgeRegression(torch.nn.Module):
|
332 |
+
# make sure to add weight_decay when initializing optimizer
|
333 |
+
def __init__(self, input_size, out_features):
|
334 |
+
super(RidgeRegression, self).__init__()
|
335 |
+
self.linear = torch.nn.Linear(input_size, out_features)
|
336 |
+
def forward(self, x):
|
337 |
+
return self.linear(x)
|
338 |
+
|
339 |
+
model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
|
340 |
+
utils.count_params(model.ridge)
|
341 |
+
utils.count_params(model)
|
342 |
+
|
343 |
+
b = torch.randn((2,voxels.shape[1]))
|
344 |
+
print(b.shape, model.ridge(b).shape)
|
345 |
+
|
346 |
+
|
347 |
+
# In[12]:
|
348 |
+
|
349 |
+
|
350 |
+
from functools import partial
|
351 |
+
class BrainNetwork(nn.Module):
|
352 |
+
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):
|
353 |
+
super().__init__()
|
354 |
+
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
|
355 |
+
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
|
356 |
+
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
|
357 |
+
self.mlp = nn.ModuleList([
|
358 |
+
nn.Sequential(
|
359 |
+
nn.Linear(h, h),
|
360 |
+
*[item() for item in act_and_norm],
|
361 |
+
nn.Dropout(drop2)
|
362 |
+
) for _ in range(n_blocks)
|
363 |
+
])
|
364 |
+
self.lin1 = nn.Linear(h, out_dim, bias=True)
|
365 |
+
self.n_blocks = n_blocks
|
366 |
+
self.clip_size = clip_size
|
367 |
+
self.use_projector = use_projector
|
368 |
+
if use_projector:
|
369 |
+
self.projector = nn.Sequential(
|
370 |
+
nn.LayerNorm(clip_size),
|
371 |
+
nn.GELU(),
|
372 |
+
nn.Linear(clip_size, 2048),
|
373 |
+
nn.LayerNorm(2048),
|
374 |
+
nn.GELU(),
|
375 |
+
nn.Linear(2048, 2048),
|
376 |
+
nn.LayerNorm(2048),
|
377 |
+
nn.GELU(),
|
378 |
+
nn.Linear(2048, clip_size)
|
379 |
+
)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
residual = x
|
383 |
+
for res_block in range(self.n_blocks):
|
384 |
+
x = self.mlp[res_block](x)
|
385 |
+
x += residual
|
386 |
+
residual = x
|
387 |
+
x = x.reshape(len(x), -1)
|
388 |
+
x = self.lin1(x)
|
389 |
+
if self.use_projector:
|
390 |
+
return self.projector(x.reshape(len(x), -1, self.clip_size))
|
391 |
+
return x
|
392 |
+
|
393 |
+
model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)
|
394 |
+
utils.count_params(model.backbone)
|
395 |
+
utils.count_params(model)
|
396 |
+
|
397 |
+
b = torch.randn((2,hidden_dim))
|
398 |
+
print(b.shape, model.backbone(b).shape)
|
399 |
+
|
400 |
+
|
401 |
+
# In[13]:
|
402 |
+
|
403 |
+
|
404 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
405 |
+
opt_grouped_parameters = [
|
406 |
+
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
|
407 |
+
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
|
408 |
+
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
|
409 |
+
]
|
410 |
+
|
411 |
+
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
|
412 |
+
|
413 |
+
if lr_scheduler_type == 'linear':
|
414 |
+
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
415 |
+
optimizer,
|
416 |
+
total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
|
417 |
+
last_epoch=-1
|
418 |
+
)
|
419 |
+
elif lr_scheduler_type == 'cycle':
|
420 |
+
total_steps=int(num_epochs*(num_train*num_devices//batch_size))
|
421 |
+
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
422 |
+
optimizer,
|
423 |
+
max_lr=max_lr,
|
424 |
+
total_steps=total_steps,
|
425 |
+
final_div_factor=1000,
|
426 |
+
last_epoch=-1, pct_start=2/num_epochs
|
427 |
+
)
|
428 |
+
|
429 |
+
def save_ckpt(tag):
|
430 |
+
ckpt_path = outdir+f'/{tag}.pth'
|
431 |
+
print(f'saving {ckpt_path}',flush=True)
|
432 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
433 |
+
try:
|
434 |
+
torch.save({
|
435 |
+
'epoch': epoch,
|
436 |
+
'model_state_dict': unwrapped_model.state_dict(),
|
437 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
438 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
439 |
+
'train_losses': losses,
|
440 |
+
'test_losses': test_losses,
|
441 |
+
'lrs': lrs,
|
442 |
+
}, ckpt_path)
|
443 |
+
except:
|
444 |
+
print("Couldn't save... moving on to prevent crashing.")
|
445 |
+
del unwrapped_model
|
446 |
+
|
447 |
+
print("\nDone with model preparations!")
|
448 |
+
|
449 |
+
|
450 |
+
# # Weights and Biases
|
451 |
+
|
452 |
+
# In[14]:
|
453 |
+
|
454 |
+
|
455 |
+
# params for wandb
|
456 |
+
if local_rank==0 and wandb_log: # only use main process for wandb logging
|
457 |
+
import wandb
|
458 |
+
|
459 |
+
wandb_project = 'stability'
|
460 |
+
wandb_run = model_name
|
461 |
+
wandb_notes = ''
|
462 |
+
|
463 |
+
print(f"wandb {wandb_project} run {wandb_run}")
|
464 |
+
wandb.login(host='https://stability.wandb.io')#, relogin=True)
|
465 |
+
wandb_config = {
|
466 |
+
"model_name": model_name,
|
467 |
+
"clip_variant": clip_variant,
|
468 |
+
"batch_size": batch_size,
|
469 |
+
"num_epochs": num_epochs,
|
470 |
+
"use_image_aug": use_image_aug,
|
471 |
+
"max_lr": max_lr,
|
472 |
+
"lr_scheduler_type": lr_scheduler_type,
|
473 |
+
"mixup_pct": mixup_pct,
|
474 |
+
"num_train": num_train,
|
475 |
+
"num_test": num_test,
|
476 |
+
"seed": seed,
|
477 |
+
"distributed": distributed,
|
478 |
+
"num_devices": num_devices,
|
479 |
+
"world_size": world_size,
|
480 |
+
}
|
481 |
+
print("wandb_config:\n",wandb_config)
|
482 |
+
if True: # wandb_auto_resume
|
483 |
+
print("wandb_id:",model_name)
|
484 |
+
wandb.init(
|
485 |
+
id = model_name,
|
486 |
+
project=wandb_project,
|
487 |
+
name=wandb_run,
|
488 |
+
config=wandb_config,
|
489 |
+
notes=wandb_notes,
|
490 |
+
resume="allow",
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
wandb.init(
|
494 |
+
project=wandb_project,
|
495 |
+
name=wandb_run,
|
496 |
+
config=wandb_config,
|
497 |
+
notes=wandb_notes,
|
498 |
+
)
|
499 |
+
else:
|
500 |
+
wandb_log = False
|
501 |
+
|
502 |
+
|
503 |
+
# # Main
|
504 |
+
|
505 |
+
# In[15]:
|
506 |
+
|
507 |
+
|
508 |
+
epoch = 0
|
509 |
+
losses, test_losses, lrs = [], [], []
|
510 |
+
best_test_loss = 1e9
|
511 |
+
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
|
512 |
+
|
513 |
+
# Optionally resume from checkpoint #
|
514 |
+
if resume_from_ckpt:
|
515 |
+
print("\n---resuming from last.pth ckpt---\n")
|
516 |
+
try:
|
517 |
+
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
|
518 |
+
except:
|
519 |
+
print('last.pth failed... trying last_backup.pth')
|
520 |
+
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
|
521 |
+
epoch = checkpoint['epoch']
|
522 |
+
print("Epoch",epoch)
|
523 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
524 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
525 |
+
diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
|
526 |
+
del checkpoint
|
527 |
+
elif wandb_log:
|
528 |
+
if wandb.run.resumed:
|
529 |
+
print("\n---resuming from last.pth ckpt---\n")
|
530 |
+
try:
|
531 |
+
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
|
532 |
+
except:
|
533 |
+
print('last.pth failed... trying last_backup.pth')
|
534 |
+
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
|
535 |
+
epoch = checkpoint['epoch']
|
536 |
+
print("Epoch",epoch)
|
537 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
538 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
539 |
+
diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
|
540 |
+
del checkpoint
|
541 |
+
torch.cuda.empty_cache()
|
542 |
+
|
543 |
+
|
544 |
+
# In[16]:
|
545 |
+
|
546 |
+
|
547 |
+
model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
|
548 |
+
model, optimizer, train_dl, test_dl, lr_scheduler
|
549 |
+
)
|
550 |
+
|
551 |
+
|
552 |
+
# In[17]:
|
553 |
+
|
554 |
+
|
555 |
+
print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
|
556 |
+
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
|
557 |
+
test_image, test_voxel = None, None
|
558 |
+
mse = nn.MSELoss()
|
559 |
+
for epoch in progress_bar:
|
560 |
+
model.train()
|
561 |
+
|
562 |
+
fwd_percent_correct = 0.
|
563 |
+
bwd_percent_correct = 0.
|
564 |
+
test_fwd_percent_correct = 0.
|
565 |
+
test_bwd_percent_correct = 0.
|
566 |
+
|
567 |
+
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
|
568 |
+
with torch.cuda.amp.autocast():
|
569 |
+
optimizer.zero_grad()
|
570 |
+
|
571 |
+
voxel = voxels[behav[:,0,5].cpu().long()].to(device)
|
572 |
+
image = images[behav[:,0,0].cpu().long()].to(device)
|
573 |
+
|
574 |
+
if use_image_aug: image = img_augment(image)
|
575 |
+
|
576 |
+
clip_target = eva02_model.embed_image(image.float())
|
577 |
+
assert not torch.any(torch.isnan(clip_target))
|
578 |
+
|
579 |
+
if epoch < int(mixup_pct * num_epochs):
|
580 |
+
voxel, perm, betas, select = utils.mixco(voxel)
|
581 |
+
|
582 |
+
voxel_ridge = model.ridge(voxel)
|
583 |
+
|
584 |
+
clip_voxels = model.backbone(voxel_ridge)
|
585 |
+
|
586 |
+
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
|
587 |
+
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
|
588 |
+
|
589 |
+
if epoch < int(mixup_pct * num_epochs):
|
590 |
+
loss_clip = utils.mixco_nce(
|
591 |
+
clip_voxels_norm,
|
592 |
+
clip_target_norm,
|
593 |
+
temp=.006,
|
594 |
+
perm=perm, betas=betas, select=select)
|
595 |
+
else:
|
596 |
+
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
|
597 |
+
loss_clip = utils.soft_clip_loss(
|
598 |
+
clip_voxels_norm,
|
599 |
+
clip_target_norm,
|
600 |
+
temp=epoch_temp)
|
601 |
+
|
602 |
+
loss = loss_clip
|
603 |
+
|
604 |
+
utils.check_loss(loss)
|
605 |
+
|
606 |
+
accelerator.backward(loss)
|
607 |
+
optimizer.step()
|
608 |
+
|
609 |
+
losses.append(loss.item())
|
610 |
+
lrs.append(optimizer.param_groups[0]['lr'])
|
611 |
+
|
612 |
+
# forward and backward top 1 accuracy
|
613 |
+
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
|
614 |
+
fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
|
615 |
+
bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
|
616 |
+
|
617 |
+
if lr_scheduler_type is not None:
|
618 |
+
lr_scheduler.step()
|
619 |
+
|
620 |
+
model.eval()
|
621 |
+
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
|
622 |
+
with torch.no_grad():
|
623 |
+
with torch.cuda.amp.autocast():
|
624 |
+
# all test samples should be loaded per batch such that test_i should never exceed 0
|
625 |
+
if len(behav) != num_test: print("!",len(behav),num_test)
|
626 |
+
|
627 |
+
## Average same-image repeats ##
|
628 |
+
if test_image is None:
|
629 |
+
voxel = voxels[behav[:,0,5].cpu().long()]
|
630 |
+
image = behav[:,0,0].cpu().long()
|
631 |
+
|
632 |
+
unique_image, sort_indices = torch.unique(image, return_inverse=True)
|
633 |
+
for im in unique_image:
|
634 |
+
locs = torch.where(im == image)[0]
|
635 |
+
if test_image is None:
|
636 |
+
test_image = images[im][None]
|
637 |
+
test_voxel = torch.mean(voxel[locs],axis=0)[None]
|
638 |
+
else:
|
639 |
+
test_image = torch.vstack((test_image, images[im][None]))
|
640 |
+
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
|
641 |
+
|
642 |
+
# random sample of 300
|
643 |
+
random_indices = torch.randperm(len(test_voxel))[:300]
|
644 |
+
voxel = test_voxel[random_indices].to(device)
|
645 |
+
image = test_image[random_indices].to(device)
|
646 |
+
assert len(image) == 300
|
647 |
+
|
648 |
+
clip_target = eva02_model.embed_image(image.float())
|
649 |
+
|
650 |
+
voxel_ridge = model.ridge(voxel)
|
651 |
+
|
652 |
+
clip_voxels = model.backbone(voxel_ridge)
|
653 |
+
|
654 |
+
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
|
655 |
+
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
|
656 |
+
|
657 |
+
loss_clip = utils.soft_clip_loss(
|
658 |
+
clip_voxels_norm,
|
659 |
+
clip_target_norm,
|
660 |
+
temp=.006)
|
661 |
+
|
662 |
+
loss = loss_clip
|
663 |
+
|
664 |
+
utils.check_loss(loss)
|
665 |
+
|
666 |
+
test_losses.append(loss.item())
|
667 |
+
|
668 |
+
# forward and backward top 1 accuracy
|
669 |
+
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
|
670 |
+
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
|
671 |
+
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
|
672 |
+
|
673 |
+
if local_rank==0:
|
674 |
+
if utils.is_interactive():
|
675 |
+
# clear_output(wait=True)
|
676 |
+
print("---")
|
677 |
+
|
678 |
+
assert (test_i+1) == 1
|
679 |
+
logs = {"train/loss": np.mean(losses[-(train_i+1):]),
|
680 |
+
"test/loss": np.mean(test_losses[-(test_i+1):]),
|
681 |
+
"train/lr": lrs[-1],
|
682 |
+
"train/num_steps": len(losses),
|
683 |
+
"test/num_steps": len(test_losses),
|
684 |
+
"train/fwd_pct_correct": fwd_percent_correct.item() / (train_i + 1),
|
685 |
+
"train/bwd_pct_correct": bwd_percent_correct.item() / (train_i + 1),
|
686 |
+
"test/test_fwd_pct_correct": test_fwd_percent_correct.item() / (test_i + 1),
|
687 |
+
"test/test_bwd_pct_correct": test_bwd_percent_correct.item() / (test_i + 1),
|
688 |
+
}
|
689 |
+
progress_bar.set_postfix(**logs)
|
690 |
+
|
691 |
+
# Save model checkpoint and reconstruct
|
692 |
+
if epoch % ckpt_interval == 0:
|
693 |
+
if not utils.is_interactive():
|
694 |
+
save_ckpt(f'last')
|
695 |
+
|
696 |
+
if wandb_log: wandb.log(logs)
|
697 |
+
|
698 |
+
# wait for other GPUs to catch up if needed
|
699 |
+
accelerator.wait_for_everyone()
|
700 |
+
torch.cuda.empty_cache()
|
701 |
+
gc.collect()
|
702 |
+
|
703 |
+
print("\n===Finished!===\n")
|
704 |
+
if ckpt_saving:
|
705 |
+
save_ckpt(f'last')
|
706 |
+
if not utils.is_interactive():
|
707 |
+
sys.exit(0)
|
708 |
+
|
src/accel.slurm
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --account=topfmri
|
3 |
+
#SBATCH --partition=g40x
|
4 |
+
#SBATCH --job-name=ms
|
5 |
+
#SBATCH --nodes=1
|
6 |
+
#SBATCH --ntasks-per-node=4 # should = number of gpus
|
7 |
+
#SBATCH --gres=gpu:4
|
8 |
+
#SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
|
9 |
+
#SBATCH -e slurms/%j.err
|
10 |
+
#SBATCH -o slurms/%j.out
|
11 |
+
#SBATCH --comment=topfmri
|
12 |
+
|
13 |
+
module load cuda/11.7 # should match torch.cuda.version
|
14 |
+
|
15 |
+
export NUM_GPUS=4 # Set to equal gres=gpu:#
|
16 |
+
export GLOBAL_BATCH_SIZE=512
|
17 |
+
|
18 |
+
# Make sure another job doesnt use same port, here using random number
|
19 |
+
export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
|
20 |
+
|
21 |
+
export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
22 |
+
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
|
23 |
+
export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
|
24 |
+
|
25 |
+
export WANDB_DIR="/fsx/proj-fmri/paulscotti/MindEyeV2/wandb/"
|
26 |
+
export WANDB_CACHE_DIR="/fsx/home-paulscotti/.cache"
|
27 |
+
export WANDB_MODE="online"
|
28 |
+
|
29 |
+
echo MASTER_ADDR=${MASTER_ADDR}
|
30 |
+
echo MASTER_PORT=${MASTER_PORT}
|
31 |
+
echo WORLD_SIZE=${COUNT_NODE}
|
32 |
+
|
33 |
+
###########
|
34 |
+
|
35 |
+
cd /fsx/proj-fmri/paulscotti/MindEyeV2
|
36 |
+
accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=test --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=240 --ckpt_interval=999 --no-use_image_aug
|
37 |
+
|
38 |
+
# --wandb_log
|
src/checking_models.ipynb
ADDED
@@ -0,0 +1,1526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 25,
|
6 |
+
"id": "ef9e1556-7840-4004-b181-a2c97ac2ab17",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import os\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import torch.nn as nn\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"import matplotlib.pyplot as plt"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"id": "b6f12dd4-f3aa-4981-b604-b72e67229011",
|
20 |
+
"metadata": {},
|
21 |
+
"source": [
|
22 |
+
"# DinoV2"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 26,
|
28 |
+
"id": "2a604617-b602-4503-b288-e9828684505e",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [
|
31 |
+
{
|
32 |
+
"name": "stderr",
|
33 |
+
"output_type": "stream",
|
34 |
+
"text": [
|
35 |
+
"Using cache found in /fsx/proj-fmri/shared/cache/dinov2/hub/facebookresearch_dinov2_main\n"
|
36 |
+
]
|
37 |
+
}
|
38 |
+
],
|
39 |
+
"source": [
|
40 |
+
"# need to change TORCH_HOME env variable to specify pretrained model should go in shared folder, not home directory\n",
|
41 |
+
"os.environ['TORCH_HOME'] = '/fsx/proj-fmri/shared/cache/dinov2'\n",
|
42 |
+
"dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')\n",
|
43 |
+
"# remove initial image patching\n",
|
44 |
+
"dinov2_model.patch_embed = nn.Identity()\n",
|
45 |
+
"dinov2_model.patch_embed = nn.Identity()"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 27,
|
51 |
+
"id": "32da913d-d931-4967-a5e8-bd40c21d1ad9",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"torch.Size([2, 33, 1024])\n"
|
59 |
+
]
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"dinov2_model.to(\"cuda\")\n",
|
64 |
+
"input = torch.randn((2,33,1024)).to(\"cuda\")\n",
|
65 |
+
"\n",
|
66 |
+
"for block in dinov2_model.blocks: input = block(input)\n",
|
67 |
+
"input = dinov2_model.norm(input)\n",
|
68 |
+
"\n",
|
69 |
+
"print(input.shape)"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "markdown",
|
74 |
+
"id": "febe89c0-06d0-4309-b378-a8d58b99bf4c",
|
75 |
+
"metadata": {},
|
76 |
+
"source": [
|
77 |
+
"# eva"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 28,
|
83 |
+
"id": "690204d0-13d7-452b-97af-14d144800e81",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [],
|
86 |
+
"source": [
|
87 |
+
"from urllib.request import urlopen\n",
|
88 |
+
"from PIL import Image\n",
|
89 |
+
"import timm\n",
|
90 |
+
"\n",
|
91 |
+
"img = Image.open(urlopen(\n",
|
92 |
+
" 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'\n",
|
93 |
+
"))\n",
|
94 |
+
"\n",
|
95 |
+
"model = timm.create_model(\n",
|
96 |
+
" \"eva02_enormous_patch14_clip_224.laion2b\",\n",
|
97 |
+
" pretrained=True,\n",
|
98 |
+
" num_classes=0, # remove classifier nn.Linear\n",
|
99 |
+
")\n",
|
100 |
+
"model = model.eval()"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 39,
|
106 |
+
"id": "035e3e9d-86c9-4ddf-b760-7b78dded7d2e",
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [
|
109 |
+
{
|
110 |
+
"ename": "ValueError",
|
111 |
+
"evalue": "You have to specify pixel_values",
|
112 |
+
"output_type": "error",
|
113 |
+
"traceback": [
|
114 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
115 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
116 |
+
"Cell \u001b[0;32mIn[39], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m data_config \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mresolve_model_data_config(model)\n\u001b[1;32m 3\u001b[0m transforms \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mcreate_transform(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_config, is_training\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m----> 5\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransforms\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# output is (batch_size, num_features) shaped tensor\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(output\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# or equivalently (without needing to set num_classes=0)\u001b[39;00m\n",
|
117 |
+
"File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
118 |
+
"File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:1433\u001b[0m, in \u001b[0;36mCLIPSegForImageSegmentation.forward\u001b[0;34m(self, input_ids, pixel_values, conditional_pixel_values, conditional_embeddings, attention_mask, position_ids, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1431\u001b[0m \u001b[38;5;66;03m# step 1: forward the query images through the frozen CLIP vision encoder\u001b[39;00m\n\u001b[1;32m 1432\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1433\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclip\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvision_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1436\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# we need the intermediate hidden states\u001b[39;49;00m\n\u001b[1;32m 1437\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1438\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1439\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclip\u001b[38;5;241m.\u001b[39mvisual_projection(vision_outputs[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 1441\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m vision_outputs\u001b[38;5;241m.\u001b[39mhidden_states \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;28;01melse\u001b[39;00m vision_outputs[\u001b[38;5;241m2\u001b[39m]\n",
|
119 |
+
"File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
120 |
+
"File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:872\u001b[0m, in \u001b[0;36mCLIPSegVisionTransformer.forward\u001b[0;34m(self, pixel_values, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 869\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pixel_values \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou have to specify pixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 874\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(pixel_values)\n\u001b[1;32m 875\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_layrnorm(hidden_states)\n",
|
121 |
+
"\u001b[0;31mValueError\u001b[0m: You have to specify pixel_values"
|
122 |
+
]
|
123 |
+
}
|
124 |
+
],
|
125 |
+
"source": [
|
126 |
+
"# get model specific transforms (normalization, resize)\n",
|
127 |
+
"data_config = timm.data.resolve_model_data_config(model)\n",
|
128 |
+
"transforms = timm.data.create_transform(**data_config, is_training=False)\n",
|
129 |
+
"\n",
|
130 |
+
"output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor\n",
|
131 |
+
"print(output.shape)\n",
|
132 |
+
"# or equivalently (without needing to set num_classes=0)\n",
|
133 |
+
"\n",
|
134 |
+
"output = model.forward_features(transforms(img).unsqueeze(0))\n",
|
135 |
+
"# output is unpooled, a (1, 257, 768) shaped tensor\n",
|
136 |
+
"print(output.shape)\n",
|
137 |
+
"\n",
|
138 |
+
"output = model.forward_head(output, pre_logits=True)\n",
|
139 |
+
"# output is a (1, num_features) shaped tensor\n",
|
140 |
+
"print(output.shape)"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": null,
|
146 |
+
"id": "54275c4c-e506-4959-92f1-29e584f5ce51",
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"model.forward_features(transforms(img).unsqueeze(0)).shape"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "markdown",
|
155 |
+
"id": "6546c673-f3ab-4d43-a051-cab20e782bab",
|
156 |
+
"metadata": {},
|
157 |
+
"source": [
|
158 |
+
"# Eva02-clip"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": 29,
|
164 |
+
"id": "dfbc95de-9af9-4583-98fc-b8061114ef64",
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"import timm \n",
|
169 |
+
"# couldnt figure out how to load pretrained model from shared folder rather than home directory using timm...\n",
|
170 |
+
"eva02_model = timm.create_model(\"eva02_enormous_patch14_clip_224.laion2b\", pretrained=True)\n",
|
171 |
+
"# eva02_model.head_drop = nn.Identity()\n",
|
172 |
+
"# eva02_model.head = nn.Identity()"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 30,
|
178 |
+
"id": "97e3ea29-ae6b-4bd2-b3d7-17839098a6e4",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [
|
181 |
+
{
|
182 |
+
"data": {
|
183 |
+
"text/plain": [
|
184 |
+
"torch.Size([2, 1024])"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
"execution_count": 30,
|
188 |
+
"metadata": {},
|
189 |
+
"output_type": "execute_result"
|
190 |
+
}
|
191 |
+
],
|
192 |
+
"source": [
|
193 |
+
"eva02_model(torch.randn((2,3,224,224))).shape"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": 31,
|
199 |
+
"id": "069b76f0-029f-42b1-85f5-a492ee1cc5d1",
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [
|
202 |
+
{
|
203 |
+
"name": "stdout",
|
204 |
+
"output_type": "stream",
|
205 |
+
"text": [
|
206 |
+
"torch.Size([2, 256, 1024])\n"
|
207 |
+
]
|
208 |
+
}
|
209 |
+
],
|
210 |
+
"source": [
|
211 |
+
"image = torch.randn((2,3,224,224))\n",
|
212 |
+
"\n",
|
213 |
+
"input = eva02_model.patch_embed(image)\n",
|
214 |
+
"input = eva02_model.pos_drop(input)\n",
|
215 |
+
"for block in eva02_model.blocks: input = block(input)\n",
|
216 |
+
"input = eva02_model.norm(input)\n",
|
217 |
+
"input = eva02_model.fc_norm(input)\n",
|
218 |
+
"input = eva02_model.head_drop(input)\n",
|
219 |
+
"input = eva02_model.head(input)\n",
|
220 |
+
"\n",
|
221 |
+
"print(input.shape)"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": 32,
|
227 |
+
"id": "90e4e8e7-3dd1-43b0-a305-066a6ec13c2e",
|
228 |
+
"metadata": {},
|
229 |
+
"outputs": [
|
230 |
+
{
|
231 |
+
"name": "stdout",
|
232 |
+
"output_type": "stream",
|
233 |
+
"text": [
|
234 |
+
"Help on Eva in module timm.models.eva object:\n",
|
235 |
+
"\n",
|
236 |
+
"class Eva(torch.nn.modules.module.Module)\n",
|
237 |
+
" | Eva(img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
|
238 |
+
" | \n",
|
239 |
+
" | Eva Vision Transformer w/ Abs & Rotary Pos Embed\n",
|
240 |
+
" | \n",
|
241 |
+
" | This class implements the EVA and EVA02 models that were based on the BEiT ViT variant\n",
|
242 |
+
" | * EVA - abs pos embed, global avg pool\n",
|
243 |
+
" | * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)\n",
|
244 |
+
" | \n",
|
245 |
+
" | Method resolution order:\n",
|
246 |
+
" | Eva\n",
|
247 |
+
" | torch.nn.modules.module.Module\n",
|
248 |
+
" | builtins.object\n",
|
249 |
+
" | \n",
|
250 |
+
" | Methods defined here:\n",
|
251 |
+
" | \n",
|
252 |
+
" | __init__(self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
|
253 |
+
" | Args:\n",
|
254 |
+
" | img_size:\n",
|
255 |
+
" | patch_size:\n",
|
256 |
+
" | in_chans:\n",
|
257 |
+
" | num_classes:\n",
|
258 |
+
" | global_pool:\n",
|
259 |
+
" | embed_dim:\n",
|
260 |
+
" | depth:\n",
|
261 |
+
" | num_heads:\n",
|
262 |
+
" | qkv_bias:\n",
|
263 |
+
" | qkv_fused:\n",
|
264 |
+
" | mlp_ratio:\n",
|
265 |
+
" | swiglu_mlp:\n",
|
266 |
+
" | scale_mlp:\n",
|
267 |
+
" | scale_attn_inner:\n",
|
268 |
+
" | drop_rate:\n",
|
269 |
+
" | pos_drop_rate:\n",
|
270 |
+
" | proj_drop_rate:\n",
|
271 |
+
" | attn_drop_rate:\n",
|
272 |
+
" | drop_path_rate:\n",
|
273 |
+
" | norm_layer:\n",
|
274 |
+
" | init_values:\n",
|
275 |
+
" | class_token:\n",
|
276 |
+
" | use_abs_pos_emb:\n",
|
277 |
+
" | use_rot_pos_emb:\n",
|
278 |
+
" | use_post_norm:\n",
|
279 |
+
" | ref_feat_shape:\n",
|
280 |
+
" | head_init_scale:\n",
|
281 |
+
" | \n",
|
282 |
+
" | fix_init_weight(self)\n",
|
283 |
+
" | \n",
|
284 |
+
" | forward(self, x)\n",
|
285 |
+
" | Defines the computation performed at every call.\n",
|
286 |
+
" | \n",
|
287 |
+
" | Should be overridden by all subclasses.\n",
|
288 |
+
" | \n",
|
289 |
+
" | .. note::\n",
|
290 |
+
" | Although the recipe for forward pass needs to be defined within\n",
|
291 |
+
" | this function, one should call the :class:`Module` instance afterwards\n",
|
292 |
+
" | instead of this since the former takes care of running the\n",
|
293 |
+
" | registered hooks while the latter silently ignores them.\n",
|
294 |
+
" | \n",
|
295 |
+
" | forward_features(self, x)\n",
|
296 |
+
" | \n",
|
297 |
+
" | forward_head(self, x, pre_logits: bool = False)\n",
|
298 |
+
" | \n",
|
299 |
+
" | get_classifier(self)\n",
|
300 |
+
" | \n",
|
301 |
+
" | group_matcher(self, coarse=False)\n",
|
302 |
+
" | \n",
|
303 |
+
" | no_weight_decay(self)\n",
|
304 |
+
" | \n",
|
305 |
+
" | reset_classifier(self, num_classes, global_pool=None)\n",
|
306 |
+
" | \n",
|
307 |
+
" | set_grad_checkpointing(self, enable=True)\n",
|
308 |
+
" | \n",
|
309 |
+
" | ----------------------------------------------------------------------\n",
|
310 |
+
" | Data and other attributes defined here:\n",
|
311 |
+
" | \n",
|
312 |
+
" | __annotations__ = {}\n",
|
313 |
+
" | \n",
|
314 |
+
" | ----------------------------------------------------------------------\n",
|
315 |
+
" | Methods inherited from torch.nn.modules.module.Module:\n",
|
316 |
+
" | \n",
|
317 |
+
" | __call__ = _call_impl(self, *args, **kwargs)\n",
|
318 |
+
" | \n",
|
319 |
+
" | __delattr__(self, name)\n",
|
320 |
+
" | Implement delattr(self, name).\n",
|
321 |
+
" | \n",
|
322 |
+
" | __dir__(self)\n",
|
323 |
+
" | Default dir() implementation.\n",
|
324 |
+
" | \n",
|
325 |
+
" | __getattr__(self, name: str) -> Union[torch.Tensor, ForwardRef('Module')]\n",
|
326 |
+
" | \n",
|
327 |
+
" | __repr__(self)\n",
|
328 |
+
" | Return repr(self).\n",
|
329 |
+
" | \n",
|
330 |
+
" | __setattr__(self, name: str, value: Union[torch.Tensor, ForwardRef('Module')]) -> None\n",
|
331 |
+
" | Implement setattr(self, name, value).\n",
|
332 |
+
" | \n",
|
333 |
+
" | __setstate__(self, state)\n",
|
334 |
+
" | \n",
|
335 |
+
" | add_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
|
336 |
+
" | Adds a child module to the current module.\n",
|
337 |
+
" | \n",
|
338 |
+
" | The module can be accessed as an attribute using the given name.\n",
|
339 |
+
" | \n",
|
340 |
+
" | Args:\n",
|
341 |
+
" | name (str): name of the child module. The child module can be\n",
|
342 |
+
" | accessed from this module using the given name\n",
|
343 |
+
" | module (Module): child module to be added to the module.\n",
|
344 |
+
" | \n",
|
345 |
+
" | apply(self: ~T, fn: Callable[[ForwardRef('Module')], NoneType]) -> ~T\n",
|
346 |
+
" | Applies ``fn`` recursively to every submodule (as returned by ``.children()``)\n",
|
347 |
+
" | as well as self. Typical use includes initializing the parameters of a model\n",
|
348 |
+
" | (see also :ref:`nn-init-doc`).\n",
|
349 |
+
" | \n",
|
350 |
+
" | Args:\n",
|
351 |
+
" | fn (:class:`Module` -> None): function to be applied to each submodule\n",
|
352 |
+
" | \n",
|
353 |
+
" | Returns:\n",
|
354 |
+
" | Module: self\n",
|
355 |
+
" | \n",
|
356 |
+
" | Example::\n",
|
357 |
+
" | \n",
|
358 |
+
" | >>> @torch.no_grad()\n",
|
359 |
+
" | >>> def init_weights(m):\n",
|
360 |
+
" | >>> print(m)\n",
|
361 |
+
" | >>> if type(m) == nn.Linear:\n",
|
362 |
+
" | >>> m.weight.fill_(1.0)\n",
|
363 |
+
" | >>> print(m.weight)\n",
|
364 |
+
" | >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
|
365 |
+
" | >>> net.apply(init_weights)\n",
|
366 |
+
" | Linear(in_features=2, out_features=2, bias=True)\n",
|
367 |
+
" | Parameter containing:\n",
|
368 |
+
" | tensor([[1., 1.],\n",
|
369 |
+
" | [1., 1.]], requires_grad=True)\n",
|
370 |
+
" | Linear(in_features=2, out_features=2, bias=True)\n",
|
371 |
+
" | Parameter containing:\n",
|
372 |
+
" | tensor([[1., 1.],\n",
|
373 |
+
" | [1., 1.]], requires_grad=True)\n",
|
374 |
+
" | Sequential(\n",
|
375 |
+
" | (0): Linear(in_features=2, out_features=2, bias=True)\n",
|
376 |
+
" | (1): Linear(in_features=2, out_features=2, bias=True)\n",
|
377 |
+
" | )\n",
|
378 |
+
" | \n",
|
379 |
+
" | bfloat16(self: ~T) -> ~T\n",
|
380 |
+
" | Casts all floating point parameters and buffers to ``bfloat16`` datatype.\n",
|
381 |
+
" | \n",
|
382 |
+
" | .. note::\n",
|
383 |
+
" | This method modifies the module in-place.\n",
|
384 |
+
" | \n",
|
385 |
+
" | Returns:\n",
|
386 |
+
" | Module: self\n",
|
387 |
+
" | \n",
|
388 |
+
" | buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]\n",
|
389 |
+
" | Returns an iterator over module buffers.\n",
|
390 |
+
" | \n",
|
391 |
+
" | Args:\n",
|
392 |
+
" | recurse (bool): if True, then yields buffers of this module\n",
|
393 |
+
" | and all submodules. Otherwise, yields only buffers that\n",
|
394 |
+
" | are direct members of this module.\n",
|
395 |
+
" | \n",
|
396 |
+
" | Yields:\n",
|
397 |
+
" | torch.Tensor: module buffer\n",
|
398 |
+
" | \n",
|
399 |
+
" | Example::\n",
|
400 |
+
" | \n",
|
401 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
402 |
+
" | >>> for buf in model.buffers():\n",
|
403 |
+
" | >>> print(type(buf), buf.size())\n",
|
404 |
+
" | <class 'torch.Tensor'> (20L,)\n",
|
405 |
+
" | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
|
406 |
+
" | \n",
|
407 |
+
" | children(self) -> Iterator[ForwardRef('Module')]\n",
|
408 |
+
" | Returns an iterator over immediate children modules.\n",
|
409 |
+
" | \n",
|
410 |
+
" | Yields:\n",
|
411 |
+
" | Module: a child module\n",
|
412 |
+
" | \n",
|
413 |
+
" | cpu(self: ~T) -> ~T\n",
|
414 |
+
" | Moves all model parameters and buffers to the CPU.\n",
|
415 |
+
" | \n",
|
416 |
+
" | .. note::\n",
|
417 |
+
" | This method modifies the module in-place.\n",
|
418 |
+
" | \n",
|
419 |
+
" | Returns:\n",
|
420 |
+
" | Module: self\n",
|
421 |
+
" | \n",
|
422 |
+
" | cuda(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
|
423 |
+
" | Moves all model parameters and buffers to the GPU.\n",
|
424 |
+
" | \n",
|
425 |
+
" | This also makes associated parameters and buffers different objects. So\n",
|
426 |
+
" | it should be called before constructing optimizer if the module will\n",
|
427 |
+
" | live on GPU while being optimized.\n",
|
428 |
+
" | \n",
|
429 |
+
" | .. note::\n",
|
430 |
+
" | This method modifies the module in-place.\n",
|
431 |
+
" | \n",
|
432 |
+
" | Args:\n",
|
433 |
+
" | device (int, optional): if specified, all parameters will be\n",
|
434 |
+
" | copied to that device\n",
|
435 |
+
" | \n",
|
436 |
+
" | Returns:\n",
|
437 |
+
" | Module: self\n",
|
438 |
+
" | \n",
|
439 |
+
" | double(self: ~T) -> ~T\n",
|
440 |
+
" | Casts all floating point parameters and buffers to ``double`` datatype.\n",
|
441 |
+
" | \n",
|
442 |
+
" | .. note::\n",
|
443 |
+
" | This method modifies the module in-place.\n",
|
444 |
+
" | \n",
|
445 |
+
" | Returns:\n",
|
446 |
+
" | Module: self\n",
|
447 |
+
" | \n",
|
448 |
+
" | eval(self: ~T) -> ~T\n",
|
449 |
+
" | Sets the module in evaluation mode.\n",
|
450 |
+
" | \n",
|
451 |
+
" | This has any effect only on certain modules. See documentations of\n",
|
452 |
+
" | particular modules for details of their behaviors in training/evaluation\n",
|
453 |
+
" | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
|
454 |
+
" | etc.\n",
|
455 |
+
" | \n",
|
456 |
+
" | This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n",
|
457 |
+
" | \n",
|
458 |
+
" | See :ref:`locally-disable-grad-doc` for a comparison between\n",
|
459 |
+
" | `.eval()` and several similar mechanisms that may be confused with it.\n",
|
460 |
+
" | \n",
|
461 |
+
" | Returns:\n",
|
462 |
+
" | Module: self\n",
|
463 |
+
" | \n",
|
464 |
+
" | extra_repr(self) -> str\n",
|
465 |
+
" | Set the extra representation of the module\n",
|
466 |
+
" | \n",
|
467 |
+
" | To print customized extra information, you should re-implement\n",
|
468 |
+
" | this method in your own modules. Both single-line and multi-line\n",
|
469 |
+
" | strings are acceptable.\n",
|
470 |
+
" | \n",
|
471 |
+
" | float(self: ~T) -> ~T\n",
|
472 |
+
" | Casts all floating point parameters and buffers to ``float`` datatype.\n",
|
473 |
+
" | \n",
|
474 |
+
" | .. note::\n",
|
475 |
+
" | This method modifies the module in-place.\n",
|
476 |
+
" | \n",
|
477 |
+
" | Returns:\n",
|
478 |
+
" | Module: self\n",
|
479 |
+
" | \n",
|
480 |
+
" | get_buffer(self, target: str) -> 'Tensor'\n",
|
481 |
+
" | Returns the buffer given by ``target`` if it exists,\n",
|
482 |
+
" | otherwise throws an error.\n",
|
483 |
+
" | \n",
|
484 |
+
" | See the docstring for ``get_submodule`` for a more detailed\n",
|
485 |
+
" | explanation of this method's functionality as well as how to\n",
|
486 |
+
" | correctly specify ``target``.\n",
|
487 |
+
" | \n",
|
488 |
+
" | Args:\n",
|
489 |
+
" | target: The fully-qualified string name of the buffer\n",
|
490 |
+
" | to look for. (See ``get_submodule`` for how to specify a\n",
|
491 |
+
" | fully-qualified string.)\n",
|
492 |
+
" | \n",
|
493 |
+
" | Returns:\n",
|
494 |
+
" | torch.Tensor: The buffer referenced by ``target``\n",
|
495 |
+
" | \n",
|
496 |
+
" | Raises:\n",
|
497 |
+
" | AttributeError: If the target string references an invalid\n",
|
498 |
+
" | path or resolves to something that is not a\n",
|
499 |
+
" | buffer\n",
|
500 |
+
" | \n",
|
501 |
+
" | get_extra_state(self) -> Any\n",
|
502 |
+
" | Returns any extra state to include in the module's state_dict.\n",
|
503 |
+
" | Implement this and a corresponding :func:`set_extra_state` for your module\n",
|
504 |
+
" | if you need to store extra state. This function is called when building the\n",
|
505 |
+
" | module's `state_dict()`.\n",
|
506 |
+
" | \n",
|
507 |
+
" | Note that extra state should be picklable to ensure working serialization\n",
|
508 |
+
" | of the state_dict. We only provide provide backwards compatibility guarantees\n",
|
509 |
+
" | for serializing Tensors; other objects may break backwards compatibility if\n",
|
510 |
+
" | their serialized pickled form changes.\n",
|
511 |
+
" | \n",
|
512 |
+
" | Returns:\n",
|
513 |
+
" | object: Any extra state to store in the module's state_dict\n",
|
514 |
+
" | \n",
|
515 |
+
" | get_parameter(self, target: str) -> 'Parameter'\n",
|
516 |
+
" | Returns the parameter given by ``target`` if it exists,\n",
|
517 |
+
" | otherwise throws an error.\n",
|
518 |
+
" | \n",
|
519 |
+
" | See the docstring for ``get_submodule`` for a more detailed\n",
|
520 |
+
" | explanation of this method's functionality as well as how to\n",
|
521 |
+
" | correctly specify ``target``.\n",
|
522 |
+
" | \n",
|
523 |
+
" | Args:\n",
|
524 |
+
" | target: The fully-qualified string name of the Parameter\n",
|
525 |
+
" | to look for. (See ``get_submodule`` for how to specify a\n",
|
526 |
+
" | fully-qualified string.)\n",
|
527 |
+
" | \n",
|
528 |
+
" | Returns:\n",
|
529 |
+
" | torch.nn.Parameter: The Parameter referenced by ``target``\n",
|
530 |
+
" | \n",
|
531 |
+
" | Raises:\n",
|
532 |
+
" | AttributeError: If the target string references an invalid\n",
|
533 |
+
" | path or resolves to something that is not an\n",
|
534 |
+
" | ``nn.Parameter``\n",
|
535 |
+
" | \n",
|
536 |
+
" | get_submodule(self, target: str) -> 'Module'\n",
|
537 |
+
" | Returns the submodule given by ``target`` if it exists,\n",
|
538 |
+
" | otherwise throws an error.\n",
|
539 |
+
" | \n",
|
540 |
+
" | For example, let's say you have an ``nn.Module`` ``A`` that\n",
|
541 |
+
" | looks like this:\n",
|
542 |
+
" | \n",
|
543 |
+
" | .. code-block:: text\n",
|
544 |
+
" | \n",
|
545 |
+
" | A(\n",
|
546 |
+
" | (net_b): Module(\n",
|
547 |
+
" | (net_c): Module(\n",
|
548 |
+
" | (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n",
|
549 |
+
" | )\n",
|
550 |
+
" | (linear): Linear(in_features=100, out_features=200, bias=True)\n",
|
551 |
+
" | )\n",
|
552 |
+
" | )\n",
|
553 |
+
" | \n",
|
554 |
+
" | (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested\n",
|
555 |
+
" | submodule ``net_b``, which itself has two submodules ``net_c``\n",
|
556 |
+
" | and ``linear``. ``net_c`` then has a submodule ``conv``.)\n",
|
557 |
+
" | \n",
|
558 |
+
" | To check whether or not we have the ``linear`` submodule, we\n",
|
559 |
+
" | would call ``get_submodule(\"net_b.linear\")``. To check whether\n",
|
560 |
+
" | we have the ``conv`` submodule, we would call\n",
|
561 |
+
" | ``get_submodule(\"net_b.net_c.conv\")``.\n",
|
562 |
+
" | \n",
|
563 |
+
" | The runtime of ``get_submodule`` is bounded by the degree\n",
|
564 |
+
" | of module nesting in ``target``. A query against\n",
|
565 |
+
" | ``named_modules`` achieves the same result, but it is O(N) in\n",
|
566 |
+
" | the number of transitive modules. So, for a simple check to see\n",
|
567 |
+
" | if some submodule exists, ``get_submodule`` should always be\n",
|
568 |
+
" | used.\n",
|
569 |
+
" | \n",
|
570 |
+
" | Args:\n",
|
571 |
+
" | target: The fully-qualified string name of the submodule\n",
|
572 |
+
" | to look for. (See above example for how to specify a\n",
|
573 |
+
" | fully-qualified string.)\n",
|
574 |
+
" | \n",
|
575 |
+
" | Returns:\n",
|
576 |
+
" | torch.nn.Module: The submodule referenced by ``target``\n",
|
577 |
+
" | \n",
|
578 |
+
" | Raises:\n",
|
579 |
+
" | AttributeError: If the target string references an invalid\n",
|
580 |
+
" | path or resolves to something that is not an\n",
|
581 |
+
" | ``nn.Module``\n",
|
582 |
+
" | \n",
|
583 |
+
" | half(self: ~T) -> ~T\n",
|
584 |
+
" | Casts all floating point parameters and buffers to ``half`` datatype.\n",
|
585 |
+
" | \n",
|
586 |
+
" | .. note::\n",
|
587 |
+
" | This method modifies the module in-place.\n",
|
588 |
+
" | \n",
|
589 |
+
" | Returns:\n",
|
590 |
+
" | Module: self\n",
|
591 |
+
" | \n",
|
592 |
+
" | ipu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
|
593 |
+
" | Moves all model parameters and buffers to the IPU.\n",
|
594 |
+
" | \n",
|
595 |
+
" | This also makes associated parameters and buffers different objects. So\n",
|
596 |
+
" | it should be called before constructing optimizer if the module will\n",
|
597 |
+
" | live on IPU while being optimized.\n",
|
598 |
+
" | \n",
|
599 |
+
" | .. note::\n",
|
600 |
+
" | This method modifies the module in-place.\n",
|
601 |
+
" | \n",
|
602 |
+
" | Arguments:\n",
|
603 |
+
" | device (int, optional): if specified, all parameters will be\n",
|
604 |
+
" | copied to that device\n",
|
605 |
+
" | \n",
|
606 |
+
" | Returns:\n",
|
607 |
+
" | Module: self\n",
|
608 |
+
" | \n",
|
609 |
+
" | load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True)\n",
|
610 |
+
" | Copies parameters and buffers from :attr:`state_dict` into\n",
|
611 |
+
" | this module and its descendants. If :attr:`strict` is ``True``, then\n",
|
612 |
+
" | the keys of :attr:`state_dict` must exactly match the keys returned\n",
|
613 |
+
" | by this module's :meth:`~torch.nn.Module.state_dict` function.\n",
|
614 |
+
" | \n",
|
615 |
+
" | Args:\n",
|
616 |
+
" | state_dict (dict): a dict containing parameters and\n",
|
617 |
+
" | persistent buffers.\n",
|
618 |
+
" | strict (bool, optional): whether to strictly enforce that the keys\n",
|
619 |
+
" | in :attr:`state_dict` match the keys returned by this module's\n",
|
620 |
+
" | :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n",
|
621 |
+
" | \n",
|
622 |
+
" | Returns:\n",
|
623 |
+
" | ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n",
|
624 |
+
" | * **missing_keys** is a list of str containing the missing keys\n",
|
625 |
+
" | * **unexpected_keys** is a list of str containing the unexpected keys\n",
|
626 |
+
" | \n",
|
627 |
+
" | Note:\n",
|
628 |
+
" | If a parameter or buffer is registered as ``None`` and its corresponding key\n",
|
629 |
+
" | exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n",
|
630 |
+
" | ``RuntimeError``.\n",
|
631 |
+
" | \n",
|
632 |
+
" | modules(self) -> Iterator[ForwardRef('Module')]\n",
|
633 |
+
" | Returns an iterator over all modules in the network.\n",
|
634 |
+
" | \n",
|
635 |
+
" | Yields:\n",
|
636 |
+
" | Module: a module in the network\n",
|
637 |
+
" | \n",
|
638 |
+
" | Note:\n",
|
639 |
+
" | Duplicate modules are returned only once. In the following\n",
|
640 |
+
" | example, ``l`` will be returned only once.\n",
|
641 |
+
" | \n",
|
642 |
+
" | Example::\n",
|
643 |
+
" | \n",
|
644 |
+
" | >>> l = nn.Linear(2, 2)\n",
|
645 |
+
" | >>> net = nn.Sequential(l, l)\n",
|
646 |
+
" | >>> for idx, m in enumerate(net.modules()):\n",
|
647 |
+
" | ... print(idx, '->', m)\n",
|
648 |
+
" | \n",
|
649 |
+
" | 0 -> Sequential(\n",
|
650 |
+
" | (0): Linear(in_features=2, out_features=2, bias=True)\n",
|
651 |
+
" | (1): Linear(in_features=2, out_features=2, bias=True)\n",
|
652 |
+
" | )\n",
|
653 |
+
" | 1 -> Linear(in_features=2, out_features=2, bias=True)\n",
|
654 |
+
" | \n",
|
655 |
+
" | named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]\n",
|
656 |
+
" | Returns an iterator over module buffers, yielding both the\n",
|
657 |
+
" | name of the buffer as well as the buffer itself.\n",
|
658 |
+
" | \n",
|
659 |
+
" | Args:\n",
|
660 |
+
" | prefix (str): prefix to prepend to all buffer names.\n",
|
661 |
+
" | recurse (bool, optional): if True, then yields buffers of this module\n",
|
662 |
+
" | and all submodules. Otherwise, yields only buffers that\n",
|
663 |
+
" | are direct members of this module. Defaults to True.\n",
|
664 |
+
" | remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.\n",
|
665 |
+
" | \n",
|
666 |
+
" | Yields:\n",
|
667 |
+
" | (str, torch.Tensor): Tuple containing the name and buffer\n",
|
668 |
+
" | \n",
|
669 |
+
" | Example::\n",
|
670 |
+
" | \n",
|
671 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
672 |
+
" | >>> for name, buf in self.named_buffers():\n",
|
673 |
+
" | >>> if name in ['running_var']:\n",
|
674 |
+
" | >>> print(buf.size())\n",
|
675 |
+
" | \n",
|
676 |
+
" | named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]\n",
|
677 |
+
" | Returns an iterator over immediate children modules, yielding both\n",
|
678 |
+
" | the name of the module as well as the module itself.\n",
|
679 |
+
" | \n",
|
680 |
+
" | Yields:\n",
|
681 |
+
" | (str, Module): Tuple containing a name and child module\n",
|
682 |
+
" | \n",
|
683 |
+
" | Example::\n",
|
684 |
+
" | \n",
|
685 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
686 |
+
" | >>> for name, module in model.named_children():\n",
|
687 |
+
" | >>> if name in ['conv4', 'conv5']:\n",
|
688 |
+
" | >>> print(module)\n",
|
689 |
+
" | \n",
|
690 |
+
" | named_modules(self, memo: Optional[Set[ForwardRef('Module')]] = None, prefix: str = '', remove_duplicate: bool = True)\n",
|
691 |
+
" | Returns an iterator over all modules in the network, yielding\n",
|
692 |
+
" | both the name of the module as well as the module itself.\n",
|
693 |
+
" | \n",
|
694 |
+
" | Args:\n",
|
695 |
+
" | memo: a memo to store the set of modules already added to the result\n",
|
696 |
+
" | prefix: a prefix that will be added to the name of the module\n",
|
697 |
+
" | remove_duplicate: whether to remove the duplicated module instances in the result\n",
|
698 |
+
" | or not\n",
|
699 |
+
" | \n",
|
700 |
+
" | Yields:\n",
|
701 |
+
" | (str, Module): Tuple of name and module\n",
|
702 |
+
" | \n",
|
703 |
+
" | Note:\n",
|
704 |
+
" | Duplicate modules are returned only once. In the following\n",
|
705 |
+
" | example, ``l`` will be returned only once.\n",
|
706 |
+
" | \n",
|
707 |
+
" | Example::\n",
|
708 |
+
" | \n",
|
709 |
+
" | >>> l = nn.Linear(2, 2)\n",
|
710 |
+
" | >>> net = nn.Sequential(l, l)\n",
|
711 |
+
" | >>> for idx, m in enumerate(net.named_modules()):\n",
|
712 |
+
" | ... print(idx, '->', m)\n",
|
713 |
+
" | \n",
|
714 |
+
" | 0 -> ('', Sequential(\n",
|
715 |
+
" | (0): Linear(in_features=2, out_features=2, bias=True)\n",
|
716 |
+
" | (1): Linear(in_features=2, out_features=2, bias=True)\n",
|
717 |
+
" | ))\n",
|
718 |
+
" | 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n",
|
719 |
+
" | \n",
|
720 |
+
" | named_parameters(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]\n",
|
721 |
+
" | Returns an iterator over module parameters, yielding both the\n",
|
722 |
+
" | name of the parameter as well as the parameter itself.\n",
|
723 |
+
" | \n",
|
724 |
+
" | Args:\n",
|
725 |
+
" | prefix (str): prefix to prepend to all parameter names.\n",
|
726 |
+
" | recurse (bool): if True, then yields parameters of this module\n",
|
727 |
+
" | and all submodules. Otherwise, yields only parameters that\n",
|
728 |
+
" | are direct members of this module.\n",
|
729 |
+
" | remove_duplicate (bool, optional): whether to remove the duplicated\n",
|
730 |
+
" | parameters in the result. Defaults to True.\n",
|
731 |
+
" | \n",
|
732 |
+
" | Yields:\n",
|
733 |
+
" | (str, Parameter): Tuple containing the name and parameter\n",
|
734 |
+
" | \n",
|
735 |
+
" | Example::\n",
|
736 |
+
" | \n",
|
737 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
738 |
+
" | >>> for name, param in self.named_parameters():\n",
|
739 |
+
" | >>> if name in ['bias']:\n",
|
740 |
+
" | >>> print(param.size())\n",
|
741 |
+
" | \n",
|
742 |
+
" | parameters(self, recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]\n",
|
743 |
+
" | Returns an iterator over module parameters.\n",
|
744 |
+
" | \n",
|
745 |
+
" | This is typically passed to an optimizer.\n",
|
746 |
+
" | \n",
|
747 |
+
" | Args:\n",
|
748 |
+
" | recurse (bool): if True, then yields parameters of this module\n",
|
749 |
+
" | and all submodules. Otherwise, yields only parameters that\n",
|
750 |
+
" | are direct members of this module.\n",
|
751 |
+
" | \n",
|
752 |
+
" | Yields:\n",
|
753 |
+
" | Parameter: module parameter\n",
|
754 |
+
" | \n",
|
755 |
+
" | Example::\n",
|
756 |
+
" | \n",
|
757 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
758 |
+
" | >>> for param in model.parameters():\n",
|
759 |
+
" | >>> print(type(param), param.size())\n",
|
760 |
+
" | <class 'torch.Tensor'> (20L,)\n",
|
761 |
+
" | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
|
762 |
+
" | \n",
|
763 |
+
" | register_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle\n",
|
764 |
+
" | Registers a backward hook on the module.\n",
|
765 |
+
" | \n",
|
766 |
+
" | This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and\n",
|
767 |
+
" | the behavior of this function will change in future versions.\n",
|
768 |
+
" | \n",
|
769 |
+
" | Returns:\n",
|
770 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
771 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
772 |
+
" | ``handle.remove()``\n",
|
773 |
+
" | \n",
|
774 |
+
" | register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None\n",
|
775 |
+
" | Adds a buffer to the module.\n",
|
776 |
+
" | \n",
|
777 |
+
" | This is typically used to register a buffer that should not to be\n",
|
778 |
+
" | considered a model parameter. For example, BatchNorm's ``running_mean``\n",
|
779 |
+
" | is not a parameter, but is part of the module's state. Buffers, by\n",
|
780 |
+
" | default, are persistent and will be saved alongside parameters. This\n",
|
781 |
+
" | behavior can be changed by setting :attr:`persistent` to ``False``. The\n",
|
782 |
+
" | only difference between a persistent buffer and a non-persistent buffer\n",
|
783 |
+
" | is that the latter will not be a part of this module's\n",
|
784 |
+
" | :attr:`state_dict`.\n",
|
785 |
+
" | \n",
|
786 |
+
" | Buffers can be accessed as attributes using given names.\n",
|
787 |
+
" | \n",
|
788 |
+
" | Args:\n",
|
789 |
+
" | name (str): name of the buffer. The buffer can be accessed\n",
|
790 |
+
" | from this module using the given name\n",
|
791 |
+
" | tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n",
|
792 |
+
" | that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n",
|
793 |
+
" | the buffer is **not** included in the module's :attr:`state_dict`.\n",
|
794 |
+
" | persistent (bool): whether the buffer is part of this module's\n",
|
795 |
+
" | :attr:`state_dict`.\n",
|
796 |
+
" | \n",
|
797 |
+
" | Example::\n",
|
798 |
+
" | \n",
|
799 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
800 |
+
" | >>> self.register_buffer('running_mean', torch.zeros(num_features))\n",
|
801 |
+
" | \n",
|
802 |
+
" | register_forward_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
|
803 |
+
" | Registers a forward hook on the module.\n",
|
804 |
+
" | \n",
|
805 |
+
" | The hook will be called every time after :func:`forward` has computed an output.\n",
|
806 |
+
" | \n",
|
807 |
+
" | If ``with_kwargs`` is ``False`` or not specified, the input contains only\n",
|
808 |
+
" | the positional arguments given to the module. Keyword arguments won't be\n",
|
809 |
+
" | passed to the hooks and only to the ``forward``. The hook can modify the\n",
|
810 |
+
" | output. It can modify the input inplace but it will not have effect on\n",
|
811 |
+
" | forward since this is called after :func:`forward` is called. The hook\n",
|
812 |
+
" | should have the following signature::\n",
|
813 |
+
" | \n",
|
814 |
+
" | hook(module, args, output) -> None or modified output\n",
|
815 |
+
" | \n",
|
816 |
+
" | If ``with_kwargs`` is ``True``, the forward hook will be passed the\n",
|
817 |
+
" | ``kwargs`` given to the forward function and be expected to return the\n",
|
818 |
+
" | output possibly modified. The hook should have the following signature::\n",
|
819 |
+
" | \n",
|
820 |
+
" | hook(module, args, kwargs, output) -> None or modified output\n",
|
821 |
+
" | \n",
|
822 |
+
" | Args:\n",
|
823 |
+
" | hook (Callable): The user defined hook to be registered.\n",
|
824 |
+
" | prepend (bool): If ``True``, the provided ``hook`` will be fired\n",
|
825 |
+
" | before all existing ``forward`` hooks on this\n",
|
826 |
+
" | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
|
827 |
+
" | ``hook`` will be fired after all existing ``forward`` hooks on\n",
|
828 |
+
" | this :class:`torch.nn.modules.Module`. Note that global\n",
|
829 |
+
" | ``forward`` hooks registered with\n",
|
830 |
+
" | :func:`register_module_forward_hook` will fire before all hooks\n",
|
831 |
+
" | registered by this method.\n",
|
832 |
+
" | Default: ``False``\n",
|
833 |
+
" | with_kwargs (bool): If ``True``, the ``hook`` will be passed the\n",
|
834 |
+
" | kwargs given to the forward function.\n",
|
835 |
+
" | Default: ``False``\n",
|
836 |
+
" | \n",
|
837 |
+
" | Returns:\n",
|
838 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
839 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
840 |
+
" | ``handle.remove()``\n",
|
841 |
+
" | \n",
|
842 |
+
" | register_forward_pre_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
|
843 |
+
" | Registers a forward pre-hook on the module.\n",
|
844 |
+
" | \n",
|
845 |
+
" | The hook will be called every time before :func:`forward` is invoked.\n",
|
846 |
+
" | \n",
|
847 |
+
" | \n",
|
848 |
+
" | If ``with_kwargs`` is false or not specified, the input contains only\n",
|
849 |
+
" | the positional arguments given to the module. Keyword arguments won't be\n",
|
850 |
+
" | passed to the hooks and only to the ``forward``. The hook can modify the\n",
|
851 |
+
" | input. User can either return a tuple or a single modified value in the\n",
|
852 |
+
" | hook. We will wrap the value into a tuple if a single value is returned\n",
|
853 |
+
" | (unless that value is already a tuple). The hook should have the\n",
|
854 |
+
" | following signature::\n",
|
855 |
+
" | \n",
|
856 |
+
" | hook(module, args) -> None or modified input\n",
|
857 |
+
" | \n",
|
858 |
+
" | If ``with_kwargs`` is true, the forward pre-hook will be passed the\n",
|
859 |
+
" | kwargs given to the forward function. And if the hook modifies the\n",
|
860 |
+
" | input, both the args and kwargs should be returned. The hook should have\n",
|
861 |
+
" | the following signature::\n",
|
862 |
+
" | \n",
|
863 |
+
" | hook(module, args, kwargs) -> None or a tuple of modified input and kwargs\n",
|
864 |
+
" | \n",
|
865 |
+
" | Args:\n",
|
866 |
+
" | hook (Callable): The user defined hook to be registered.\n",
|
867 |
+
" | prepend (bool): If true, the provided ``hook`` will be fired before\n",
|
868 |
+
" | all existing ``forward_pre`` hooks on this\n",
|
869 |
+
" | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
|
870 |
+
" | ``hook`` will be fired after all existing ``forward_pre`` hooks\n",
|
871 |
+
" | on this :class:`torch.nn.modules.Module`. Note that global\n",
|
872 |
+
" | ``forward_pre`` hooks registered with\n",
|
873 |
+
" | :func:`register_module_forward_pre_hook` will fire before all\n",
|
874 |
+
" | hooks registered by this method.\n",
|
875 |
+
" | Default: ``False``\n",
|
876 |
+
" | with_kwargs (bool): If true, the ``hook`` will be passed the kwargs\n",
|
877 |
+
" | given to the forward function.\n",
|
878 |
+
" | Default: ``False``\n",
|
879 |
+
" | \n",
|
880 |
+
" | Returns:\n",
|
881 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
882 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
883 |
+
" | ``handle.remove()``\n",
|
884 |
+
" | \n",
|
885 |
+
" | register_full_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
|
886 |
+
" | Registers a backward hook on the module.\n",
|
887 |
+
" | \n",
|
888 |
+
" | The hook will be called every time the gradients with respect to a module\n",
|
889 |
+
" | are computed, i.e. the hook will execute if and only if the gradients with\n",
|
890 |
+
" | respect to module outputs are computed. The hook should have the following\n",
|
891 |
+
" | signature::\n",
|
892 |
+
" | \n",
|
893 |
+
" | hook(module, grad_input, grad_output) -> tuple(Tensor) or None\n",
|
894 |
+
" | \n",
|
895 |
+
" | The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients\n",
|
896 |
+
" | with respect to the inputs and outputs respectively. The hook should\n",
|
897 |
+
" | not modify its arguments, but it can optionally return a new gradient with\n",
|
898 |
+
" | respect to the input that will be used in place of :attr:`grad_input` in\n",
|
899 |
+
" | subsequent computations. :attr:`grad_input` will only correspond to the inputs given\n",
|
900 |
+
" | as positional arguments and all kwarg arguments are ignored. Entries\n",
|
901 |
+
" | in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n",
|
902 |
+
" | arguments.\n",
|
903 |
+
" | \n",
|
904 |
+
" | For technical reasons, when this hook is applied to a Module, its forward function will\n",
|
905 |
+
" | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
|
906 |
+
" | of each Tensor returned by the Module's forward function.\n",
|
907 |
+
" | \n",
|
908 |
+
" | .. warning ::\n",
|
909 |
+
" | Modifying inputs or outputs inplace is not allowed when using backward hooks and\n",
|
910 |
+
" | will raise an error.\n",
|
911 |
+
" | \n",
|
912 |
+
" | Args:\n",
|
913 |
+
" | hook (Callable): The user-defined hook to be registered.\n",
|
914 |
+
" | prepend (bool): If true, the provided ``hook`` will be fired before\n",
|
915 |
+
" | all existing ``backward`` hooks on this\n",
|
916 |
+
" | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
|
917 |
+
" | ``hook`` will be fired after all existing ``backward`` hooks on\n",
|
918 |
+
" | this :class:`torch.nn.modules.Module`. Note that global\n",
|
919 |
+
" | ``backward`` hooks registered with\n",
|
920 |
+
" | :func:`register_module_full_backward_hook` will fire before\n",
|
921 |
+
" | all hooks registered by this method.\n",
|
922 |
+
" | \n",
|
923 |
+
" | Returns:\n",
|
924 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
925 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
926 |
+
" | ``handle.remove()``\n",
|
927 |
+
" | \n",
|
928 |
+
" | register_full_backward_pre_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
|
929 |
+
" | Registers a backward pre-hook on the module.\n",
|
930 |
+
" | \n",
|
931 |
+
" | The hook will be called every time the gradients for the module are computed.\n",
|
932 |
+
" | The hook should have the following signature::\n",
|
933 |
+
" | \n",
|
934 |
+
" | hook(module, grad_output) -> Tensor or None\n",
|
935 |
+
" | \n",
|
936 |
+
" | The :attr:`grad_output` is a tuple. The hook should\n",
|
937 |
+
" | not modify its arguments, but it can optionally return a new gradient with\n",
|
938 |
+
" | respect to the output that will be used in place of :attr:`grad_output` in\n",
|
939 |
+
" | subsequent computations. Entries in :attr:`grad_output` will be ``None`` for\n",
|
940 |
+
" | all non-Tensor arguments.\n",
|
941 |
+
" | \n",
|
942 |
+
" | For technical reasons, when this hook is applied to a Module, its forward function will\n",
|
943 |
+
" | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
|
944 |
+
" | of each Tensor returned by the Module's forward function.\n",
|
945 |
+
" | \n",
|
946 |
+
" | .. warning ::\n",
|
947 |
+
" | Modifying inputs inplace is not allowed when using backward hooks and\n",
|
948 |
+
" | will raise an error.\n",
|
949 |
+
" | \n",
|
950 |
+
" | Args:\n",
|
951 |
+
" | hook (Callable): The user-defined hook to be registered.\n",
|
952 |
+
" | prepend (bool): If true, the provided ``hook`` will be fired before\n",
|
953 |
+
" | all existing ``backward_pre`` hooks on this\n",
|
954 |
+
" | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
|
955 |
+
" | ``hook`` will be fired after all existing ``backward_pre`` hooks\n",
|
956 |
+
" | on this :class:`torch.nn.modules.Module`. Note that global\n",
|
957 |
+
" | ``backward_pre`` hooks registered with\n",
|
958 |
+
" | :func:`register_module_full_backward_pre_hook` will fire before\n",
|
959 |
+
" | all hooks registered by this method.\n",
|
960 |
+
" | \n",
|
961 |
+
" | Returns:\n",
|
962 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
963 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
964 |
+
" | ``handle.remove()``\n",
|
965 |
+
" | \n",
|
966 |
+
" | register_load_state_dict_post_hook(self, hook)\n",
|
967 |
+
" | Registers a post hook to be run after module's ``load_state_dict``\n",
|
968 |
+
" | is called.\n",
|
969 |
+
" | \n",
|
970 |
+
" | It should have the following signature::\n",
|
971 |
+
" | hook(module, incompatible_keys) -> None\n",
|
972 |
+
" | \n",
|
973 |
+
" | The ``module`` argument is the current module that this hook is registered\n",
|
974 |
+
" | on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting\n",
|
975 |
+
" | of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``\n",
|
976 |
+
" | is a ``list`` of ``str`` containing the missing keys and\n",
|
977 |
+
" | ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.\n",
|
978 |
+
" | \n",
|
979 |
+
" | The given incompatible_keys can be modified inplace if needed.\n",
|
980 |
+
" | \n",
|
981 |
+
" | Note that the checks performed when calling :func:`load_state_dict` with\n",
|
982 |
+
" | ``strict=True`` are affected by modifications the hook makes to\n",
|
983 |
+
" | ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either\n",
|
984 |
+
" | set of keys will result in an error being thrown when ``strict=True``, and\n",
|
985 |
+
" | clearing out both missing and unexpected keys will avoid an error.\n",
|
986 |
+
" | \n",
|
987 |
+
" | Returns:\n",
|
988 |
+
" | :class:`torch.utils.hooks.RemovableHandle`:\n",
|
989 |
+
" | a handle that can be used to remove the added hook by calling\n",
|
990 |
+
" | ``handle.remove()``\n",
|
991 |
+
" | \n",
|
992 |
+
" | register_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
|
993 |
+
" | Alias for :func:`add_module`.\n",
|
994 |
+
" | \n",
|
995 |
+
" | register_parameter(self, name: str, param: Optional[torch.nn.parameter.Parameter]) -> None\n",
|
996 |
+
" | Adds a parameter to the module.\n",
|
997 |
+
" | \n",
|
998 |
+
" | The parameter can be accessed as an attribute using given name.\n",
|
999 |
+
" | \n",
|
1000 |
+
" | Args:\n",
|
1001 |
+
" | name (str): name of the parameter. The parameter can be accessed\n",
|
1002 |
+
" | from this module using the given name\n",
|
1003 |
+
" | param (Parameter or None): parameter to be added to the module. If\n",
|
1004 |
+
" | ``None``, then operations that run on parameters, such as :attr:`cuda`,\n",
|
1005 |
+
" | are ignored. If ``None``, the parameter is **not** included in the\n",
|
1006 |
+
" | module's :attr:`state_dict`.\n",
|
1007 |
+
" | \n",
|
1008 |
+
" | register_state_dict_pre_hook(self, hook)\n",
|
1009 |
+
" | These hooks will be called with arguments: ``self``, ``prefix``,\n",
|
1010 |
+
" | and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered\n",
|
1011 |
+
" | hooks can be used to perform pre-processing before the ``state_dict``\n",
|
1012 |
+
" | call is made.\n",
|
1013 |
+
" | \n",
|
1014 |
+
" | requires_grad_(self: ~T, requires_grad: bool = True) -> ~T\n",
|
1015 |
+
" | Change if autograd should record operations on parameters in this\n",
|
1016 |
+
" | module.\n",
|
1017 |
+
" | \n",
|
1018 |
+
" | This method sets the parameters' :attr:`requires_grad` attributes\n",
|
1019 |
+
" | in-place.\n",
|
1020 |
+
" | \n",
|
1021 |
+
" | This method is helpful for freezing part of the module for finetuning\n",
|
1022 |
+
" | or training parts of a model individually (e.g., GAN training).\n",
|
1023 |
+
" | \n",
|
1024 |
+
" | See :ref:`locally-disable-grad-doc` for a comparison between\n",
|
1025 |
+
" | `.requires_grad_()` and several similar mechanisms that may be confused with it.\n",
|
1026 |
+
" | \n",
|
1027 |
+
" | Args:\n",
|
1028 |
+
" | requires_grad (bool): whether autograd should record operations on\n",
|
1029 |
+
" | parameters in this module. Default: ``True``.\n",
|
1030 |
+
" | \n",
|
1031 |
+
" | Returns:\n",
|
1032 |
+
" | Module: self\n",
|
1033 |
+
" | \n",
|
1034 |
+
" | set_extra_state(self, state: Any)\n",
|
1035 |
+
" | This function is called from :func:`load_state_dict` to handle any extra state\n",
|
1036 |
+
" | found within the `state_dict`. Implement this function and a corresponding\n",
|
1037 |
+
" | :func:`get_extra_state` for your module if you need to store extra state within its\n",
|
1038 |
+
" | `state_dict`.\n",
|
1039 |
+
" | \n",
|
1040 |
+
" | Args:\n",
|
1041 |
+
" | state (dict): Extra state from the `state_dict`\n",
|
1042 |
+
" | \n",
|
1043 |
+
" | share_memory(self: ~T) -> ~T\n",
|
1044 |
+
" | See :meth:`torch.Tensor.share_memory_`\n",
|
1045 |
+
" | \n",
|
1046 |
+
" | state_dict(self, *args, destination=None, prefix='', keep_vars=False)\n",
|
1047 |
+
" | Returns a dictionary containing references to the whole state of the module.\n",
|
1048 |
+
" | \n",
|
1049 |
+
" | Both parameters and persistent buffers (e.g. running averages) are\n",
|
1050 |
+
" | included. Keys are corresponding parameter and buffer names.\n",
|
1051 |
+
" | Parameters and buffers set to ``None`` are not included.\n",
|
1052 |
+
" | \n",
|
1053 |
+
" | .. note::\n",
|
1054 |
+
" | The returned object is a shallow copy. It contains references\n",
|
1055 |
+
" | to the module's parameters and buffers.\n",
|
1056 |
+
" | \n",
|
1057 |
+
" | .. warning::\n",
|
1058 |
+
" | Currently ``state_dict()`` also accepts positional arguments for\n",
|
1059 |
+
" | ``destination``, ``prefix`` and ``keep_vars`` in order. However,\n",
|
1060 |
+
" | this is being deprecated and keyword arguments will be enforced in\n",
|
1061 |
+
" | future releases.\n",
|
1062 |
+
" | \n",
|
1063 |
+
" | .. warning::\n",
|
1064 |
+
" | Please avoid the use of argument ``destination`` as it is not\n",
|
1065 |
+
" | designed for end-users.\n",
|
1066 |
+
" | \n",
|
1067 |
+
" | Args:\n",
|
1068 |
+
" | destination (dict, optional): If provided, the state of module will\n",
|
1069 |
+
" | be updated into the dict and the same object is returned.\n",
|
1070 |
+
" | Otherwise, an ``OrderedDict`` will be created and returned.\n",
|
1071 |
+
" | Default: ``None``.\n",
|
1072 |
+
" | prefix (str, optional): a prefix added to parameter and buffer\n",
|
1073 |
+
" | names to compose the keys in state_dict. Default: ``''``.\n",
|
1074 |
+
" | keep_vars (bool, optional): by default the :class:`~torch.Tensor` s\n",
|
1075 |
+
" | returned in the state dict are detached from autograd. If it's\n",
|
1076 |
+
" | set to ``True``, detaching will not be performed.\n",
|
1077 |
+
" | Default: ``False``.\n",
|
1078 |
+
" | \n",
|
1079 |
+
" | Returns:\n",
|
1080 |
+
" | dict:\n",
|
1081 |
+
" | a dictionary containing a whole state of the module\n",
|
1082 |
+
" | \n",
|
1083 |
+
" | Example::\n",
|
1084 |
+
" | \n",
|
1085 |
+
" | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
|
1086 |
+
" | >>> module.state_dict().keys()\n",
|
1087 |
+
" | ['bias', 'weight']\n",
|
1088 |
+
" | \n",
|
1089 |
+
" | to(self, *args, **kwargs)\n",
|
1090 |
+
" | Moves and/or casts the parameters and buffers.\n",
|
1091 |
+
" | \n",
|
1092 |
+
" | This can be called as\n",
|
1093 |
+
" | \n",
|
1094 |
+
" | .. function:: to(device=None, dtype=None, non_blocking=False)\n",
|
1095 |
+
" | :noindex:\n",
|
1096 |
+
" | \n",
|
1097 |
+
" | .. function:: to(dtype, non_blocking=False)\n",
|
1098 |
+
" | :noindex:\n",
|
1099 |
+
" | \n",
|
1100 |
+
" | .. function:: to(tensor, non_blocking=False)\n",
|
1101 |
+
" | :noindex:\n",
|
1102 |
+
" | \n",
|
1103 |
+
" | .. function:: to(memory_format=torch.channels_last)\n",
|
1104 |
+
" | :noindex:\n",
|
1105 |
+
" | \n",
|
1106 |
+
" | Its signature is similar to :meth:`torch.Tensor.to`, but only accepts\n",
|
1107 |
+
" | floating point or complex :attr:`dtype`\\ s. In addition, this method will\n",
|
1108 |
+
" | only cast the floating point or complex parameters and buffers to :attr:`dtype`\n",
|
1109 |
+
" | (if given). The integral parameters and buffers will be moved\n",
|
1110 |
+
" | :attr:`device`, if that is given, but with dtypes unchanged. When\n",
|
1111 |
+
" | :attr:`non_blocking` is set, it tries to convert/move asynchronously\n",
|
1112 |
+
" | with respect to the host if possible, e.g., moving CPU Tensors with\n",
|
1113 |
+
" | pinned memory to CUDA devices.\n",
|
1114 |
+
" | \n",
|
1115 |
+
" | See below for examples.\n",
|
1116 |
+
" | \n",
|
1117 |
+
" | .. note::\n",
|
1118 |
+
" | This method modifies the module in-place.\n",
|
1119 |
+
" | \n",
|
1120 |
+
" | Args:\n",
|
1121 |
+
" | device (:class:`torch.device`): the desired device of the parameters\n",
|
1122 |
+
" | and buffers in this module\n",
|
1123 |
+
" | dtype (:class:`torch.dtype`): the desired floating point or complex dtype of\n",
|
1124 |
+
" | the parameters and buffers in this module\n",
|
1125 |
+
" | tensor (torch.Tensor): Tensor whose dtype and device are the desired\n",
|
1126 |
+
" | dtype and device for all parameters and buffers in this module\n",
|
1127 |
+
" | memory_format (:class:`torch.memory_format`): the desired memory\n",
|
1128 |
+
" | format for 4D parameters and buffers in this module (keyword\n",
|
1129 |
+
" | only argument)\n",
|
1130 |
+
" | \n",
|
1131 |
+
" | Returns:\n",
|
1132 |
+
" | Module: self\n",
|
1133 |
+
" | \n",
|
1134 |
+
" | Examples::\n",
|
1135 |
+
" | \n",
|
1136 |
+
" | >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n",
|
1137 |
+
" | >>> linear = nn.Linear(2, 2)\n",
|
1138 |
+
" | >>> linear.weight\n",
|
1139 |
+
" | Parameter containing:\n",
|
1140 |
+
" | tensor([[ 0.1913, -0.3420],\n",
|
1141 |
+
" | [-0.5113, -0.2325]])\n",
|
1142 |
+
" | >>> linear.to(torch.double)\n",
|
1143 |
+
" | Linear(in_features=2, out_features=2, bias=True)\n",
|
1144 |
+
" | >>> linear.weight\n",
|
1145 |
+
" | Parameter containing:\n",
|
1146 |
+
" | tensor([[ 0.1913, -0.3420],\n",
|
1147 |
+
" | [-0.5113, -0.2325]], dtype=torch.float64)\n",
|
1148 |
+
" | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)\n",
|
1149 |
+
" | >>> gpu1 = torch.device(\"cuda:1\")\n",
|
1150 |
+
" | >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)\n",
|
1151 |
+
" | Linear(in_features=2, out_features=2, bias=True)\n",
|
1152 |
+
" | >>> linear.weight\n",
|
1153 |
+
" | Parameter containing:\n",
|
1154 |
+
" | tensor([[ 0.1914, -0.3420],\n",
|
1155 |
+
" | [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')\n",
|
1156 |
+
" | >>> cpu = torch.device(\"cpu\")\n",
|
1157 |
+
" | >>> linear.to(cpu)\n",
|
1158 |
+
" | Linear(in_features=2, out_features=2, bias=True)\n",
|
1159 |
+
" | >>> linear.weight\n",
|
1160 |
+
" | Parameter containing:\n",
|
1161 |
+
" | tensor([[ 0.1914, -0.3420],\n",
|
1162 |
+
" | [-0.5112, -0.2324]], dtype=torch.float16)\n",
|
1163 |
+
" | \n",
|
1164 |
+
" | >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)\n",
|
1165 |
+
" | >>> linear.weight\n",
|
1166 |
+
" | Parameter containing:\n",
|
1167 |
+
" | tensor([[ 0.3741+0.j, 0.2382+0.j],\n",
|
1168 |
+
" | [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)\n",
|
1169 |
+
" | >>> linear(torch.ones(3, 2, dtype=torch.cdouble))\n",
|
1170 |
+
" | tensor([[0.6122+0.j, 0.1150+0.j],\n",
|
1171 |
+
" | [0.6122+0.j, 0.1150+0.j],\n",
|
1172 |
+
" | [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)\n",
|
1173 |
+
" | \n",
|
1174 |
+
" | to_empty(self: ~T, *, device: Union[str, torch.device]) -> ~T\n",
|
1175 |
+
" | Moves the parameters and buffers to the specified device without copying storage.\n",
|
1176 |
+
" | \n",
|
1177 |
+
" | Args:\n",
|
1178 |
+
" | device (:class:`torch.device`): The desired device of the parameters\n",
|
1179 |
+
" | and buffers in this module.\n",
|
1180 |
+
" | \n",
|
1181 |
+
" | Returns:\n",
|
1182 |
+
" | Module: self\n",
|
1183 |
+
" | \n",
|
1184 |
+
" | train(self: ~T, mode: bool = True) -> ~T\n",
|
1185 |
+
" | Sets the module in training mode.\n",
|
1186 |
+
" | \n",
|
1187 |
+
" | This has any effect only on certain modules. See documentations of\n",
|
1188 |
+
" | particular modules for details of their behaviors in training/evaluation\n",
|
1189 |
+
" | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
|
1190 |
+
" | etc.\n",
|
1191 |
+
" | \n",
|
1192 |
+
" | Args:\n",
|
1193 |
+
" | mode (bool): whether to set training mode (``True``) or evaluation\n",
|
1194 |
+
" | mode (``False``). Default: ``True``.\n",
|
1195 |
+
" | \n",
|
1196 |
+
" | Returns:\n",
|
1197 |
+
" | Module: self\n",
|
1198 |
+
" | \n",
|
1199 |
+
" | type(self: ~T, dst_type: Union[torch.dtype, str]) -> ~T\n",
|
1200 |
+
" | Casts all parameters and buffers to :attr:`dst_type`.\n",
|
1201 |
+
" | \n",
|
1202 |
+
" | .. note::\n",
|
1203 |
+
" | This method modifies the module in-place.\n",
|
1204 |
+
" | \n",
|
1205 |
+
" | Args:\n",
|
1206 |
+
" | dst_type (type or string): the desired type\n",
|
1207 |
+
" | \n",
|
1208 |
+
" | Returns:\n",
|
1209 |
+
" | Module: self\n",
|
1210 |
+
" | \n",
|
1211 |
+
" | xpu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
|
1212 |
+
" | Moves all model parameters and buffers to the XPU.\n",
|
1213 |
+
" | \n",
|
1214 |
+
" | This also makes associated parameters and buffers different objects. So\n",
|
1215 |
+
" | it should be called before constructing optimizer if the module will\n",
|
1216 |
+
" | live on XPU while being optimized.\n",
|
1217 |
+
" | \n",
|
1218 |
+
" | .. note::\n",
|
1219 |
+
" | This method modifies the module in-place.\n",
|
1220 |
+
" | \n",
|
1221 |
+
" | Arguments:\n",
|
1222 |
+
" | device (int, optional): if specified, all parameters will be\n",
|
1223 |
+
" | copied to that device\n",
|
1224 |
+
" | \n",
|
1225 |
+
" | Returns:\n",
|
1226 |
+
" | Module: self\n",
|
1227 |
+
" | \n",
|
1228 |
+
" | zero_grad(self, set_to_none: bool = True) -> None\n",
|
1229 |
+
" | Sets gradients of all model parameters to zero. See similar function\n",
|
1230 |
+
" | under :class:`torch.optim.Optimizer` for more context.\n",
|
1231 |
+
" | \n",
|
1232 |
+
" | Args:\n",
|
1233 |
+
" | set_to_none (bool): instead of setting to zero, set the grads to None.\n",
|
1234 |
+
" | See :meth:`torch.optim.Optimizer.zero_grad` for details.\n",
|
1235 |
+
" | \n",
|
1236 |
+
" | ----------------------------------------------------------------------\n",
|
1237 |
+
" | Data descriptors inherited from torch.nn.modules.module.Module:\n",
|
1238 |
+
" | \n",
|
1239 |
+
" | __dict__\n",
|
1240 |
+
" | dictionary for instance variables (if defined)\n",
|
1241 |
+
" | \n",
|
1242 |
+
" | __weakref__\n",
|
1243 |
+
" | list of weak references to the object (if defined)\n",
|
1244 |
+
" | \n",
|
1245 |
+
" | ----------------------------------------------------------------------\n",
|
1246 |
+
" | Data and other attributes inherited from torch.nn.modules.module.Module:\n",
|
1247 |
+
" | \n",
|
1248 |
+
" | T_destination = ~T_destination\n",
|
1249 |
+
" | \n",
|
1250 |
+
" | call_super_init = False\n",
|
1251 |
+
" | \n",
|
1252 |
+
" | dump_patches = False\n",
|
1253 |
+
"\n"
|
1254 |
+
]
|
1255 |
+
}
|
1256 |
+
],
|
1257 |
+
"source": [
|
1258 |
+
"help(eva02_model)"
|
1259 |
+
]
|
1260 |
+
},
|
1261 |
+
{
|
1262 |
+
"cell_type": "markdown",
|
1263 |
+
"id": "2f5ac1a7-6f1b-4417-8a67-1b2e32d385dd",
|
1264 |
+
"metadata": {},
|
1265 |
+
"source": [
|
1266 |
+
"# DETR"
|
1267 |
+
]
|
1268 |
+
},
|
1269 |
+
{
|
1270 |
+
"cell_type": "code",
|
1271 |
+
"execution_count": 33,
|
1272 |
+
"id": "5c3ade1b-18ea-4368-abd9-53be1fdfb610",
|
1273 |
+
"metadata": {},
|
1274 |
+
"outputs": [
|
1275 |
+
{
|
1276 |
+
"name": "stdout",
|
1277 |
+
"output_type": "stream",
|
1278 |
+
"text": [
|
1279 |
+
"[2023-08-28 01:51:14,033] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
1280 |
+
]
|
1281 |
+
},
|
1282 |
+
{
|
1283 |
+
"name": "stderr",
|
1284 |
+
"output_type": "stream",
|
1285 |
+
"text": [
|
1286 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.\n"
|
1287 |
+
]
|
1288 |
+
}
|
1289 |
+
],
|
1290 |
+
"source": [
|
1291 |
+
"from transformers import DetrImageProcessor, DetrForObjectDetection\n",
|
1292 |
+
"import torch\n",
|
1293 |
+
"from PIL import Image\n",
|
1294 |
+
"import requests\n",
|
1295 |
+
"\n",
|
1296 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
1297 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
1298 |
+
"\n",
|
1299 |
+
"processor = DetrImageProcessor.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')\n",
|
1300 |
+
"model = DetrForObjectDetection.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')"
|
1301 |
+
]
|
1302 |
+
},
|
1303 |
+
{
|
1304 |
+
"cell_type": "code",
|
1305 |
+
"execution_count": 34,
|
1306 |
+
"id": "1d5aa2d7-4868-4751-8d90-7c52be028cd9",
|
1307 |
+
"metadata": {},
|
1308 |
+
"outputs": [],
|
1309 |
+
"source": [
|
1310 |
+
"inputs = processor(images=image, return_tensors=\"pt\")\n",
|
1311 |
+
"outputs = model(**inputs)"
|
1312 |
+
]
|
1313 |
+
},
|
1314 |
+
{
|
1315 |
+
"cell_type": "code",
|
1316 |
+
"execution_count": 35,
|
1317 |
+
"id": "ae6bafc6-cee4-4e59-b7ba-12efc2a65b74",
|
1318 |
+
"metadata": {},
|
1319 |
+
"outputs": [
|
1320 |
+
{
|
1321 |
+
"name": "stdout",
|
1322 |
+
"output_type": "stream",
|
1323 |
+
"text": [
|
1324 |
+
"Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]\n",
|
1325 |
+
"Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]\n",
|
1326 |
+
"Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]\n",
|
1327 |
+
"Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]\n",
|
1328 |
+
"Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]\n"
|
1329 |
+
]
|
1330 |
+
}
|
1331 |
+
],
|
1332 |
+
"source": [
|
1333 |
+
"# convert outputs (bounding boxes and class logits) to COCO API\n",
|
1334 |
+
"# let's only keep detections with score > 0.9\n",
|
1335 |
+
"target_sizes = torch.tensor([image.size[::-1]])\n",
|
1336 |
+
"results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]\n",
|
1337 |
+
"\n",
|
1338 |
+
"for score, label, box in zip(results[\"scores\"], results[\"labels\"], results[\"boxes\"]):\n",
|
1339 |
+
" box = [round(i, 2) for i in box.tolist()]\n",
|
1340 |
+
" print(\n",
|
1341 |
+
" f\"Detected {model.config.id2label[label.item()]} with confidence \"\n",
|
1342 |
+
" f\"{round(score.item(), 3)} at location {box}\"\n",
|
1343 |
+
" )"
|
1344 |
+
]
|
1345 |
+
},
|
1346 |
+
{
|
1347 |
+
"cell_type": "code",
|
1348 |
+
"execution_count": 36,
|
1349 |
+
"id": "6dcc5934-79d4-4062-8b32-e42b3ebcdc0f",
|
1350 |
+
"metadata": {},
|
1351 |
+
"outputs": [
|
1352 |
+
{
|
1353 |
+
"data": {
|
1354 |
+
"text/plain": [
|
1355 |
+
"DetrImageProcessor {\n",
|
1356 |
+
" \"do_normalize\": true,\n",
|
1357 |
+
" \"do_pad\": true,\n",
|
1358 |
+
" \"do_rescale\": true,\n",
|
1359 |
+
" \"do_resize\": true,\n",
|
1360 |
+
" \"feature_extractor_type\": \"DetrFeatureExtractor\",\n",
|
1361 |
+
" \"format\": \"coco_detection\",\n",
|
1362 |
+
" \"image_mean\": [\n",
|
1363 |
+
" 0.485,\n",
|
1364 |
+
" 0.456,\n",
|
1365 |
+
" 0.406\n",
|
1366 |
+
" ],\n",
|
1367 |
+
" \"image_processor_type\": \"DetrImageProcessor\",\n",
|
1368 |
+
" \"image_std\": [\n",
|
1369 |
+
" 0.229,\n",
|
1370 |
+
" 0.224,\n",
|
1371 |
+
" 0.225\n",
|
1372 |
+
" ],\n",
|
1373 |
+
" \"resample\": 2,\n",
|
1374 |
+
" \"rescale_factor\": 0.00392156862745098,\n",
|
1375 |
+
" \"size\": {\n",
|
1376 |
+
" \"longest_edge\": 1333,\n",
|
1377 |
+
" \"shortest_edge\": 800\n",
|
1378 |
+
" }\n",
|
1379 |
+
"}"
|
1380 |
+
]
|
1381 |
+
},
|
1382 |
+
"execution_count": 36,
|
1383 |
+
"metadata": {},
|
1384 |
+
"output_type": "execute_result"
|
1385 |
+
}
|
1386 |
+
],
|
1387 |
+
"source": [
|
1388 |
+
"processor"
|
1389 |
+
]
|
1390 |
+
},
|
1391 |
+
{
|
1392 |
+
"cell_type": "markdown",
|
1393 |
+
"id": "db1d89cc-b432-473e-af69-d81c435ac731",
|
1394 |
+
"metadata": {},
|
1395 |
+
"source": [
|
1396 |
+
"# CLIPSeg"
|
1397 |
+
]
|
1398 |
+
},
|
1399 |
+
{
|
1400 |
+
"cell_type": "code",
|
1401 |
+
"execution_count": 37,
|
1402 |
+
"id": "15db14d1-ee4d-4429-9286-054c4498293b",
|
1403 |
+
"metadata": {},
|
1404 |
+
"outputs": [],
|
1405 |
+
"source": [
|
1406 |
+
"from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n",
|
1407 |
+
"\n",
|
1408 |
+
"processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')\n",
|
1409 |
+
"model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')"
|
1410 |
+
]
|
1411 |
+
},
|
1412 |
+
{
|
1413 |
+
"cell_type": "code",
|
1414 |
+
"execution_count": 38,
|
1415 |
+
"id": "4aa225d4-5a3b-4dbb-ae57-dea2872ff492",
|
1416 |
+
"metadata": {},
|
1417 |
+
"outputs": [
|
1418 |
+
{
|
1419 |
+
"ename": "AttributeError",
|
1420 |
+
"evalue": "'JpegImageFile' object has no attribute 'shape'",
|
1421 |
+
"output_type": "error",
|
1422 |
+
"traceback": [
|
1423 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1424 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
1425 |
+
"Cell \u001b[0;32mIn[38], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mimage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\n",
|
1426 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'JpegImageFile' object has no attribute 'shape'"
|
1427 |
+
]
|
1428 |
+
}
|
1429 |
+
],
|
1430 |
+
"source": [
|
1431 |
+
"image.shape"
|
1432 |
+
]
|
1433 |
+
},
|
1434 |
+
{
|
1435 |
+
"cell_type": "code",
|
1436 |
+
"execution_count": null,
|
1437 |
+
"id": "ad7e2daf-0c7c-4fec-b29e-9ba47a037c6b",
|
1438 |
+
"metadata": {},
|
1439 |
+
"outputs": [],
|
1440 |
+
"source": [
|
1441 |
+
"from PIL import Image\n",
|
1442 |
+
"import requests\n",
|
1443 |
+
"import h5py\n",
|
1444 |
+
"\n",
|
1445 |
+
"# url = \"https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640\"\n",
|
1446 |
+
"# image = Image.open(requests.get(url, stream=True).raw)\n",
|
1447 |
+
"\n",
|
1448 |
+
"image_path = \"/fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\"\n",
|
1449 |
+
"with h5py.File(image_path, 'r') as file:\n",
|
1450 |
+
" image = file['images'][0]\n",
|
1451 |
+
"image = np.moveaxis(image, 0, -1).astype(np.float32)\n",
|
1452 |
+
"plt.imshow(image)\n",
|
1453 |
+
"\n",
|
1454 |
+
"prompts = [\"person\",\"animal\",\"object\",\"background\"]\n",
|
1455 |
+
"import torch\n",
|
1456 |
+
"\n",
|
1457 |
+
"# Rescale to [0, 255]\n",
|
1458 |
+
"array = (image * 255).astype(np.uint8)\n",
|
1459 |
+
"\n",
|
1460 |
+
"# Convert to PIL image\n",
|
1461 |
+
"image = Image.fromarray(array)\n",
|
1462 |
+
"\n",
|
1463 |
+
"inputs = processor(text=prompts, images=[image] * len(prompts), padding=\"max_length\", return_tensors=\"pt\")\n",
|
1464 |
+
"# predict\n",
|
1465 |
+
"with torch.no_grad():\n",
|
1466 |
+
" outputs = model(**inputs)\n",
|
1467 |
+
"preds = outputs.logits.unsqueeze(1)\n",
|
1468 |
+
"print(preds.shape)"
|
1469 |
+
]
|
1470 |
+
},
|
1471 |
+
{
|
1472 |
+
"cell_type": "code",
|
1473 |
+
"execution_count": null,
|
1474 |
+
"id": "131eb5b7-2f16-4a79-8402-edc1a1d8c348",
|
1475 |
+
"metadata": {},
|
1476 |
+
"outputs": [],
|
1477 |
+
"source": [
|
1478 |
+
"preds = ((preds[0] + preds[1] + preds[2] + preds[-1].max() - preds[-1]) / 4)[None]\n",
|
1479 |
+
"preds.shape"
|
1480 |
+
]
|
1481 |
+
},
|
1482 |
+
{
|
1483 |
+
"cell_type": "code",
|
1484 |
+
"execution_count": null,
|
1485 |
+
"id": "e2bf99e7-064d-4c22-997f-aa1a35dbab82",
|
1486 |
+
"metadata": {},
|
1487 |
+
"outputs": [],
|
1488 |
+
"source": [
|
1489 |
+
"_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))\n",
|
1490 |
+
"[a.axis('off') for a in ax.flatten()]\n",
|
1491 |
+
"ax[0].imshow(image)\n",
|
1492 |
+
"[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(1)];\n",
|
1493 |
+
"# [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];"
|
1494 |
+
]
|
1495 |
+
},
|
1496 |
+
{
|
1497 |
+
"cell_type": "code",
|
1498 |
+
"execution_count": null,
|
1499 |
+
"id": "b58b926f-a2b2-423b-b367-18808cf6b4f7",
|
1500 |
+
"metadata": {},
|
1501 |
+
"outputs": [],
|
1502 |
+
"source": []
|
1503 |
+
}
|
1504 |
+
],
|
1505 |
+
"metadata": {
|
1506 |
+
"kernelspec": {
|
1507 |
+
"display_name": "Python 3 (ipykernel)",
|
1508 |
+
"language": "python",
|
1509 |
+
"name": "python3"
|
1510 |
+
},
|
1511 |
+
"language_info": {
|
1512 |
+
"codemirror_mode": {
|
1513 |
+
"name": "ipython",
|
1514 |
+
"version": 3
|
1515 |
+
},
|
1516 |
+
"file_extension": ".py",
|
1517 |
+
"mimetype": "text/x-python",
|
1518 |
+
"name": "python",
|
1519 |
+
"nbconvert_exporter": "python",
|
1520 |
+
"pygments_lexer": "ipython3",
|
1521 |
+
"version": "3.10.8"
|
1522 |
+
}
|
1523 |
+
},
|
1524 |
+
"nbformat": 4,
|
1525 |
+
"nbformat_minor": 5
|
1526 |
+
}
|
src/cnd_prov/cnd_prov-Copy1.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#new one
|
2 |
+
import pandas as pd
|
3 |
+
import datasets
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import matplotlib
|
9 |
+
|
10 |
+
_VERSION = datasets.Version("0.0.3")
|
11 |
+
|
12 |
+
_DESCRIPTION = "TODO"
|
13 |
+
_HOMEPAGE = "TODO"
|
14 |
+
_LICENSE = "TODO"
|
15 |
+
_CITATION = "TODO"
|
16 |
+
|
17 |
+
_FEATURES = datasets.Features(
|
18 |
+
{
|
19 |
+
"target": datasets.Image(),
|
20 |
+
"source": datasets.Image(),
|
21 |
+
"heatmap": datasets.Image(),
|
22 |
+
"depth": datasets.Image(),
|
23 |
+
"prompt": datasets.Value("string"),
|
24 |
+
},
|
25 |
+
)
|
26 |
+
|
27 |
+
METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl"
|
28 |
+
SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source"
|
29 |
+
TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target"
|
30 |
+
HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg"
|
31 |
+
DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth"
|
32 |
+
|
33 |
+
# METADATA_URL = hf_hub_url(
|
34 |
+
# "fusing/fill50k",
|
35 |
+
# filename="train.jsonl",
|
36 |
+
# repo_type="dataset",
|
37 |
+
# )
|
38 |
+
|
39 |
+
# IMAGES_URL = hf_hub_url(
|
40 |
+
# "fusing/fill50k",
|
41 |
+
# filename="images.zip",
|
42 |
+
# repo_type="dataset",
|
43 |
+
# )
|
44 |
+
|
45 |
+
# CONDITIONING_IMAGES_URL = hf_hub_url(
|
46 |
+
# "fusing/fill50k",
|
47 |
+
# filename="conditioning_images.zip",
|
48 |
+
# repo_type="dataset",
|
49 |
+
# )
|
50 |
+
|
51 |
+
_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
|
52 |
+
|
53 |
+
|
54 |
+
class CocoTest(datasets.GeneratorBasedBuilder):
|
55 |
+
BUILDER_CONFIGS = [_DEFAULT_CONFIG]
|
56 |
+
DEFAULT_CONFIG_NAME = "default"
|
57 |
+
|
58 |
+
def _info(self):
|
59 |
+
return datasets.DatasetInfo(
|
60 |
+
description=_DESCRIPTION,
|
61 |
+
features=_FEATURES,
|
62 |
+
supervised_keys=None,
|
63 |
+
homepage=_HOMEPAGE,
|
64 |
+
license=_LICENSE,
|
65 |
+
citation=_CITATION,
|
66 |
+
)
|
67 |
+
|
68 |
+
def _split_generators(self, dl_manager):
|
69 |
+
metadata_path = METADATA_DIR
|
70 |
+
target_dir = TARGET_DIR
|
71 |
+
source_dir = SOURCE_DIR
|
72 |
+
heatmap_dir = HEATMAP_DIR
|
73 |
+
depth_dir = DEPTH_DIR
|
74 |
+
|
75 |
+
return [
|
76 |
+
datasets.SplitGenerator(
|
77 |
+
name=datasets.Split.TRAIN,
|
78 |
+
# These kwargs will be passed to _generate_examples
|
79 |
+
gen_kwargs={
|
80 |
+
"metadata_path": metadata_path,
|
81 |
+
"target_dir": TARGET_DIR,
|
82 |
+
"source_dir": SOURCE_DIR,
|
83 |
+
"heatmap_dir": HEATMAP_DIR,
|
84 |
+
"depth_dir": DEPTH_DIR,
|
85 |
+
"num_examples": 120,
|
86 |
+
},
|
87 |
+
),
|
88 |
+
datasets.SplitGenerator(
|
89 |
+
name=datasets.Split.VALIDATION,
|
90 |
+
# These kwargs will be passed to _generate_examples
|
91 |
+
gen_kwargs={
|
92 |
+
"metadata_path": metadata_path,
|
93 |
+
"target_dir": TARGET_DIR,
|
94 |
+
"source_dir": SOURCE_DIR,
|
95 |
+
"heatmap_dir": HEATMAP_DIR,
|
96 |
+
"depth_dir": DEPTH_DIR,
|
97 |
+
"num_examples": 120,
|
98 |
+
},
|
99 |
+
),
|
100 |
+
]
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples):
|
105 |
+
data = []
|
106 |
+
with open(metadata_path, 'rb') as f:
|
107 |
+
loaded_data = pickle.load(f)
|
108 |
+
for line in loaded_data[:num_examples]:
|
109 |
+
data.append(line)
|
110 |
+
|
111 |
+
for _, item in enumerate(data):
|
112 |
+
source_filename = item['source']
|
113 |
+
target_filename = item['target']
|
114 |
+
heatmap_filename = item['h_map']
|
115 |
+
depth_filename = item['depth']
|
116 |
+
prompt = item['prompt']
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
tgt_img = open(target_filename, "rb").read()
|
121 |
+
src_img = open(source_filename, "rb").read()
|
122 |
+
h_img = open(heatmap_filename, "rb").read()
|
123 |
+
|
124 |
+
|
125 |
+
# one channel depth image
|
126 |
+
d_img = Image.open(depth_filename).convert('1')
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
yield item["target"], {
|
131 |
+
"prompt": prompt,
|
132 |
+
"target": {
|
133 |
+
"path": target_filename,
|
134 |
+
"bytes": tgt_img,
|
135 |
+
},
|
136 |
+
"source": {
|
137 |
+
"path": source_filename,
|
138 |
+
"bytes": src_img,
|
139 |
+
},
|
140 |
+
"heatmap": {
|
141 |
+
"path": heatmap_filename,
|
142 |
+
"bytes": h_img,
|
143 |
+
},
|
144 |
+
"depth": {
|
145 |
+
"path": depth_filename,
|
146 |
+
"bytes": d_img,
|
147 |
+
},
|
148 |
+
}
|
src/cnd_prov/cnd_prov.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#new one
|
2 |
+
import pandas as pd
|
3 |
+
import datasets
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
_VERSION = datasets.Version("0.0.3")
|
8 |
+
|
9 |
+
_DESCRIPTION = "TODO"
|
10 |
+
_HOMEPAGE = "TODO"
|
11 |
+
_LICENSE = "TODO"
|
12 |
+
_CITATION = "TODO"
|
13 |
+
|
14 |
+
_FEATURES = datasets.Features(
|
15 |
+
{
|
16 |
+
"target": datasets.Image(),
|
17 |
+
"source": datasets.Image(),
|
18 |
+
"heatmap": datasets.Image(),
|
19 |
+
"depth": datasets.Image(),
|
20 |
+
"prompt": datasets.Value("string"),
|
21 |
+
},
|
22 |
+
)
|
23 |
+
|
24 |
+
METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl"
|
25 |
+
SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source"
|
26 |
+
TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target"
|
27 |
+
HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg"
|
28 |
+
DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth"
|
29 |
+
|
30 |
+
# METADATA_URL = hf_hub_url(
|
31 |
+
# "fusing/fill50k",
|
32 |
+
# filename="train.jsonl",
|
33 |
+
# repo_type="dataset",
|
34 |
+
# )
|
35 |
+
|
36 |
+
# IMAGES_URL = hf_hub_url(
|
37 |
+
# "fusing/fill50k",
|
38 |
+
# filename="images.zip",
|
39 |
+
# repo_type="dataset",
|
40 |
+
# )
|
41 |
+
|
42 |
+
# CONDITIONING_IMAGES_URL = hf_hub_url(
|
43 |
+
# "fusing/fill50k",
|
44 |
+
# filename="conditioning_images.zip",
|
45 |
+
# repo_type="dataset",
|
46 |
+
# )
|
47 |
+
|
48 |
+
_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
|
49 |
+
|
50 |
+
|
51 |
+
class CocoTest(datasets.GeneratorBasedBuilder):
|
52 |
+
BUILDER_CONFIGS = [_DEFAULT_CONFIG]
|
53 |
+
DEFAULT_CONFIG_NAME = "default"
|
54 |
+
|
55 |
+
def _info(self):
|
56 |
+
return datasets.DatasetInfo(
|
57 |
+
description=_DESCRIPTION,
|
58 |
+
features=_FEATURES,
|
59 |
+
supervised_keys=None,
|
60 |
+
homepage=_HOMEPAGE,
|
61 |
+
license=_LICENSE,
|
62 |
+
citation=_CITATION,
|
63 |
+
)
|
64 |
+
|
65 |
+
def _split_generators(self, dl_manager):
|
66 |
+
metadata_path = METADATA_DIR
|
67 |
+
target_dir = TARGET_DIR
|
68 |
+
source_dir = SOURCE_DIR
|
69 |
+
heatmap_dir = HEATMAP_DIR
|
70 |
+
depth_dir = DEPTH_DIR
|
71 |
+
|
72 |
+
return [
|
73 |
+
datasets.SplitGenerator(
|
74 |
+
name=datasets.Split.TRAIN,
|
75 |
+
# These kwargs will be passed to _generate_examples
|
76 |
+
gen_kwargs={
|
77 |
+
"metadata_path": metadata_path,
|
78 |
+
"target_dir": TARGET_DIR,
|
79 |
+
"source_dir": SOURCE_DIR,
|
80 |
+
"heatmap_dir": HEATMAP_DIR,
|
81 |
+
"depth_dir": DEPTH_DIR,
|
82 |
+
"num_examples": 190573,
|
83 |
+
},
|
84 |
+
),
|
85 |
+
datasets.SplitGenerator(
|
86 |
+
name=datasets.Split.VALIDATION,
|
87 |
+
# These kwargs will be passed to _generate_examples
|
88 |
+
gen_kwargs={
|
89 |
+
"metadata_path": metadata_path,
|
90 |
+
"target_dir": TARGET_DIR,
|
91 |
+
"source_dir": SOURCE_DIR,
|
92 |
+
"heatmap_dir": HEATMAP_DIR,
|
93 |
+
"depth_dir": DEPTH_DIR,
|
94 |
+
"num_examples": 20000,
|
95 |
+
},
|
96 |
+
),
|
97 |
+
]
|
98 |
+
|
99 |
+
def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples):
|
100 |
+
data = []
|
101 |
+
with open(metadata_path, 'rb') as f:
|
102 |
+
loaded_data = pickle.load(f)
|
103 |
+
for line in loaded_data[:num_examples]:
|
104 |
+
data.append(line)
|
105 |
+
|
106 |
+
for _, item in enumerate(data):
|
107 |
+
source_filename = item['source']
|
108 |
+
target_filename = item['target']
|
109 |
+
heatmap_filename = item['h_map']
|
110 |
+
depth_filename = item['depth']
|
111 |
+
prompt = item['prompt']
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
tgt_img = open(target_filename, "rb").read()
|
116 |
+
src_img = open(source_filename, "rb").read()
|
117 |
+
h_img = open(heatmap_filename, "rb").read()
|
118 |
+
d_img = open(depth_filename, "rb").read()
|
119 |
+
|
120 |
+
yield item["target"], {
|
121 |
+
"prompt": prompt,
|
122 |
+
"target": {
|
123 |
+
"path": target_filename,
|
124 |
+
"bytes": tgt_img,
|
125 |
+
},
|
126 |
+
"source": {
|
127 |
+
"path": source_filename,
|
128 |
+
"bytes": src_img,
|
129 |
+
},
|
130 |
+
"heatmap": {
|
131 |
+
"path": heatmap_filename,
|
132 |
+
"bytes": h_img,
|
133 |
+
},
|
134 |
+
"depth": {
|
135 |
+
"path": depth_filename,
|
136 |
+
"bytes": d_img,
|
137 |
+
},
|
138 |
+
}
|
src/cnd_prov/data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5bab2a7c32b2da8c899f18015e17b88a876a3baa6f676a315d47e07c8713d313
|
3 |
+
size 58375712
|
src/deepspeed_config_stage1.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 1, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "none", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 8, "train_micro_batch_size_per_gpu": 8, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
|
src/deepspeed_config_stage2.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 2, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "none", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 128, "train_micro_batch_size_per_gpu": 128, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
|
src/deepspeed_config_stage3.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": false
|
4 |
+
},
|
5 |
+
"fp16": {
|
6 |
+
"enabled": true
|
7 |
+
},
|
8 |
+
"zero_optimization": {
|
9 |
+
"stage": 3,
|
10 |
+
"contiguous_gradients": true,
|
11 |
+
"stage3_gather_16bit_weights_on_model_save": true,
|
12 |
+
"stage3_max_live_parameters": 1e9,
|
13 |
+
"stage3_max_reuse_distance": 1e9,
|
14 |
+
"stage3_prefetch_bucket_size": 1e7,
|
15 |
+
"stage3_param_persistence_threshold": 1e5,
|
16 |
+
"reduce_bucket_size": 1e7,
|
17 |
+
"sub_group_size": 1e9,
|
18 |
+
"offload_optimizer": {
|
19 |
+
"device": "none",
|
20 |
+
"nvme_path": "/scratch",
|
21 |
+
"pin_memory": true
|
22 |
+
},
|
23 |
+
"offload_param": {
|
24 |
+
"device": "none",
|
25 |
+
"nvme_path": "/scratch",
|
26 |
+
"buffer_size": 4e9,
|
27 |
+
"pin_memory": true
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"aio": {
|
31 |
+
"block_size": 26214400,
|
32 |
+
"queue_depth": 32,
|
33 |
+
"thread_count": 1,
|
34 |
+
"single_submit": false,
|
35 |
+
"overlap_events": true
|
36 |
+
},
|
37 |
+
"gradient_accumulation_steps": 1,
|
38 |
+
"gradient_clipping": 1.0,
|
39 |
+
"steps_per_print": 20000,
|
40 |
+
"train_batch_size": 128,
|
41 |
+
"train_micro_batch_size_per_gpu": 16,
|
42 |
+
"wall_clock_breakdown": false,
|
43 |
+
"zero_allow_untested_optimizer": true
|
44 |
+
}
|
src/getdepthimages.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/huggingface_to_s3.ipynb
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "cf698d59-1cc2-4859-9c43-9a5d4d924ee1",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Transfer huggingface mindeyev2 dataset to Stability aws s3"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
+
"id": "94c7404c-7a0f-4508-a630-954bc9af11fa",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/shared1000.npy -O /fsx/proj-fmri/shared/mindeyev2_dataset/shared1000.npy\n",
|
22 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj01.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj01.hdf5\n",
|
23 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj02.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj02.hdf5\n",
|
24 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj03.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj03.hdf5\n",
|
25 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj04.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj04.hdf5\n",
|
26 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj05.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj05.hdf5\n",
|
27 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj06.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj06.hdf5\n",
|
28 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj07.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj07.hdf5\n",
|
29 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj08.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj08.hdf5\n",
|
30 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/coco_images_224_float16.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\n",
|
31 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/COCO_73k_subj_indices.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_subj_indices.hdf5\n",
|
32 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/0.tar\n",
|
33 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/1.tar\n",
|
34 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/2.tar\n",
|
35 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/3.tar\n",
|
36 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/4.tar\n",
|
37 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/5.tar\n",
|
38 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/6.tar\n",
|
39 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/7.tar\n",
|
40 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/8.tar\n",
|
41 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/9.tar\n",
|
42 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/10.tar\n",
|
43 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/11.tar\n",
|
44 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/12.tar\n",
|
45 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/13.tar\n",
|
46 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/14.tar\n",
|
47 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/15.tar\n",
|
48 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/16.tar\n",
|
49 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/17.tar\n",
|
50 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/18.tar\n",
|
51 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/19.tar\n",
|
52 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/20.tar\n",
|
53 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/21.tar\n",
|
54 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/22.tar\n",
|
55 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/23.tar\n",
|
56 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/24.tar\n",
|
57 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/25.tar\n",
|
58 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/26.tar\n",
|
59 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/27.tar\n",
|
60 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/28.tar\n",
|
61 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/29.tar\n",
|
62 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/30.tar\n",
|
63 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/31.tar\n",
|
64 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/32.tar\n",
|
65 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/33.tar\n",
|
66 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/34.tar\n",
|
67 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/35.tar\n",
|
68 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/36.tar\n",
|
69 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n",
|
70 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/0.tar\n",
|
71 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/1.tar\n",
|
72 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/2.tar\n",
|
73 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/3.tar\n",
|
74 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/4.tar\n",
|
75 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/5.tar\n",
|
76 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/6.tar\n",
|
77 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/7.tar\n",
|
78 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/8.tar\n",
|
79 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/9.tar\n",
|
80 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/10.tar\n",
|
81 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/11.tar\n",
|
82 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/12.tar\n",
|
83 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/13.tar\n",
|
84 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/14.tar\n",
|
85 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/15.tar\n",
|
86 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/16.tar\n",
|
87 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/17.tar\n",
|
88 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/18.tar\n",
|
89 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/19.tar\n",
|
90 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/20.tar\n",
|
91 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/21.tar\n",
|
92 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/22.tar\n",
|
93 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/23.tar\n",
|
94 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/24.tar\n",
|
95 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/25.tar\n",
|
96 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/26.tar\n",
|
97 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/27.tar\n",
|
98 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/28.tar\n",
|
99 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/29.tar\n",
|
100 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/30.tar\n",
|
101 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/31.tar\n",
|
102 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/32.tar\n",
|
103 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/33.tar\n",
|
104 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/34.tar\n",
|
105 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/35.tar\n",
|
106 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/36.tar\n",
|
107 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/test/0.tar\n",
|
108 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/0.tar\n",
|
109 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/1.tar\n",
|
110 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/2.tar\n",
|
111 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/3.tar\n",
|
112 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/4.tar\n",
|
113 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/5.tar\n",
|
114 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/6.tar\n",
|
115 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/7.tar\n",
|
116 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/8.tar\n",
|
117 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/9.tar\n",
|
118 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/10.tar\n",
|
119 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/11.tar\n",
|
120 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/12.tar\n",
|
121 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/13.tar\n",
|
122 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/14.tar\n",
|
123 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/15.tar\n",
|
124 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/16.tar\n",
|
125 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/17.tar\n",
|
126 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/18.tar\n",
|
127 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/19.tar\n",
|
128 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/20.tar\n",
|
129 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/21.tar\n",
|
130 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/22.tar\n",
|
131 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/23.tar\n",
|
132 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/24.tar\n",
|
133 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/25.tar\n",
|
134 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/26.tar\n",
|
135 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/27.tar\n",
|
136 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/28.tar\n",
|
137 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/29.tar\n",
|
138 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/30.tar\n",
|
139 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/31.tar\n",
|
140 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/32.tar\n",
|
141 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/33.tar\n",
|
142 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/34.tar\n",
|
143 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/35.tar\n",
|
144 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/36.tar\n",
|
145 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/test/0.tar\n",
|
146 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/0.tar\n",
|
147 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/1.tar\n",
|
148 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/2.tar\n",
|
149 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/3.tar\n",
|
150 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/4.tar\n",
|
151 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/5.tar\n",
|
152 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/6.tar\n",
|
153 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/7.tar\n",
|
154 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/8.tar\n",
|
155 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/9.tar\n",
|
156 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/10.tar\n",
|
157 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/11.tar\n",
|
158 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/12.tar\n",
|
159 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/13.tar\n",
|
160 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/14.tar\n",
|
161 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/15.tar\n",
|
162 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/16.tar\n",
|
163 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/17.tar\n",
|
164 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/18.tar\n",
|
165 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/19.tar\n",
|
166 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/20.tar\n",
|
167 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/21.tar\n",
|
168 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/22.tar\n",
|
169 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/23.tar\n",
|
170 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/24.tar\n",
|
171 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/25.tar\n",
|
172 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/26.tar\n",
|
173 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/27.tar\n",
|
174 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/28.tar\n",
|
175 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/29.tar\n",
|
176 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/30.tar\n",
|
177 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/31.tar\n",
|
178 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/32.tar\n",
|
179 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/33.tar\n",
|
180 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/34.tar\n",
|
181 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/35.tar\n",
|
182 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/36.tar\n",
|
183 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/test/0.tar\n",
|
184 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/0.tar\n",
|
185 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/1.tar\n",
|
186 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/2.tar\n",
|
187 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/3.tar\n",
|
188 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/4.tar\n",
|
189 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/5.tar\n",
|
190 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/6.tar\n",
|
191 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/7.tar\n",
|
192 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/8.tar\n",
|
193 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/9.tar\n",
|
194 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/10.tar\n",
|
195 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/11.tar\n",
|
196 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/12.tar\n",
|
197 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/13.tar\n",
|
198 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/14.tar\n",
|
199 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/15.tar\n",
|
200 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/16.tar\n",
|
201 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/17.tar\n",
|
202 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/18.tar\n",
|
203 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/19.tar\n",
|
204 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/20.tar\n",
|
205 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/21.tar\n",
|
206 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/22.tar\n",
|
207 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/23.tar\n",
|
208 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/24.tar\n",
|
209 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/25.tar\n",
|
210 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/26.tar\n",
|
211 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/27.tar\n",
|
212 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/28.tar\n",
|
213 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/29.tar\n",
|
214 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/30.tar\n",
|
215 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/31.tar\n",
|
216 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/32.tar\n",
|
217 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/33.tar\n",
|
218 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/34.tar\n",
|
219 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/35.tar\n",
|
220 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/36.tar\n",
|
221 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/test/0.tar\n",
|
222 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/0.tar\n",
|
223 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/1.tar\n",
|
224 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/2.tar\n",
|
225 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/3.tar\n",
|
226 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/4.tar\n",
|
227 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/5.tar\n",
|
228 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/6.tar\n",
|
229 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/7.tar\n",
|
230 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/8.tar\n",
|
231 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/9.tar\n",
|
232 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/10.tar\n",
|
233 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/11.tar\n",
|
234 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/12.tar\n",
|
235 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/13.tar\n",
|
236 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/14.tar\n",
|
237 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/15.tar\n",
|
238 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/16.tar\n",
|
239 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/17.tar\n",
|
240 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/18.tar\n",
|
241 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/19.tar\n",
|
242 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/20.tar\n",
|
243 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/21.tar\n",
|
244 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/22.tar\n",
|
245 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/23.tar\n",
|
246 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/24.tar\n",
|
247 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/25.tar\n",
|
248 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/26.tar\n",
|
249 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/27.tar\n",
|
250 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/28.tar\n",
|
251 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/29.tar\n",
|
252 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/30.tar\n",
|
253 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/31.tar\n",
|
254 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/32.tar\n",
|
255 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/33.tar\n",
|
256 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/34.tar\n",
|
257 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/35.tar\n",
|
258 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/36.tar\n",
|
259 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/test/0.tar\n",
|
260 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/0.tar\n",
|
261 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/1.tar\n",
|
262 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/2.tar\n",
|
263 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/3.tar\n",
|
264 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/4.tar\n",
|
265 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/5.tar\n",
|
266 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/6.tar\n",
|
267 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/7.tar\n",
|
268 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/8.tar\n",
|
269 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/9.tar\n",
|
270 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/10.tar\n",
|
271 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/11.tar\n",
|
272 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/12.tar\n",
|
273 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/13.tar\n",
|
274 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/14.tar\n",
|
275 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/15.tar\n",
|
276 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/16.tar\n",
|
277 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/17.tar\n",
|
278 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/18.tar\n",
|
279 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/19.tar\n",
|
280 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/20.tar\n",
|
281 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/21.tar\n",
|
282 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/22.tar\n",
|
283 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/23.tar\n",
|
284 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/24.tar\n",
|
285 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/25.tar\n",
|
286 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/26.tar\n",
|
287 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/27.tar\n",
|
288 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/28.tar\n",
|
289 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/29.tar\n",
|
290 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/30.tar\n",
|
291 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/31.tar\n",
|
292 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/32.tar\n",
|
293 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/33.tar\n",
|
294 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/34.tar\n",
|
295 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/35.tar\n",
|
296 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/36.tar\n",
|
297 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/test/0.tar\n",
|
298 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/0.tar\n",
|
299 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/1.tar\n",
|
300 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/2.tar\n",
|
301 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/3.tar\n",
|
302 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/4.tar\n",
|
303 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/5.tar\n",
|
304 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/6.tar\n",
|
305 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/7.tar\n",
|
306 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/8.tar\n",
|
307 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/9.tar\n",
|
308 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/10.tar\n",
|
309 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/11.tar\n",
|
310 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/12.tar\n",
|
311 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/13.tar\n",
|
312 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/14.tar\n",
|
313 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/15.tar\n",
|
314 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/16.tar\n",
|
315 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/17.tar\n",
|
316 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/18.tar\n",
|
317 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/19.tar\n",
|
318 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/20.tar\n",
|
319 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/21.tar\n",
|
320 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/22.tar\n",
|
321 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/23.tar\n",
|
322 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/24.tar\n",
|
323 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/25.tar\n",
|
324 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/26.tar\n",
|
325 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/27.tar\n",
|
326 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/28.tar\n",
|
327 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/29.tar\n",
|
328 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/30.tar\n",
|
329 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/31.tar\n",
|
330 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/32.tar\n",
|
331 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/33.tar\n",
|
332 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/34.tar\n",
|
333 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/35.tar\n",
|
334 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/36.tar\n",
|
335 |
+
"wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/test/0.tar\n",
|
336 |
+
"aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\n"
|
337 |
+
]
|
338 |
+
}
|
339 |
+
],
|
340 |
+
"source": [
|
341 |
+
"import os\n",
|
342 |
+
"# from subprocess import call\n",
|
343 |
+
"# PS Note: it's faster to print the wget statements and then manually copy paste all them into terminal than to use subprocess call()\n",
|
344 |
+
"tmp = '/fsx/proj-fmri/shared/mindeyev2_dataset/' #'/scratch/mindeyev2_dataset/'\n",
|
345 |
+
"\n",
|
346 |
+
"hf_base_link = 'https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/'\n",
|
347 |
+
"\n",
|
348 |
+
"os.makedirs(tmp,exist_ok=True)\n",
|
349 |
+
"\n",
|
350 |
+
"files = [\n",
|
351 |
+
" \"shared1000.npy\",\n",
|
352 |
+
" \"betas_all_subj01.hdf5\",\n",
|
353 |
+
" \"betas_all_subj02.hdf5\",\n",
|
354 |
+
" \"betas_all_subj03.hdf5\",\n",
|
355 |
+
" \"betas_all_subj04.hdf5\",\n",
|
356 |
+
" \"betas_all_subj05.hdf5\",\n",
|
357 |
+
" \"betas_all_subj06.hdf5\",\n",
|
358 |
+
" \"betas_all_subj07.hdf5\",\n",
|
359 |
+
" \"betas_all_subj08.hdf5\",\n",
|
360 |
+
" \"coco_images_224_float16.hdf5\",\n",
|
361 |
+
" \"COCO_73k_subj_indices.hdf5\",\n",
|
362 |
+
"]\n",
|
363 |
+
"\n",
|
364 |
+
"for f in files: \n",
|
365 |
+
" command = f\"wget --show-progress {hf_base_link}{f} -O {tmp}{f}\"\n",
|
366 |
+
" print(command)\n",
|
367 |
+
" # call(command,shell=True)\n",
|
368 |
+
"\n",
|
369 |
+
"for sub in range(1,9):\n",
|
370 |
+
" subject = f'subj0{sub}'\n",
|
371 |
+
"\n",
|
372 |
+
" tmp_fol = f'{tmp}wds/{subject}/'\n",
|
373 |
+
" os.makedirs(tmp_fol,exist_ok=True)\n",
|
374 |
+
" os.makedirs(tmp_fol+'train',exist_ok=True)\n",
|
375 |
+
" os.makedirs(tmp_fol+'test',exist_ok=True)\n",
|
376 |
+
"\n",
|
377 |
+
" for i in range(37):\n",
|
378 |
+
" link = f'train/{i}.tar'\n",
|
379 |
+
" command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
|
380 |
+
" print(command)\n",
|
381 |
+
" # call(command,shell=True)\n",
|
382 |
+
"\n",
|
383 |
+
" link = f'test/0.tar'\n",
|
384 |
+
" command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
|
385 |
+
" print(command)\n",
|
386 |
+
" # call(command,shell=True)\n",
|
387 |
+
"\n",
|
388 |
+
"command = \"aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\"\n",
|
389 |
+
"print(command)"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": null,
|
395 |
+
"id": "30966082-59c2-411c-9b2e-4f4e3f9eb0f3",
|
396 |
+
"metadata": {},
|
397 |
+
"outputs": [],
|
398 |
+
"source": []
|
399 |
+
}
|
400 |
+
],
|
401 |
+
"metadata": {
|
402 |
+
"kernelspec": {
|
403 |
+
"display_name": "Python 3 (ipykernel)",
|
404 |
+
"language": "python",
|
405 |
+
"name": "python3"
|
406 |
+
},
|
407 |
+
"language_info": {
|
408 |
+
"codemirror_mode": {
|
409 |
+
"name": "ipython",
|
410 |
+
"version": 3
|
411 |
+
},
|
412 |
+
"file_extension": ".py",
|
413 |
+
"mimetype": "text/x-python",
|
414 |
+
"name": "python",
|
415 |
+
"nbconvert_exporter": "python",
|
416 |
+
"pygments_lexer": "ipython3",
|
417 |
+
"version": "3.10.8"
|
418 |
+
}
|
419 |
+
},
|
420 |
+
"nbformat": 4,
|
421 |
+
"nbformat_minor": 5
|
422 |
+
}
|
src/models.py
ADDED
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import PIL
|
7 |
+
import clip
|
8 |
+
import open_clip
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
# for prior
|
12 |
+
from dalle2_pytorch import DiffusionPrior
|
13 |
+
from dalle2_pytorch.dalle2_pytorch import l2norm, default, exists
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
import random
|
16 |
+
import json
|
17 |
+
from dalle2_pytorch.train_configs import DiffusionPriorNetworkConfig
|
18 |
+
# vd prior
|
19 |
+
from dalle2_pytorch.dalle2_pytorch import RotaryEmbedding, CausalTransformer, SinusoidalPosEmb, MLP, Rearrange, repeat, rearrange, prob_mask_like, LayerNorm, RelPosBias, FeedForward, Attention
|
20 |
+
|
21 |
+
# for pipeline
|
22 |
+
from diffusers import StableDiffusionImageVariationPipeline, VersatileDiffusionDualGuidedPipeline
|
23 |
+
from typing import Callable, List, Optional, Union
|
24 |
+
|
25 |
+
from diffusers.models.vae import Decoder
|
26 |
+
|
27 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
+
|
29 |
+
|
30 |
+
# class BrainMLP(nn.Module):
|
31 |
+
# def __init__(self, out_dim=257*768, in_dim=15724, clip_size=768, h=4096):
|
32 |
+
# super().__init__()
|
33 |
+
# self.lin0 = nn.Sequential(
|
34 |
+
# nn.Linear(in_dim, h, bias=False),
|
35 |
+
# nn.LayerNorm(h),
|
36 |
+
# nn.GELU(inplace=True),
|
37 |
+
# nn.Dropout(0.5))
|
38 |
+
# self.mlp = nn.ModuleList([
|
39 |
+
# nn.Sequential(
|
40 |
+
# nn.Linear(h, h),
|
41 |
+
# nn.LayerNorm(h),
|
42 |
+
# nn.GELU(inplace=True),
|
43 |
+
# nn.Dropout(0.15)
|
44 |
+
# ) for _ in range(4)])
|
45 |
+
# self.lin1 = nn.Linear(h, out_dim, bias=True)
|
46 |
+
# self.proj = nn.Sequential(
|
47 |
+
# nn.LayerNorm(clip_size),
|
48 |
+
# nn.GELU(),
|
49 |
+
# nn.Linear(clip_size, 2048),
|
50 |
+
# nn.LayerNorm(2048),
|
51 |
+
# nn.GELU(),
|
52 |
+
# nn.Linear(2048, 2048),
|
53 |
+
# nn.LayerNorm(2048),
|
54 |
+
# nn.GELU(),
|
55 |
+
# nn.Linear(2048, clip_size))
|
56 |
+
# def forward(self, x):
|
57 |
+
# x = self.lin0(x)
|
58 |
+
# residual = x
|
59 |
+
# for res_block in range(self.n_blocks):
|
60 |
+
# x = self.mlp[res_block](x)
|
61 |
+
# x += residual
|
62 |
+
# residual = x
|
63 |
+
# diffusion_prior_input = self.lin1(x.reshape(len(x), -1))
|
64 |
+
# disjointed_clip_fmri = self.proj(diffusion_prior_input.reshape(
|
65 |
+
# len(x),-1, self.clip_size))
|
66 |
+
# return diffusion_prior_input, disjointed_clip_fmri
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
class Clipper(torch.nn.Module):
|
71 |
+
def __init__(self, clip_variant, clamp_embs=False, norm_embs=False,
|
72 |
+
hidden_state=False, device=torch.device('cpu')):
|
73 |
+
super().__init__()
|
74 |
+
assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \
|
75 |
+
"clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64"
|
76 |
+
print(clip_variant, device)
|
77 |
+
|
78 |
+
if clip_variant=="ViT-L/14" and hidden_state:
|
79 |
+
# from transformers import CLIPVisionModelWithProjection
|
80 |
+
# image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",cache_dir="/fsx/proj-medarc/fmri/cache")
|
81 |
+
from transformers import CLIPVisionModelWithProjection
|
82 |
+
sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
|
83 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval()
|
84 |
+
image_encoder = image_encoder.to(device)
|
85 |
+
for param in image_encoder.parameters():
|
86 |
+
param.requires_grad = False # dont need to calculate gradients
|
87 |
+
self.image_encoder = image_encoder
|
88 |
+
elif hidden_state:
|
89 |
+
raise Exception("hidden_state embeddings only works with ViT-L/14 right now")
|
90 |
+
|
91 |
+
clip_model, preprocess = clip.load(clip_variant, device=device)
|
92 |
+
clip_model.eval() # dont want to train model
|
93 |
+
for param in clip_model.parameters():
|
94 |
+
param.requires_grad = False # dont need to calculate gradients
|
95 |
+
|
96 |
+
self.clip = clip_model
|
97 |
+
self.clip_variant = clip_variant
|
98 |
+
if clip_variant == "RN50x64":
|
99 |
+
self.clip_size = (448,448)
|
100 |
+
else:
|
101 |
+
self.clip_size = (224,224)
|
102 |
+
|
103 |
+
preproc = transforms.Compose([
|
104 |
+
transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
|
105 |
+
transforms.CenterCrop(size=self.clip_size),
|
106 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
107 |
+
])
|
108 |
+
self.preprocess = preproc
|
109 |
+
self.hidden_state = hidden_state
|
110 |
+
self.mean = np.array([0.48145466, 0.4578275, 0.40821073])
|
111 |
+
self.std = np.array([0.26862954, 0.26130258, 0.27577711])
|
112 |
+
self.normalize = transforms.Normalize(self.mean, self.std)
|
113 |
+
self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist())
|
114 |
+
self.clamp_embs = clamp_embs
|
115 |
+
self.norm_embs = norm_embs
|
116 |
+
self.device= device
|
117 |
+
|
118 |
+
def versatile_normalize_embeddings(encoder_output):
|
119 |
+
embeds = encoder_output.last_hidden_state
|
120 |
+
embeds = image_encoder.vision_model.post_layernorm(embeds)
|
121 |
+
embeds = image_encoder.visual_projection(embeds)
|
122 |
+
return embeds
|
123 |
+
self.versatile_normalize_embeddings = versatile_normalize_embeddings
|
124 |
+
|
125 |
+
def resize_image(self, image):
|
126 |
+
# note: antialias should be False if planning to use Pinkney's Image Variation SD model
|
127 |
+
return transforms.Resize(self.clip_size)(image.to(self.device))
|
128 |
+
|
129 |
+
def embed_image(self, image):
|
130 |
+
"""Expects images in -1 to 1 range"""
|
131 |
+
if self.hidden_state:
|
132 |
+
# clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation
|
133 |
+
clip_emb = self.preprocess((image).to(self.device))
|
134 |
+
clip_emb = self.image_encoder(clip_emb)
|
135 |
+
clip_emb = self.versatile_normalize_embeddings(clip_emb)
|
136 |
+
else:
|
137 |
+
clip_emb = self.preprocess(image.to(self.device))
|
138 |
+
clip_emb = self.clip.encode_image(clip_emb)
|
139 |
+
# input is now in CLIP space, but mind-reader preprint further processes embeddings:
|
140 |
+
if self.clamp_embs:
|
141 |
+
clip_emb = torch.clamp(clip_emb, -1.5, 1.5)
|
142 |
+
if self.norm_embs:
|
143 |
+
if self.hidden_state:
|
144 |
+
# normalize all tokens by cls token's norm
|
145 |
+
clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1)
|
146 |
+
else:
|
147 |
+
clip_emb = nn.functional.normalize(clip_emb, dim=-1)
|
148 |
+
return clip_emb
|
149 |
+
|
150 |
+
def embed_text(self, text_samples):
|
151 |
+
clip_text = clip.tokenize(text_samples).to(self.device)
|
152 |
+
clip_text = self.clip.encode_text(clip_text)
|
153 |
+
if self.clamp_embs:
|
154 |
+
clip_text = torch.clamp(clip_text, -1.5, 1.5)
|
155 |
+
if self.norm_embs:
|
156 |
+
clip_text = nn.functional.normalize(clip_text, dim=-1)
|
157 |
+
return clip_text
|
158 |
+
|
159 |
+
def embed_curated_annotations(self, annots):
|
160 |
+
for i,b in enumerate(annots):
|
161 |
+
t = ''
|
162 |
+
while t == '':
|
163 |
+
rand = torch.randint(5,(1,1))[0][0]
|
164 |
+
t = b[0,rand]
|
165 |
+
if i==0:
|
166 |
+
txt = np.array(t)
|
167 |
+
else:
|
168 |
+
txt = np.vstack((txt,t))
|
169 |
+
txt = txt.flatten()
|
170 |
+
return self.embed_text(txt)
|
171 |
+
|
172 |
+
class OpenClipper(torch.nn.Module):
|
173 |
+
def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')):
|
174 |
+
super().__init__()
|
175 |
+
print(clip_variant, device)
|
176 |
+
assert clip_variant == 'ViT-H-14' # not setup for other models yet
|
177 |
+
|
178 |
+
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',
|
179 |
+
pretrained='laion2b_s32b_b79k', device=device)
|
180 |
+
clip_model.eval() # dont want to train model
|
181 |
+
for param in clip_model.parameters():
|
182 |
+
param.requires_grad = False # dont need to calculate gradients
|
183 |
+
|
184 |
+
# overwrite preprocess to accept torch inputs instead of PIL Image
|
185 |
+
preprocess = transforms.Compose([
|
186 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None),
|
187 |
+
transforms.CenterCrop(224),
|
188 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
189 |
+
])
|
190 |
+
|
191 |
+
tokenizer = open_clip.get_tokenizer('ViT-H-14')
|
192 |
+
|
193 |
+
self.clip = clip_model
|
194 |
+
self.norm_embs = norm_embs
|
195 |
+
self.preprocess = preprocess
|
196 |
+
self.tokenizer = tokenizer
|
197 |
+
self.device = device
|
198 |
+
|
199 |
+
def embed_image(self, image):
|
200 |
+
"""Expects images in -1 to 1 range"""
|
201 |
+
image = self.preprocess(image).to(self.device)
|
202 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
203 |
+
image_features = self.clip.encode_image(image)
|
204 |
+
if self.norm_embs:
|
205 |
+
image_features = nn.functional.normalize(image_features, dim=-1)
|
206 |
+
return image_features
|
207 |
+
|
208 |
+
def embed_text(self, text_samples):
|
209 |
+
text = self.tokenizer(text_samples).to(self.device)
|
210 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
211 |
+
text_features = self.clip.encode_text(text)
|
212 |
+
if self.norm_embs:
|
213 |
+
text_features = nn.functional.normalize(text_features, dim=-1)
|
214 |
+
return text_features
|
215 |
+
|
216 |
+
def embed_curated_annotations(self, annots):
|
217 |
+
for i,b in enumerate(annots):
|
218 |
+
t = ''
|
219 |
+
while t == '':
|
220 |
+
rand = torch.randint(5,(1,1))[0][0]
|
221 |
+
t = b[0,rand]
|
222 |
+
if i==0:
|
223 |
+
txt = np.array(t)
|
224 |
+
else:
|
225 |
+
txt = np.vstack((txt,t))
|
226 |
+
txt = txt.flatten()
|
227 |
+
return self.embed_text(txt)
|
228 |
+
|
229 |
+
class BrainNetwork(nn.Module):
|
230 |
+
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):
|
231 |
+
super().__init__()
|
232 |
+
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
|
233 |
+
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
|
234 |
+
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
|
235 |
+
# self.temp = nn.Parameter(torch.tensor(.006))
|
236 |
+
self.lin0 = nn.Sequential(
|
237 |
+
nn.Linear(in_dim, h),
|
238 |
+
*[item() for item in act_and_norm],
|
239 |
+
nn.Dropout(drop1),
|
240 |
+
)
|
241 |
+
self.mlp = nn.ModuleList([
|
242 |
+
nn.Sequential(
|
243 |
+
nn.Linear(h, h),
|
244 |
+
*[item() for item in act_and_norm],
|
245 |
+
nn.Dropout(drop2)
|
246 |
+
) for _ in range(n_blocks)
|
247 |
+
])
|
248 |
+
self.lin1 = nn.Linear(h, out_dim, bias=True)
|
249 |
+
self.n_blocks = n_blocks
|
250 |
+
self.clip_size = clip_size
|
251 |
+
|
252 |
+
self.use_projector = use_projector
|
253 |
+
if use_projector:
|
254 |
+
self.projector = nn.Sequential(
|
255 |
+
nn.LayerNorm(clip_size),
|
256 |
+
nn.GELU(),
|
257 |
+
nn.Linear(clip_size, 2048),
|
258 |
+
nn.LayerNorm(2048),
|
259 |
+
nn.GELU(),
|
260 |
+
nn.Linear(2048, 2048),
|
261 |
+
nn.LayerNorm(2048),
|
262 |
+
nn.GELU(),
|
263 |
+
nn.Linear(2048, clip_size)
|
264 |
+
)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
'''
|
268 |
+
bs, 1, 15724 -> bs, 32, h
|
269 |
+
bs, 32, h -> bs, 32h
|
270 |
+
b2, 32h -> bs, 768
|
271 |
+
'''
|
272 |
+
if x.ndim == 4:
|
273 |
+
# case when we passed 3D data of shape [N, 81, 104, 83]
|
274 |
+
assert x.shape[1] == 81 and x.shape[2] == 104 and x.shape[3] == 83
|
275 |
+
# [N, 699192]
|
276 |
+
x = x.reshape(x.shape[0], -1)
|
277 |
+
x = self.lin0(x) # bs, h
|
278 |
+
residual = x
|
279 |
+
for res_block in range(self.n_blocks):
|
280 |
+
x = self.mlp[res_block](x)
|
281 |
+
x += residual
|
282 |
+
residual = x
|
283 |
+
x = x.reshape(len(x), -1)
|
284 |
+
x = self.lin1(x)
|
285 |
+
if self.use_projector:
|
286 |
+
return x, self.projector(x.reshape(len(x), -1, self.clip_size))
|
287 |
+
return x
|
288 |
+
|
289 |
+
class BrainDiffusionPriorOld(DiffusionPrior):
|
290 |
+
"""
|
291 |
+
Differences from original:
|
292 |
+
- Allow for passing of generators to torch random functions
|
293 |
+
- Option to include the voxel2clip model and pass voxels into forward method
|
294 |
+
- Return predictions when computing loss
|
295 |
+
- Load pretrained model from @nousr trained on LAION aesthetics
|
296 |
+
"""
|
297 |
+
def __init__(self, *args, **kwargs):
|
298 |
+
voxel2clip = kwargs.pop('voxel2clip', None)
|
299 |
+
super().__init__(*args, **kwargs)
|
300 |
+
self.voxel2clip = voxel2clip
|
301 |
+
|
302 |
+
@torch.no_grad()
|
303 |
+
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
|
304 |
+
generator=None):
|
305 |
+
b, *_, device = *x.shape, x.device
|
306 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
307 |
+
if generator is None:
|
308 |
+
noise = torch.randn_like(x)
|
309 |
+
else:
|
310 |
+
#noise = torch.randn_like(x)
|
311 |
+
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
|
312 |
+
# no noise when t == 0
|
313 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
314 |
+
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
315 |
+
return pred, x_start
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
|
319 |
+
batch, device = shape[0], self.device
|
320 |
+
|
321 |
+
if generator is None:
|
322 |
+
image_embed = torch.randn(shape, device = device)
|
323 |
+
else:
|
324 |
+
image_embed = torch.randn(shape, device = device, generator=generator)
|
325 |
+
x_start = None # for self-conditioning
|
326 |
+
|
327 |
+
if self.init_image_embed_l2norm:
|
328 |
+
image_embed = l2norm(image_embed) * self.image_embed_scale
|
329 |
+
|
330 |
+
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
|
331 |
+
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
332 |
+
|
333 |
+
self_cond = x_start if self.net.self_cond else None
|
334 |
+
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
|
335 |
+
generator=generator)
|
336 |
+
|
337 |
+
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
338 |
+
image_embed = self.l2norm_clamp_embed(image_embed)
|
339 |
+
|
340 |
+
return image_embed
|
341 |
+
|
342 |
+
def p_losses(self, image_embed, times, text_cond, noise = None):
|
343 |
+
noise = default(noise, lambda: torch.randn_like(image_embed))
|
344 |
+
|
345 |
+
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
346 |
+
|
347 |
+
self_cond = None
|
348 |
+
if self.net.self_cond and random.random() < 0.5:
|
349 |
+
with torch.no_grad():
|
350 |
+
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
351 |
+
|
352 |
+
pred = self.net(
|
353 |
+
image_embed_noisy,
|
354 |
+
times,
|
355 |
+
self_cond = self_cond,
|
356 |
+
text_cond_drop_prob = self.text_cond_drop_prob,
|
357 |
+
image_cond_drop_prob = self.image_cond_drop_prob,
|
358 |
+
**text_cond
|
359 |
+
)
|
360 |
+
|
361 |
+
if self.predict_x_start and self.training_clamp_l2norm:
|
362 |
+
pred = self.l2norm_clamp_embed(pred)
|
363 |
+
|
364 |
+
if self.predict_v:
|
365 |
+
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
366 |
+
elif self.predict_x_start:
|
367 |
+
target = image_embed
|
368 |
+
else:
|
369 |
+
target = noise
|
370 |
+
|
371 |
+
loss = self.noise_scheduler.loss_fn(pred, target)
|
372 |
+
return loss, pred
|
373 |
+
|
374 |
+
def forward(
|
375 |
+
self,
|
376 |
+
text = None,
|
377 |
+
image = None,
|
378 |
+
voxel = None,
|
379 |
+
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
380 |
+
image_embed = None,
|
381 |
+
text_encodings = None, # as well as CLIP text encodings
|
382 |
+
*args,
|
383 |
+
**kwargs
|
384 |
+
):
|
385 |
+
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
|
386 |
+
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
387 |
+
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
388 |
+
|
389 |
+
if exists(voxel):
|
390 |
+
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
|
391 |
+
assert not exists(text_embed), 'cannot pass in both text and voxels'
|
392 |
+
text_embed = self.voxel2clip(voxel)
|
393 |
+
|
394 |
+
if exists(image):
|
395 |
+
image_embed, _ = self.clip.embed_image(image)
|
396 |
+
|
397 |
+
# calculate text conditionings, based on what is passed in
|
398 |
+
|
399 |
+
if exists(text):
|
400 |
+
text_embed, text_encodings = self.clip.embed_text(text)
|
401 |
+
|
402 |
+
text_cond = dict(text_embed = text_embed)
|
403 |
+
|
404 |
+
if self.condition_on_text_encodings:
|
405 |
+
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
406 |
+
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
407 |
+
|
408 |
+
# timestep conditioning from ddpm
|
409 |
+
|
410 |
+
batch, device = image_embed.shape[0], image_embed.device
|
411 |
+
times = self.noise_scheduler.sample_random_times(batch)
|
412 |
+
|
413 |
+
# scale image embed (Katherine)
|
414 |
+
|
415 |
+
image_embed *= self.image_embed_scale
|
416 |
+
|
417 |
+
# calculate forward loss
|
418 |
+
|
419 |
+
loss, pred = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
420 |
+
|
421 |
+
return loss, pred#, text_embed
|
422 |
+
|
423 |
+
@staticmethod
|
424 |
+
def from_pretrained(net_kwargs={}, prior_kwargs={}, voxel2clip_path=None, ckpt_dir='./checkpoints'):
|
425 |
+
# "https://huggingface.co/nousr/conditioned-prior/raw/main/vit-l-14/aesthetic/prior_config.json"
|
426 |
+
config_url = os.path.join(ckpt_dir, "prior_config.json")
|
427 |
+
config = json.load(open(config_url))
|
428 |
+
|
429 |
+
config['prior']['net']['max_text_len'] = 256
|
430 |
+
config['prior']['net'].update(net_kwargs)
|
431 |
+
# print('net_config', config['prior']['net'])
|
432 |
+
net_config = DiffusionPriorNetworkConfig(**config['prior']['net'])
|
433 |
+
|
434 |
+
kwargs = config['prior']
|
435 |
+
kwargs.pop('clip')
|
436 |
+
kwargs.pop('net')
|
437 |
+
kwargs.update(prior_kwargs)
|
438 |
+
# print('prior_config', kwargs)
|
439 |
+
|
440 |
+
diffusion_prior_network = net_config.create()
|
441 |
+
diffusion_prior = BrainDiffusionPriorOld(net=diffusion_prior_network, clip=None, **kwargs).to(torch.device('cpu'))
|
442 |
+
|
443 |
+
# 'https://huggingface.co/nousr/conditioned-prior/resolve/main/vit-l-14/aesthetic/best.pth'
|
444 |
+
ckpt_url = os.path.join(ckpt_dir, 'best.pth')
|
445 |
+
ckpt = torch.load(ckpt_url, map_location=torch.device('cpu'))
|
446 |
+
|
447 |
+
# Note these keys will be missing (maybe due to an update to the code since training):
|
448 |
+
# "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed"
|
449 |
+
# I don't think these get used if `cond_drop_prob = 0` though (which is the default here)
|
450 |
+
diffusion_prior.load_state_dict(ckpt, strict=False)
|
451 |
+
# keys = diffusion_prior.load_state_dict(ckpt, strict=False)
|
452 |
+
# print("missing keys in prior checkpoint (probably ok)", keys.missing_keys)
|
453 |
+
|
454 |
+
if voxel2clip_path:
|
455 |
+
# load the voxel2clip weights
|
456 |
+
checkpoint = torch.load(voxel2clip_path, map_location=torch.device('cpu'))
|
457 |
+
|
458 |
+
state_dict = checkpoint['model_state_dict']
|
459 |
+
for key in list(state_dict.keys()):
|
460 |
+
if 'module.' in key:
|
461 |
+
state_dict[key.replace('module.', '')] = state_dict[key]
|
462 |
+
del state_dict[key]
|
463 |
+
diffusion_prior.voxel2clip.load_state_dict(state_dict)
|
464 |
+
|
465 |
+
return diffusion_prior
|
466 |
+
|
467 |
+
class BrainDiffusionPrior(DiffusionPrior):
|
468 |
+
"""
|
469 |
+
Differences from original:
|
470 |
+
- Allow for passing of generators to torch random functions
|
471 |
+
- Option to include the voxel2clip model and pass voxels into forward method
|
472 |
+
- Return predictions when computing loss
|
473 |
+
- Load pretrained model from @nousr trained on LAION aesthetics
|
474 |
+
"""
|
475 |
+
def __init__(self, *args, **kwargs):
|
476 |
+
voxel2clip = kwargs.pop('voxel2clip', None)
|
477 |
+
super().__init__(*args, **kwargs)
|
478 |
+
self.voxel2clip = voxel2clip
|
479 |
+
|
480 |
+
@torch.no_grad()
|
481 |
+
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
|
482 |
+
generator=None):
|
483 |
+
b, *_, device = *x.shape, x.device
|
484 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
485 |
+
if generator is None:
|
486 |
+
noise = torch.randn_like(x)
|
487 |
+
else:
|
488 |
+
#noise = torch.randn_like(x)
|
489 |
+
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
|
490 |
+
# no noise when t == 0
|
491 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
492 |
+
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
493 |
+
return pred, x_start
|
494 |
+
|
495 |
+
@torch.no_grad()
|
496 |
+
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
|
497 |
+
batch, device = shape[0], self.device
|
498 |
+
|
499 |
+
if generator is None:
|
500 |
+
image_embed = torch.randn(shape, device = device)
|
501 |
+
else:
|
502 |
+
image_embed = torch.randn(shape, device = device, generator=generator)
|
503 |
+
x_start = None # for self-conditioning
|
504 |
+
|
505 |
+
if self.init_image_embed_l2norm:
|
506 |
+
image_embed = l2norm(image_embed) * self.image_embed_scale
|
507 |
+
|
508 |
+
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
|
509 |
+
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
510 |
+
|
511 |
+
self_cond = x_start if self.net.self_cond else None
|
512 |
+
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
|
513 |
+
generator=generator)
|
514 |
+
|
515 |
+
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
516 |
+
image_embed = self.l2norm_clamp_embed(image_embed)
|
517 |
+
|
518 |
+
return image_embed
|
519 |
+
|
520 |
+
def p_losses(self, image_embed, times, text_cond, noise = None):
|
521 |
+
noise = default(noise, lambda: torch.randn_like(image_embed))
|
522 |
+
|
523 |
+
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
524 |
+
|
525 |
+
self_cond = None
|
526 |
+
if self.net.self_cond and random.random() < 0.5:
|
527 |
+
with torch.no_grad():
|
528 |
+
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
529 |
+
|
530 |
+
pred = self.net(
|
531 |
+
image_embed_noisy,
|
532 |
+
times,
|
533 |
+
self_cond = self_cond,
|
534 |
+
text_cond_drop_prob = self.text_cond_drop_prob,
|
535 |
+
image_cond_drop_prob = self.image_cond_drop_prob,
|
536 |
+
**text_cond
|
537 |
+
)
|
538 |
+
|
539 |
+
if self.predict_x_start and self.training_clamp_l2norm:
|
540 |
+
pred = self.l2norm_clamp_embed(pred)
|
541 |
+
|
542 |
+
if self.predict_v:
|
543 |
+
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
544 |
+
elif self.predict_x_start:
|
545 |
+
target = image_embed
|
546 |
+
else:
|
547 |
+
target = noise
|
548 |
+
|
549 |
+
loss = self.noise_scheduler.loss_fn(pred, target)
|
550 |
+
return loss, pred
|
551 |
+
|
552 |
+
def forward(
|
553 |
+
self,
|
554 |
+
text = None,
|
555 |
+
image = None,
|
556 |
+
voxel = None,
|
557 |
+
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
558 |
+
image_embed = None,
|
559 |
+
text_encodings = None, # as well as CLIP text encodings
|
560 |
+
*args,
|
561 |
+
**kwargs
|
562 |
+
):
|
563 |
+
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
|
564 |
+
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
565 |
+
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
566 |
+
|
567 |
+
if exists(voxel):
|
568 |
+
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
|
569 |
+
assert not exists(text_embed), 'cannot pass in both text and voxels'
|
570 |
+
if self.voxel2clip.use_projector:
|
571 |
+
clip_voxels_mse, clip_voxels = self.voxel2clip(voxel)
|
572 |
+
text_embed = clip_voxels_mse
|
573 |
+
else:
|
574 |
+
clip_voxels = self.voxel2clip(voxel)
|
575 |
+
text_embed = clip_voxels_mse = clip_voxels
|
576 |
+
# text_embed = self.voxel2clip(voxel)
|
577 |
+
|
578 |
+
if exists(image):
|
579 |
+
image_embed, _ = self.clip.embed_image(image)
|
580 |
+
|
581 |
+
# calculate text conditionings, based on what is passed in
|
582 |
+
|
583 |
+
if exists(text):
|
584 |
+
text_embed, text_encodings = self.clip.embed_text(text)
|
585 |
+
|
586 |
+
text_cond = dict(text_embed = text_embed)
|
587 |
+
|
588 |
+
if self.condition_on_text_encodings:
|
589 |
+
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
590 |
+
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
591 |
+
|
592 |
+
# timestep conditioning from ddpm
|
593 |
+
|
594 |
+
batch, device = image_embed.shape[0], image_embed.device
|
595 |
+
times = self.noise_scheduler.sample_random_times(batch)
|
596 |
+
|
597 |
+
# PS: I dont think we need this? also if uncommented this does in-place global variable change
|
598 |
+
# scale image embed (Katherine)
|
599 |
+
# image_embed *= self.image_embed_scale
|
600 |
+
|
601 |
+
# calculate forward loss
|
602 |
+
|
603 |
+
loss, pred = self.p_losses(image_embed*self.image_embed_scale, times, text_cond = text_cond, *args, **kwargs)
|
604 |
+
|
605 |
+
# undo the scaling so we can directly use it for real mse loss and reconstruction
|
606 |
+
return loss, pred
|
607 |
+
|
608 |
+
class BrainSD(StableDiffusionImageVariationPipeline):
|
609 |
+
"""
|
610 |
+
Differences from original:
|
611 |
+
- Keep generated images on GPU and return tensors
|
612 |
+
- No NSFW checker
|
613 |
+
- Can pass in image or image_embedding to generate a variation
|
614 |
+
NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
|
615 |
+
"""
|
616 |
+
|
617 |
+
def decode_latents(self, latents):
|
618 |
+
latents = 1 / 0.18215 * latents
|
619 |
+
image = self.vae.decode(latents).sample
|
620 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
621 |
+
# # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
622 |
+
# image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
623 |
+
return image
|
624 |
+
|
625 |
+
@torch.no_grad()
|
626 |
+
def __call__(
|
627 |
+
self,
|
628 |
+
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
|
629 |
+
height: Optional[int] = None,
|
630 |
+
width: Optional[int] = None,
|
631 |
+
num_inference_steps: int = 50,
|
632 |
+
guidance_scale: float = 7.5,
|
633 |
+
num_images_per_prompt: Optional[int] = 1,
|
634 |
+
eta: float = 0.0,
|
635 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
636 |
+
latents: Optional[torch.FloatTensor] = None,
|
637 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
638 |
+
callback_steps: Optional[int] = 1,
|
639 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
640 |
+
):
|
641 |
+
|
642 |
+
# 0. Default height and width to unet
|
643 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
644 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
645 |
+
|
646 |
+
device = self._execution_device
|
647 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
648 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
649 |
+
# corresponds to doing no classifier free guidance.
|
650 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
651 |
+
|
652 |
+
if image_embeddings is None:
|
653 |
+
assert image is not None, "If image_embeddings is None, image must not be None"
|
654 |
+
|
655 |
+
# resize and normalize the way that's recommended
|
656 |
+
tform = transforms.Compose([
|
657 |
+
#transforms.ToTensor(), ## don't need this since we've already got tensors
|
658 |
+
transforms.Resize(
|
659 |
+
(224, 224),
|
660 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
661 |
+
antialias=False,
|
662 |
+
),
|
663 |
+
transforms.Normalize(
|
664 |
+
[0.48145466, 0.4578275, 0.40821073],
|
665 |
+
[0.26862954, 0.26130258, 0.27577711]),
|
666 |
+
])
|
667 |
+
image = tform(image)
|
668 |
+
|
669 |
+
# 1. Check inputs. Raise error if not correct
|
670 |
+
self.check_inputs(image, height, width, callback_steps)
|
671 |
+
|
672 |
+
# 2. Define call parameters
|
673 |
+
if isinstance(image, PIL.Image.Image):
|
674 |
+
batch_size = 1
|
675 |
+
elif isinstance(image, list):
|
676 |
+
batch_size = len(image)
|
677 |
+
else:
|
678 |
+
batch_size = image.shape[0]
|
679 |
+
|
680 |
+
# 3. Encode input image
|
681 |
+
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
|
682 |
+
else:
|
683 |
+
batch_size = image_embeddings.shape[0] // 2
|
684 |
+
|
685 |
+
# 4. Prepare timesteps
|
686 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
687 |
+
timesteps = self.scheduler.timesteps
|
688 |
+
|
689 |
+
# 5. Prepare latent variables
|
690 |
+
num_channels_latents = self.unet.in_channels
|
691 |
+
|
692 |
+
latents = self.prepare_latents(
|
693 |
+
batch_size * num_images_per_prompt,
|
694 |
+
num_channels_latents,
|
695 |
+
height,
|
696 |
+
width,
|
697 |
+
image_embeddings.dtype,
|
698 |
+
device,
|
699 |
+
generator,
|
700 |
+
latents,
|
701 |
+
)
|
702 |
+
|
703 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
704 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
705 |
+
|
706 |
+
# 7. Denoising loop
|
707 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
708 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
709 |
+
for i, t in enumerate(timesteps):
|
710 |
+
# expand the latents if we are doing classifier free guidance
|
711 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
712 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
713 |
+
|
714 |
+
# predict the noise residual
|
715 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
716 |
+
|
717 |
+
# perform guidance
|
718 |
+
if do_classifier_free_guidance:
|
719 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
720 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
721 |
+
|
722 |
+
# compute the previous noisy sample x_t -> x_t-1
|
723 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
724 |
+
|
725 |
+
# call the callback, if provided
|
726 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
727 |
+
progress_bar.update()
|
728 |
+
if callback is not None and i % callback_steps == 0:
|
729 |
+
callback(i, t, latents)
|
730 |
+
|
731 |
+
# 8. Post-processing
|
732 |
+
image = self.decode_latents(latents)
|
733 |
+
|
734 |
+
# # 9. Run safety checker
|
735 |
+
# image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
736 |
+
|
737 |
+
# # 10. Convert to PIL
|
738 |
+
# if output_type == "pil":
|
739 |
+
# image = self.numpy_to_pil(image)
|
740 |
+
|
741 |
+
# if not return_dict:
|
742 |
+
# return (image, has_nsfw_concept)
|
743 |
+
|
744 |
+
# return StableDiffusionPipelineOutput(images=image)
|
745 |
+
|
746 |
+
return image
|
747 |
+
|
748 |
+
class Voxel2StableDiffusionModel(torch.nn.Module):
|
749 |
+
def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False):
|
750 |
+
super().__init__()
|
751 |
+
self.lin0 = nn.Sequential(
|
752 |
+
nn.Linear(in_dim, h, bias=False),
|
753 |
+
nn.LayerNorm(h),
|
754 |
+
nn.SiLU(inplace=True),
|
755 |
+
nn.Dropout(0.5),
|
756 |
+
)
|
757 |
+
|
758 |
+
self.mlp = nn.ModuleList([
|
759 |
+
nn.Sequential(
|
760 |
+
nn.Linear(h, h, bias=False),
|
761 |
+
nn.LayerNorm(h),
|
762 |
+
nn.SiLU(inplace=True),
|
763 |
+
nn.Dropout(0.25)
|
764 |
+
) for _ in range(n_blocks)
|
765 |
+
])
|
766 |
+
self.lin1 = nn.Linear(h, 16384, bias=False)
|
767 |
+
self.norm = nn.LayerNorm(512)
|
768 |
+
|
769 |
+
self.register_parameter('queries', nn.Parameter(torch.randn(1, 256, 512) * 0.044))
|
770 |
+
self.transformer = nn.TransformerDecoder(
|
771 |
+
nn.TransformerDecoderLayer(d_model=512, nhead=8, norm_first=True,
|
772 |
+
dim_feedforward=1024, activation=nn.functional.gelu,
|
773 |
+
batch_first=True, dropout=0.25),
|
774 |
+
num_layers=n_blocks
|
775 |
+
)
|
776 |
+
|
777 |
+
# option 1 -> 124.56M
|
778 |
+
# self.lin1 = nn.Linear(h, 32768, bias=True)
|
779 |
+
# self.upsampler = Decoder(
|
780 |
+
# in_channels=64,
|
781 |
+
# out_channels=4,
|
782 |
+
# up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
|
783 |
+
# block_out_channels=[64, 128, 256, 256],
|
784 |
+
# layers_per_block=1,
|
785 |
+
# )
|
786 |
+
|
787 |
+
# option2 -> 132.52M
|
788 |
+
# self.lin1 = nn.Linear(h, 1024, bias=True)
|
789 |
+
# self.upsampler = Decoder(
|
790 |
+
# in_channels=64,
|
791 |
+
# out_channels=4,
|
792 |
+
# up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"],
|
793 |
+
# block_out_channels=[64, 128, 256, 256, 512],
|
794 |
+
# layers_per_block=1,
|
795 |
+
# )
|
796 |
+
|
797 |
+
if use_cont:
|
798 |
+
self.maps_projector = nn.Sequential(
|
799 |
+
nn.LayerNorm(512),
|
800 |
+
nn.Linear(512, 512),
|
801 |
+
nn.LayerNorm(512),
|
802 |
+
nn.ReLU(True),
|
803 |
+
nn.Linear(512, 512),
|
804 |
+
nn.LayerNorm(512),
|
805 |
+
nn.ReLU(True),
|
806 |
+
nn.Linear(512, 512)
|
807 |
+
)
|
808 |
+
else:
|
809 |
+
self.maps_projector = nn.Identity()
|
810 |
+
|
811 |
+
self.upsampler = nn.Sequential(
|
812 |
+
nn.GroupNorm(1, 32),
|
813 |
+
nn.SiLU(inplace=True),
|
814 |
+
nn.Conv2d(32, 320, 3, padding=1),
|
815 |
+
nn.GroupNorm(32, 320),
|
816 |
+
nn.SiLU(inplace=True),
|
817 |
+
nn.Conv2d(320, 320, 3, padding=1),
|
818 |
+
nn.GroupNorm(32, 320),
|
819 |
+
nn.SiLU(inplace=True),
|
820 |
+
nn.Conv2d(320, 4, 3, padding=1)
|
821 |
+
)
|
822 |
+
|
823 |
+
def forward(self, x, return_transformer_feats=False):
|
824 |
+
x = self.lin0(x)
|
825 |
+
residual = x
|
826 |
+
for res_block in self.mlp:
|
827 |
+
x = res_block(x)
|
828 |
+
x = x + residual
|
829 |
+
residual = x
|
830 |
+
x = x.reshape(len(x), -1)
|
831 |
+
x = self.lin1(x) # bs, 4096
|
832 |
+
|
833 |
+
# # x = x.reshape(x.shape[0], -1, 8, 8).contiguous() # bs, 64, 8, 8
|
834 |
+
# x = x.reshape(x.shape[0], -1, 64, 64).contiguous()
|
835 |
+
# return self.upsampler(x)
|
836 |
+
|
837 |
+
# decoder
|
838 |
+
x = self.norm(x.reshape(x.shape[0], 32, 512))
|
839 |
+
preds = self.transformer(self.queries.expand(x.shape[0], -1, -1), x)
|
840 |
+
sd_embeds = preds.permute(0,2,1).reshape(-1, 512, 16, 16)
|
841 |
+
sd_embeds = nn.functional.pixel_shuffle(sd_embeds, 4) # bs, 32, 32, 32
|
842 |
+
|
843 |
+
# contrastive
|
844 |
+
if return_transformer_feats:
|
845 |
+
return self.upsampler(sd_embeds), self.maps_projector(preds)
|
846 |
+
|
847 |
+
return self.upsampler(sd_embeds)
|
848 |
+
|
849 |
+
class BrainNetworkDETR(BrainNetwork):
|
850 |
+
# 133M
|
851 |
+
def __init__(self, out_dim=768, in_dim=15724, h=4096, n_blocks=4, norm_type='ln', act_first=False,
|
852 |
+
encoder_tokens=32, decoder_tokens=257):
|
853 |
+
# encoder
|
854 |
+
super().__init__(out_dim*encoder_tokens, in_dim, h, n_blocks, norm_type, act_first)
|
855 |
+
self.norm = nn.LayerNorm(out_dim)
|
856 |
+
self.encoder_tokens = encoder_tokens
|
857 |
+
|
858 |
+
self.register_parameter('queries', nn.Parameter(torch.randn(1, decoder_tokens, out_dim)))
|
859 |
+
self.transformer = nn.TransformerDecoder(
|
860 |
+
nn.TransformerDecoderLayer(d_model=out_dim, nhead=8,
|
861 |
+
dim_feedforward=1024,
|
862 |
+
batch_first=True, dropout=0.25),
|
863 |
+
num_layers=n_blocks
|
864 |
+
)
|
865 |
+
self.decoder_projector = nn.Sequential(
|
866 |
+
nn.LayerNorm(out_dim),
|
867 |
+
nn.Linear(out_dim, out_dim)
|
868 |
+
)
|
869 |
+
|
870 |
+
|
871 |
+
def forward(self, x):
|
872 |
+
enc = super().forward(x)
|
873 |
+
enc = self.norm(enc.reshape(enc.shape[0], self.encoder_tokens, -1))
|
874 |
+
|
875 |
+
dec = self.transformer(self.queries.expand(x.shape[0], -1, -1), enc)
|
876 |
+
dec = self.decoder_projector(dec)
|
877 |
+
return dec
|
878 |
+
|
879 |
+
class VersatileDiffusionPriorNetwork(nn.Module):
|
880 |
+
def __init__(
|
881 |
+
self,
|
882 |
+
dim,
|
883 |
+
num_timesteps = None,
|
884 |
+
num_time_embeds = 1,
|
885 |
+
# num_image_embeds = 1,
|
886 |
+
# num_brain_embeds = 1,
|
887 |
+
num_tokens = 257,
|
888 |
+
causal = True,
|
889 |
+
learned_query_mode = 'none',
|
890 |
+
**kwargs
|
891 |
+
):
|
892 |
+
super().__init__()
|
893 |
+
self.dim = dim
|
894 |
+
self.num_time_embeds = num_time_embeds
|
895 |
+
self.continuous_embedded_time = not exists(num_timesteps)
|
896 |
+
self.learned_query_mode = learned_query_mode
|
897 |
+
|
898 |
+
self.to_time_embeds = nn.Sequential(
|
899 |
+
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
900 |
+
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
901 |
+
)
|
902 |
+
|
903 |
+
if self.learned_query_mode == 'token':
|
904 |
+
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim))
|
905 |
+
if self.learned_query_mode == 'pos_emb':
|
906 |
+
scale = dim ** -0.5
|
907 |
+
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim) * scale)
|
908 |
+
if self.learned_query_mode == 'all_pos_emb':
|
909 |
+
scale = dim ** -0.5
|
910 |
+
self.learned_query = nn.Parameter(torch.randn(num_tokens*2+1, dim) * scale)
|
911 |
+
self.causal_transformer = FlaggedCausalTransformer(dim = dim, causal=causal, **kwargs)
|
912 |
+
|
913 |
+
self.null_brain_embeds = nn.Parameter(torch.randn(num_tokens, dim))
|
914 |
+
self.null_image_embed = nn.Parameter(torch.randn(num_tokens, dim))
|
915 |
+
|
916 |
+
self.num_tokens = num_tokens
|
917 |
+
self.self_cond = False
|
918 |
+
|
919 |
+
def forward_with_cond_scale(
|
920 |
+
self,
|
921 |
+
*args,
|
922 |
+
cond_scale = 1.,
|
923 |
+
**kwargs
|
924 |
+
):
|
925 |
+
logits = self.forward(*args, **kwargs)
|
926 |
+
|
927 |
+
if cond_scale == 1:
|
928 |
+
return logits
|
929 |
+
|
930 |
+
null_logits = self.forward(*args, brain_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
|
931 |
+
return null_logits + (logits - null_logits) * cond_scale
|
932 |
+
|
933 |
+
def forward(
|
934 |
+
self,
|
935 |
+
image_embed,
|
936 |
+
diffusion_timesteps,
|
937 |
+
*,
|
938 |
+
self_cond=None,
|
939 |
+
brain_embed=None,
|
940 |
+
text_embed=None,
|
941 |
+
brain_cond_drop_prob = 0.,
|
942 |
+
text_cond_drop_prob = None,
|
943 |
+
image_cond_drop_prob = 0.
|
944 |
+
):
|
945 |
+
if text_embed is not None:
|
946 |
+
brain_embed = text_embed
|
947 |
+
if text_cond_drop_prob is not None:
|
948 |
+
brain_cond_drop_prob = text_cond_drop_prob
|
949 |
+
|
950 |
+
image_embed = image_embed.view(len(image_embed),-1,768)
|
951 |
+
# text_embed = text_embed.view(len(text_embed),-1,768)
|
952 |
+
brain_embed = brain_embed.view(len(brain_embed),-1,768)
|
953 |
+
# print(*image_embed.shape)
|
954 |
+
# print(*image_embed.shape, image_embed.device, image_embed.dtype)
|
955 |
+
|
956 |
+
batch, _, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
957 |
+
# num_time_embeds, num_image_embeds, num_brain_embeds = self.num_time_embeds, self.num_image_embeds, self.num_brain_embeds
|
958 |
+
|
959 |
+
# classifier free guidance masks
|
960 |
+
brain_keep_mask = prob_mask_like((batch,), 1 - brain_cond_drop_prob, device = device)
|
961 |
+
brain_keep_mask = rearrange(brain_keep_mask, 'b -> b 1 1')
|
962 |
+
|
963 |
+
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
|
964 |
+
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
|
965 |
+
|
966 |
+
# mask out brain embeddings with null brain embeddings
|
967 |
+
|
968 |
+
# import pdb; pdb.set_trace()
|
969 |
+
null_brain_embeds = self.null_brain_embeds.to(brain_embed.dtype)
|
970 |
+
brain_embed = torch.where(
|
971 |
+
brain_keep_mask,
|
972 |
+
brain_embed,
|
973 |
+
null_brain_embeds[None]
|
974 |
+
)
|
975 |
+
|
976 |
+
# mask out image embeddings with null image embeddings
|
977 |
+
null_image_embed = self.null_image_embed.to(image_embed.dtype)
|
978 |
+
image_embed = torch.where(
|
979 |
+
image_keep_mask,
|
980 |
+
image_embed,
|
981 |
+
null_image_embed[None]
|
982 |
+
)
|
983 |
+
|
984 |
+
# whether brain embedding is used for conditioning depends on whether brain encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
985 |
+
# but let's just do it right
|
986 |
+
if self.continuous_embedded_time:
|
987 |
+
# if continuous cast to flat, else keep int for indexing embeddings
|
988 |
+
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
989 |
+
time_embed = self.to_time_embeds(diffusion_timesteps)
|
990 |
+
|
991 |
+
if self.learned_query_mode == 'token':
|
992 |
+
learned_queries = repeat(self.learned_query, 'n d -> b n d', b = batch)
|
993 |
+
elif self.learned_query_mode == 'pos_emb':
|
994 |
+
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
|
995 |
+
image_embed = image_embed + pos_embs
|
996 |
+
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
|
997 |
+
elif self.learned_query_mode == 'all_pos_emb':
|
998 |
+
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
|
999 |
+
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
|
1000 |
+
else:
|
1001 |
+
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
|
1002 |
+
|
1003 |
+
tokens = torch.cat((
|
1004 |
+
brain_embed, # 257
|
1005 |
+
time_embed, # 1
|
1006 |
+
image_embed, # 257
|
1007 |
+
learned_queries # 257
|
1008 |
+
), dim = -2)
|
1009 |
+
if self.learned_query_mode == 'all_pos_emb':
|
1010 |
+
tokens = tokens + pos_embs
|
1011 |
+
|
1012 |
+
# attend
|
1013 |
+
tokens = self.causal_transformer(tokens)
|
1014 |
+
|
1015 |
+
# get learned query, which should predict the image embedding (per DDPM timestep)
|
1016 |
+
pred_image_embed = tokens[..., -self.num_tokens:, :]
|
1017 |
+
|
1018 |
+
return pred_image_embed
|
1019 |
+
|
1020 |
+
# import math
|
1021 |
+
# from collections import namedtuple
|
1022 |
+
# from einops import rearrange, repeat, reduce, pack, unpack
|
1023 |
+
# from einops.layers.torch import Rearrange
|
1024 |
+
# from torch import einsum
|
1025 |
+
# class Attention(nn.Module):
|
1026 |
+
# def __init__(
|
1027 |
+
# self,
|
1028 |
+
# dim,
|
1029 |
+
# *,
|
1030 |
+
# dim_head = 64,
|
1031 |
+
# heads = 8,
|
1032 |
+
# dropout = 0.,
|
1033 |
+
# causal = False,
|
1034 |
+
# rotary_emb = None,
|
1035 |
+
# cosine_sim = True,
|
1036 |
+
# cosine_sim_scale = 16
|
1037 |
+
# ):
|
1038 |
+
# super().__init__()
|
1039 |
+
# self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
|
1040 |
+
# self.cosine_sim = cosine_sim
|
1041 |
+
|
1042 |
+
# self.heads = heads
|
1043 |
+
# inner_dim = dim_head * heads
|
1044 |
+
# self.dim = dim
|
1045 |
+
# self.inner_dim = inner_dim
|
1046 |
+
|
1047 |
+
# self.causal = causal
|
1048 |
+
# self.norm = LayerNorm(dim)
|
1049 |
+
# self.dropout = nn.Dropout(dropout)
|
1050 |
+
|
1051 |
+
# self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
1052 |
+
# self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
1053 |
+
# self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
1054 |
+
|
1055 |
+
# self.rotary_emb = rotary_emb
|
1056 |
+
|
1057 |
+
# self.to_out = nn.Sequential(
|
1058 |
+
# nn.Linear(inner_dim, dim, bias = False),
|
1059 |
+
# LayerNorm(dim)
|
1060 |
+
# )
|
1061 |
+
|
1062 |
+
# def forward(self, x, mask = None, attn_bias = None):
|
1063 |
+
# b, n, device = *x.shape[:2], x.device
|
1064 |
+
# print("xinit", torch.any(torch.isnan(x)))
|
1065 |
+
# x = self.norm(x)
|
1066 |
+
# print("xnorm", torch.any(torch.isnan(x)))
|
1067 |
+
# print("xnorm.shape", x.shape)
|
1068 |
+
|
1069 |
+
# q = self.to_q(x)
|
1070 |
+
# print("q0", torch.any(torch.isnan(q)))
|
1071 |
+
|
1072 |
+
# k, v = self.to_kv(x).chunk(2, dim = -1)
|
1073 |
+
# print("k0", torch.any(torch.isnan(k)))
|
1074 |
+
# print("v0", torch.any(torch.isnan(v)))
|
1075 |
+
|
1076 |
+
|
1077 |
+
# q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
1078 |
+
# q = q * self.scale
|
1079 |
+
|
1080 |
+
# # rotary embeddings
|
1081 |
+
|
1082 |
+
# if exists(self.rotary_emb):
|
1083 |
+
# q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
1084 |
+
|
1085 |
+
# # add null key / value for classifier free guidance in prior net
|
1086 |
+
|
1087 |
+
# nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
|
1088 |
+
# k = torch.cat((nk, k), dim = -2)
|
1089 |
+
# v = torch.cat((nv, v), dim = -2)
|
1090 |
+
|
1091 |
+
# # whether to use cosine sim
|
1092 |
+
|
1093 |
+
# if self.cosine_sim:
|
1094 |
+
# q, k = map(l2norm, (q, k))
|
1095 |
+
|
1096 |
+
# q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
|
1097 |
+
# print("q2", torch.any(torch.isnan(q)))
|
1098 |
+
# print("k2", torch.any(torch.isnan(k)))
|
1099 |
+
|
1100 |
+
# # calculate query / key similarities
|
1101 |
+
|
1102 |
+
# sim = einsum('b h i d, b j d -> b h i j', q, k)
|
1103 |
+
|
1104 |
+
# # relative positional encoding (T5 style)
|
1105 |
+
|
1106 |
+
# if exists(attn_bias):
|
1107 |
+
# sim = sim + attn_bias
|
1108 |
+
|
1109 |
+
# # masking
|
1110 |
+
|
1111 |
+
# max_neg_value = -torch.finfo(sim.dtype).max
|
1112 |
+
|
1113 |
+
# print("sim1", torch.any(torch.isnan(sim)))
|
1114 |
+
|
1115 |
+
# if exists(mask):
|
1116 |
+
# mask = F.pad(mask, (1, 0), value = True)
|
1117 |
+
# mask = rearrange(mask, 'b j -> b 1 1 j')
|
1118 |
+
# sim = sim.masked_fill(~mask, max_neg_value)
|
1119 |
+
|
1120 |
+
# print("sim2", torch.any(torch.isnan(sim)))
|
1121 |
+
|
1122 |
+
# if self.causal:
|
1123 |
+
# i, j = sim.shape[-2:]
|
1124 |
+
# causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
1125 |
+
# sim = sim.masked_fill(causal_mask, max_neg_value)
|
1126 |
+
|
1127 |
+
# # attention
|
1128 |
+
|
1129 |
+
# print("simFinal", torch.any(torch.isnan(sim)))
|
1130 |
+
# attn = sim.softmax(dim = -1, dtype = torch.float32)
|
1131 |
+
# print("attn", torch.any(torch.isnan(attn)))
|
1132 |
+
# attn = attn.type(sim.dtype)
|
1133 |
+
|
1134 |
+
# attn = self.dropout(attn)
|
1135 |
+
|
1136 |
+
# # aggregate values
|
1137 |
+
|
1138 |
+
# out = einsum('b h i j, b j d -> b h i d', attn, v)
|
1139 |
+
|
1140 |
+
# out = rearrange(out, 'b h n d -> b n (h d)')
|
1141 |
+
# return self.to_out(out)
|
1142 |
+
|
1143 |
+
class FlaggedCausalTransformer(nn.Module):
|
1144 |
+
def __init__(
|
1145 |
+
self,
|
1146 |
+
*,
|
1147 |
+
dim,
|
1148 |
+
depth,
|
1149 |
+
dim_head = 64,
|
1150 |
+
heads = 8,
|
1151 |
+
ff_mult = 4,
|
1152 |
+
norm_in = False,
|
1153 |
+
norm_out = True,
|
1154 |
+
attn_dropout = 0.,
|
1155 |
+
ff_dropout = 0.,
|
1156 |
+
final_proj = True,
|
1157 |
+
normformer = False,
|
1158 |
+
rotary_emb = False,
|
1159 |
+
causal=True
|
1160 |
+
):
|
1161 |
+
super().__init__()
|
1162 |
+
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
1163 |
+
|
1164 |
+
self.rel_pos_bias = RelPosBias(heads = heads)
|
1165 |
+
|
1166 |
+
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
1167 |
+
|
1168 |
+
self.layers = nn.ModuleList([])
|
1169 |
+
for _ in range(depth):
|
1170 |
+
self.layers.append(nn.ModuleList([
|
1171 |
+
Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
1172 |
+
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
1173 |
+
]))
|
1174 |
+
|
1175 |
+
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
1176 |
+
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
1177 |
+
|
1178 |
+
def forward(self, x):
|
1179 |
+
n, device = x.shape[1], x.device
|
1180 |
+
|
1181 |
+
x = self.init_norm(x)
|
1182 |
+
|
1183 |
+
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
1184 |
+
|
1185 |
+
for attn, ff in self.layers:
|
1186 |
+
x = attn(x, attn_bias = attn_bias) + x
|
1187 |
+
x = ff(x) + x
|
1188 |
+
|
1189 |
+
out = self.norm(x)
|
1190 |
+
return self.project_out(out)
|
1191 |
+
|
1192 |
+
class BrainVD(VersatileDiffusionDualGuidedPipeline):
|
1193 |
+
"""
|
1194 |
+
Differences from original:
|
1195 |
+
- Keep generated images on GPU and return tensors
|
1196 |
+
- No NSFW checker
|
1197 |
+
- Can pass in image or image_embedding to generate a variation
|
1198 |
+
NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
|
1199 |
+
"""
|
1200 |
+
|
1201 |
+
def decode_latents(self, latents):
|
1202 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
1203 |
+
image = self.vae.decode(latents).sample
|
1204 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
1205 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
1206 |
+
# image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
1207 |
+
return image
|
1208 |
+
|
1209 |
+
def check_inputs(self, prompt, image, height, width, callback_steps):
|
1210 |
+
if prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
|
1211 |
+
raise ValueError(f"`prompt` has to be of type None, `str` or `list` but is {type(prompt)}")
|
1212 |
+
if image is not None and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
|
1213 |
+
raise ValueError(f"`image` has to be of type None, `PIL.Image` or `list` but is {type(image)}")
|
1214 |
+
|
1215 |
+
if height % 8 != 0 or width % 8 != 0:
|
1216 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
1217 |
+
|
1218 |
+
if (callback_steps is None) or (
|
1219 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
1220 |
+
):
|
1221 |
+
raise ValueError(
|
1222 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
1223 |
+
f" {type(callback_steps)}."
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
@torch.no_grad()
|
1227 |
+
def __call__(
|
1228 |
+
self,
|
1229 |
+
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
|
1230 |
+
image: Union[str, List[str]] = None,
|
1231 |
+
text_to_image_strength: float = 0.5,
|
1232 |
+
height: Optional[int] = None,
|
1233 |
+
width: Optional[int] = None,
|
1234 |
+
num_inference_steps: int = 50,
|
1235 |
+
guidance_scale: float = 7.5,
|
1236 |
+
num_images_per_prompt: Optional[int] = 1,
|
1237 |
+
eta: float = 0.0,
|
1238 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1239 |
+
latents: Optional[torch.FloatTensor] = None,
|
1240 |
+
output_type: Optional[str] = "pil",
|
1241 |
+
return_dict: bool = True,
|
1242 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1243 |
+
callback_steps: Optional[int] = 1,
|
1244 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1245 |
+
prompt_embeddings: Optional[torch.FloatTensor] = None,
|
1246 |
+
**kwargs,
|
1247 |
+
):
|
1248 |
+
|
1249 |
+
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
|
1250 |
+
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
|
1251 |
+
|
1252 |
+
self.check_inputs(prompt, image, height, width, callback_steps)
|
1253 |
+
|
1254 |
+
prompt = [prompt] if prompt is not None and not isinstance(prompt, list) else prompt
|
1255 |
+
image = [image] if image is not None and not isinstance(image, list) else image
|
1256 |
+
device = self._execution_device
|
1257 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1258 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1259 |
+
# corresponds to doing no classifier free guidance.
|
1260 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1261 |
+
|
1262 |
+
|
1263 |
+
# 3. Encode input prompt
|
1264 |
+
if image_embeddings is None:
|
1265 |
+
if image is not None:
|
1266 |
+
image_embeddings = self._encode_image_prompt(
|
1267 |
+
image, device, num_images_per_prompt, do_classifier_free_guidance
|
1268 |
+
)
|
1269 |
+
batch_size = len(image)
|
1270 |
+
else:
|
1271 |
+
image_embeddings = None
|
1272 |
+
|
1273 |
+
if prompt_embeddings is None:
|
1274 |
+
if prompt is not None:
|
1275 |
+
prompt_embeddings = self._encode_text_prompt(
|
1276 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
1277 |
+
)
|
1278 |
+
batch_size = len(prompt)
|
1279 |
+
else:
|
1280 |
+
prompt_embeddings = None
|
1281 |
+
if image_embeddings is not None:
|
1282 |
+
batch_size = image_embeddings.shape[0] // 2
|
1283 |
+
elif prompt_embeddings is not None:
|
1284 |
+
batch_size = prompt_embeddings.shape[0] // 2
|
1285 |
+
|
1286 |
+
if image_embeddings is not None and prompt_embeddings is not None:
|
1287 |
+
dual_prompt_embeddings = torch.cat([prompt_embeddings, image_embeddings], dim=1)
|
1288 |
+
elif image_embeddings is None:
|
1289 |
+
dual_prompt_embeddings = prompt_embeddings
|
1290 |
+
text_to_image_strength = 1.
|
1291 |
+
elif prompt_embeddings is None:
|
1292 |
+
dual_prompt_embeddings = image_embeddings
|
1293 |
+
text_to_image_strength = 0.
|
1294 |
+
else:
|
1295 |
+
raise ValueError()
|
1296 |
+
|
1297 |
+
# 4. Prepare timesteps
|
1298 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1299 |
+
timesteps = self.scheduler.timesteps
|
1300 |
+
|
1301 |
+
# 5. Prepare latent variables
|
1302 |
+
num_channels_latents = self.image_unet.in_channels
|
1303 |
+
latents = self.prepare_latents(
|
1304 |
+
batch_size * num_images_per_prompt,
|
1305 |
+
num_channels_latents,
|
1306 |
+
height,
|
1307 |
+
width,
|
1308 |
+
dual_prompt_embeddings.dtype,
|
1309 |
+
device,
|
1310 |
+
generator,
|
1311 |
+
latents,
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
# 6. Prepare extra step kwargs.
|
1315 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1316 |
+
|
1317 |
+
# 7. Combine the attention blocks of the image and text UNets
|
1318 |
+
self.set_transformer_params(text_to_image_strength, ("text", "image"))
|
1319 |
+
|
1320 |
+
# 8. Denoising loop
|
1321 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
1322 |
+
# expand the latents if we are doing classifier free guidance
|
1323 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1324 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1325 |
+
|
1326 |
+
# predict the noise residual
|
1327 |
+
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
|
1328 |
+
|
1329 |
+
# perform guidance
|
1330 |
+
if do_classifier_free_guidance:
|
1331 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1332 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1333 |
+
|
1334 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1335 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1336 |
+
|
1337 |
+
# call the callback, if provided
|
1338 |
+
if callback is not None and i % callback_steps == 0:
|
1339 |
+
callback(i, t, latents)
|
1340 |
+
|
1341 |
+
# 8. Post-processing
|
1342 |
+
image = self.decode_latents(latents)
|
1343 |
+
|
1344 |
+
return image
|
src/setup.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Commands to setup a new conda environment and install all the necessary packages
|
3 |
+
# See the environment.yaml file for "conda env export > environment.yaml" after running this.
|
4 |
+
|
5 |
+
set -e
|
6 |
+
|
7 |
+
conda create -n fmri python=3.10.8 -y
|
8 |
+
conda activate fmri
|
9 |
+
|
10 |
+
conda install numpy matplotlib tqdm scikit-image jupyterlab -y
|
11 |
+
|
12 |
+
pip install accelerate webdataset clip pandas matplotlib ftfy regex kornia umap-learn h5py
|
13 |
+
pip install torchvision==0.15.2 torch==2.0.1
|
14 |
+
pip install diffusers
|
src/utils.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torchvision import transforms
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import PIL
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import pandas as pd
|
11 |
+
import math
|
12 |
+
import webdataset as wds
|
13 |
+
import tempfile
|
14 |
+
from torchvision.utils import make_grid
|
15 |
+
from diffusers.utils import randn_tensor
|
16 |
+
|
17 |
+
import json
|
18 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
19 |
+
from PIL import Image
|
20 |
+
import requests
|
21 |
+
import io
|
22 |
+
import time
|
23 |
+
|
24 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
|
26 |
+
def is_interactive():
|
27 |
+
import __main__ as main
|
28 |
+
return not hasattr(main, '__file__')
|
29 |
+
|
30 |
+
def seed_everything(seed=0, cudnn_deterministic=True):
|
31 |
+
random.seed(seed)
|
32 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
33 |
+
np.random.seed(seed)
|
34 |
+
torch.manual_seed(seed)
|
35 |
+
torch.cuda.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
if cudnn_deterministic:
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
else:
|
40 |
+
## needs to be False to use conv3D
|
41 |
+
print('Note: not using cudnn.deterministic')
|
42 |
+
|
43 |
+
def np_to_Image(x):
|
44 |
+
if x.ndim==4:
|
45 |
+
x=x[0]
|
46 |
+
return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
|
47 |
+
|
48 |
+
def torch_to_Image(x):
|
49 |
+
if x.ndim==4:
|
50 |
+
x=x[0]
|
51 |
+
return transforms.ToPILImage()(x)
|
52 |
+
|
53 |
+
def Image_to_torch(x):
|
54 |
+
try:
|
55 |
+
x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
|
56 |
+
except:
|
57 |
+
x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
|
58 |
+
return x
|
59 |
+
|
60 |
+
def torch_to_matplotlib(x,device=device):
|
61 |
+
if torch.mean(x)>10:
|
62 |
+
x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
|
63 |
+
else:
|
64 |
+
x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
|
65 |
+
if device=='cpu':
|
66 |
+
return x[0]
|
67 |
+
else:
|
68 |
+
return x.cpu().numpy()[0]
|
69 |
+
|
70 |
+
def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
|
71 |
+
#https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
|
72 |
+
numerator = A @ B.T
|
73 |
+
A_l2 = torch.mul(A, A).sum(axis=dim)
|
74 |
+
B_l2 = torch.mul(B, B).sum(axis=dim)
|
75 |
+
denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
|
76 |
+
return torch.div(numerator, denominator)
|
77 |
+
|
78 |
+
def batchwise_pearson_correlation(Z, B):
|
79 |
+
# Calculate means
|
80 |
+
Z_mean = torch.mean(Z, dim=1, keepdim=True)
|
81 |
+
B_mean = torch.mean(B, dim=1, keepdim=True)
|
82 |
+
|
83 |
+
# Subtract means
|
84 |
+
Z_centered = Z - Z_mean
|
85 |
+
B_centered = B - B_mean
|
86 |
+
|
87 |
+
# Calculate Pearson correlation coefficient
|
88 |
+
numerator = Z_centered @ B_centered.T
|
89 |
+
Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True)
|
90 |
+
B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True)
|
91 |
+
denominator = Z_centered_norm @ B_centered_norm.T
|
92 |
+
|
93 |
+
pearson_correlation = (numerator / denominator)
|
94 |
+
return pearson_correlation
|
95 |
+
|
96 |
+
def batchwise_cosine_similarity(Z,B):
|
97 |
+
# https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
|
98 |
+
B = B.T
|
99 |
+
Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
|
100 |
+
B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
|
101 |
+
cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
|
102 |
+
return cosine_similarity
|
103 |
+
|
104 |
+
def topk(similarities,labels,k=5):
|
105 |
+
if k > similarities.shape[0]:
|
106 |
+
k = similarities.shape[0]
|
107 |
+
topsum=0
|
108 |
+
for i in range(k):
|
109 |
+
topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
|
110 |
+
return topsum
|
111 |
+
|
112 |
+
def get_non_diagonals(a):
|
113 |
+
a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
|
114 |
+
# make diagonals -1
|
115 |
+
a=a.fill_diagonal_(-1)
|
116 |
+
return a
|
117 |
+
|
118 |
+
def gather_features(image_features, voxel_features, accelerator):
|
119 |
+
all_image_features = accelerator.gather(image_features.contiguous())
|
120 |
+
if voxel_features is not None:
|
121 |
+
all_voxel_features = accelerator.gather(voxel_features.contiguous())
|
122 |
+
return all_image_features, all_voxel_features
|
123 |
+
return all_image_features
|
124 |
+
|
125 |
+
def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None):
|
126 |
+
# if not distributed:
|
127 |
+
clip_clip = (targs @ targs.T)/temp
|
128 |
+
brain_clip = (preds @ targs.T)/temp
|
129 |
+
# else:
|
130 |
+
# all_targs = gather_features(targs, None, accelerator)
|
131 |
+
# clip_clip = (targs @ all_targs.T)/temp
|
132 |
+
# brain_clip = (preds @ all_targs.T)/temp
|
133 |
+
|
134 |
+
loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
|
135 |
+
loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
|
136 |
+
|
137 |
+
loss = (loss1 + loss2)/2
|
138 |
+
return loss
|
139 |
+
|
140 |
+
def mixco(voxels, beta=0.15, s_thresh=0.5):
|
141 |
+
perm = torch.randperm(voxels.shape[0])
|
142 |
+
voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
|
143 |
+
betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
|
144 |
+
select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
|
145 |
+
betas_shape = [-1] + [1]*(len(voxels.shape)-1)
|
146 |
+
voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
|
147 |
+
voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
|
148 |
+
betas[~select] = 1
|
149 |
+
return voxels, perm, betas, select
|
150 |
+
|
151 |
+
def mixco_clip_target(clip_target, perm, select, betas):
|
152 |
+
clip_target_shuffle = clip_target[perm]
|
153 |
+
clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \
|
154 |
+
clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1)
|
155 |
+
return clip_target
|
156 |
+
|
157 |
+
def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
|
158 |
+
accelerator=None, local_rank=None, bidirectional=True):
|
159 |
+
brain_clip = (preds @ targs.T)/temp
|
160 |
+
|
161 |
+
if perm is not None and betas is not None and select is not None:
|
162 |
+
probs = torch.diag(betas)
|
163 |
+
probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
|
164 |
+
|
165 |
+
loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
|
166 |
+
if bidirectional:
|
167 |
+
loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
|
168 |
+
loss = (loss + loss2)/2
|
169 |
+
return loss
|
170 |
+
else:
|
171 |
+
loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
|
172 |
+
if bidirectional:
|
173 |
+
loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
|
174 |
+
loss = (loss + loss2)/2
|
175 |
+
return loss
|
176 |
+
|
177 |
+
def count_params(model):
|
178 |
+
total = sum(p.numel() for p in model.parameters())
|
179 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
180 |
+
print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
|
181 |
+
|
182 |
+
def image_grid(imgs, rows, cols):
|
183 |
+
w, h = imgs[0].size
|
184 |
+
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
|
185 |
+
for i, img in enumerate(imgs):
|
186 |
+
grid.paste(img, box=(i%cols*w, i//cols*h))
|
187 |
+
return grid
|
188 |
+
|
189 |
+
def check_loss(loss):
|
190 |
+
if loss.isnan().any():
|
191 |
+
raise ValueError('NaN loss')
|
192 |
+
|
193 |
+
def cosine_anneal(start, end, steps):
|
194 |
+
return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
|
195 |
+
|
196 |
+
import braceexpand
|
197 |
+
def get_dataloaders(
|
198 |
+
batch_size,
|
199 |
+
image_var='images',
|
200 |
+
num_devices=None,
|
201 |
+
num_workers=None,
|
202 |
+
train_url=None,
|
203 |
+
val_url=None,
|
204 |
+
meta_url=None,
|
205 |
+
num_train=None,
|
206 |
+
num_val=None,
|
207 |
+
cache_dir="/scratch/tmp/wds-cache",
|
208 |
+
seed=0,
|
209 |
+
voxels_key="nsdgeneral.npy",
|
210 |
+
val_batch_size=None,
|
211 |
+
to_tuple=["voxels", "images", "trial"],
|
212 |
+
local_rank=0,
|
213 |
+
world_size=1,
|
214 |
+
):
|
215 |
+
print("Getting dataloaders...")
|
216 |
+
assert image_var == 'images'
|
217 |
+
|
218 |
+
def my_split_by_node(urls):
|
219 |
+
return urls
|
220 |
+
|
221 |
+
train_url = list(braceexpand.braceexpand(train_url))
|
222 |
+
val_url = list(braceexpand.braceexpand(val_url))
|
223 |
+
|
224 |
+
if num_devices is None:
|
225 |
+
num_devices = torch.cuda.device_count()
|
226 |
+
|
227 |
+
if num_workers is None:
|
228 |
+
num_workers = num_devices
|
229 |
+
|
230 |
+
if num_train is None:
|
231 |
+
metadata = json.load(open(meta_url))
|
232 |
+
num_train = metadata['totals']['train']
|
233 |
+
if num_val is None:
|
234 |
+
metadata = json.load(open(meta_url))
|
235 |
+
num_val = metadata['totals']['val']
|
236 |
+
|
237 |
+
if val_batch_size is None:
|
238 |
+
val_batch_size = batch_size
|
239 |
+
|
240 |
+
global_batch_size = batch_size * num_devices
|
241 |
+
num_batches = math.floor(num_train / global_batch_size)
|
242 |
+
num_worker_batches = math.floor(num_batches / num_workers)
|
243 |
+
if num_worker_batches == 0: num_worker_batches = 1
|
244 |
+
|
245 |
+
print("\nnum_train",num_train)
|
246 |
+
print("global_batch_size",global_batch_size)
|
247 |
+
print("batch_size",batch_size)
|
248 |
+
print("num_workers",num_workers)
|
249 |
+
print("num_batches",num_batches)
|
250 |
+
print("num_worker_batches", num_worker_batches)
|
251 |
+
|
252 |
+
# train_url = train_url[local_rank:world_size]
|
253 |
+
train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
|
254 |
+
.shuffle(500, initial=500, rng=random.Random(42))\
|
255 |
+
.decode("torch")\
|
256 |
+
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
|
257 |
+
.to_tuple(*to_tuple)#\
|
258 |
+
# .batched(batch_size, partial=True)#\
|
259 |
+
# .with_epoch(num_worker_batches)
|
260 |
+
|
261 |
+
# BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False)
|
262 |
+
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False)
|
263 |
+
|
264 |
+
# Validation
|
265 |
+
print("val_batch_size",val_batch_size)
|
266 |
+
val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
|
267 |
+
.shuffle(500, initial=500, rng=random.Random(42))\
|
268 |
+
.decode("torch")\
|
269 |
+
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
|
270 |
+
.to_tuple(*to_tuple)#\
|
271 |
+
# .batched(val_batch_size, partial=True)
|
272 |
+
val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True)
|
273 |
+
|
274 |
+
return train_dl, val_dl, num_train, num_val
|