Spaces:
Running
on
Zero
Running
on
Zero
NightRaven109
commited on
Upload 73 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ADD/dnnlib/__init__.py +9 -0
- ADD/dnnlib/util.py +492 -0
- ADD/layers/__init__.py +11 -0
- ADD/layers/attention.py +89 -0
- ADD/layers/block.py +260 -0
- ADD/layers/dino_head.py +58 -0
- ADD/layers/drop_path.py +34 -0
- ADD/layers/layer_scale.py +27 -0
- ADD/layers/mlp.py +40 -0
- ADD/layers/patch_embed.py +88 -0
- ADD/layers/swiglu_ffn.py +72 -0
- ADD/models/discriminator.py +178 -0
- ADD/models/vit.py +373 -0
- ADD/th_utils/__init__.py +9 -0
- ADD/th_utils/custom_ops.py +157 -0
- ADD/th_utils/misc.py +284 -0
- ADD/th_utils/ops/__init__.py +9 -0
- ADD/th_utils/ops/bias_act.cpp +99 -0
- ADD/th_utils/ops/bias_act.cu +173 -0
- ADD/th_utils/ops/bias_act.h +38 -0
- ADD/th_utils/ops/bias_act.py +209 -0
- ADD/th_utils/ops/conv2d_gradfix.py +203 -0
- ADD/th_utils/ops/conv2d_resample.py +143 -0
- ADD/th_utils/ops/filtered_lrelu.cpp +300 -0
- ADD/th_utils/ops/filtered_lrelu.cu +1284 -0
- ADD/th_utils/ops/filtered_lrelu.h +90 -0
- ADD/th_utils/ops/filtered_lrelu.py +274 -0
- ADD/th_utils/ops/filtered_lrelu_ns.cu +27 -0
- ADD/th_utils/ops/filtered_lrelu_rd.cu +27 -0
- ADD/th_utils/ops/filtered_lrelu_wr.cu +27 -0
- ADD/th_utils/ops/fma.py +60 -0
- ADD/th_utils/ops/grid_sample_gradfix.py +83 -0
- ADD/th_utils/ops/upfirdn2d.cpp +107 -0
- ADD/th_utils/ops/upfirdn2d.cu +384 -0
- ADD/th_utils/ops/upfirdn2d.h +59 -0
- ADD/th_utils/ops/upfirdn2d.py +389 -0
- ADD/utils/util_net.py +182 -0
- README.md +291 -15
- dataloaders/paired_dataset_txt.py +70 -0
- dataloaders/params_ccsr.yml +42 -0
- dataloaders/realesrgan.py +303 -0
- models/DiffAugment.py +121 -0
- models/controlnet.py +850 -0
- models/losses/__init__.py +1 -0
- models/losses/contperceptual.py +154 -0
- models/losses/vqperceptual.py +180 -0
- models/shared.py +106 -0
- models/unet_2d_blocks.py +0 -0
- models/unet_2d_condition.py +1081 -0
- models/vit_utils.py +182 -0
ADD/dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
ADD/dnnlib/util.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import types
|
18 |
+
import io
|
19 |
+
import pickle
|
20 |
+
import re
|
21 |
+
import requests
|
22 |
+
import html
|
23 |
+
import hashlib
|
24 |
+
import glob
|
25 |
+
import tempfile
|
26 |
+
import urllib
|
27 |
+
import urllib.request
|
28 |
+
import uuid
|
29 |
+
from typing import Any, List, Tuple, Union, Optional
|
30 |
+
from distutils.util import strtobool
|
31 |
+
import shutil
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
class EasyDict(dict):
|
40 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
41 |
+
|
42 |
+
def __getattr__(self, name: str) -> Any:
|
43 |
+
try:
|
44 |
+
return self[name]
|
45 |
+
except KeyError:
|
46 |
+
raise AttributeError(name)
|
47 |
+
|
48 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
49 |
+
self[name] = value
|
50 |
+
|
51 |
+
def __delattr__(self, name: str) -> None:
|
52 |
+
del self[name]
|
53 |
+
|
54 |
+
|
55 |
+
class Logger(object):
|
56 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
57 |
+
|
58 |
+
def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
|
59 |
+
self.file = None
|
60 |
+
|
61 |
+
if file_name is not None:
|
62 |
+
self.file = open(file_name, file_mode)
|
63 |
+
|
64 |
+
self.should_flush = should_flush
|
65 |
+
self.stdout = sys.stdout
|
66 |
+
self.stderr = sys.stderr
|
67 |
+
|
68 |
+
sys.stdout = self
|
69 |
+
sys.stderr = self
|
70 |
+
|
71 |
+
def __enter__(self) -> "Logger":
|
72 |
+
return self
|
73 |
+
|
74 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
75 |
+
self.close()
|
76 |
+
|
77 |
+
def write(self, text: Union[str, bytes]) -> None:
|
78 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
79 |
+
if isinstance(text, bytes):
|
80 |
+
text = text.decode()
|
81 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
82 |
+
return
|
83 |
+
|
84 |
+
if self.file is not None:
|
85 |
+
self.file.write(text)
|
86 |
+
|
87 |
+
self.stdout.write(text)
|
88 |
+
|
89 |
+
if self.should_flush:
|
90 |
+
self.flush()
|
91 |
+
|
92 |
+
def flush(self) -> None:
|
93 |
+
"""Flush written text to both stdout and a file, if open."""
|
94 |
+
if self.file is not None:
|
95 |
+
self.file.flush()
|
96 |
+
|
97 |
+
self.stdout.flush()
|
98 |
+
|
99 |
+
def close(self) -> None:
|
100 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
101 |
+
self.flush()
|
102 |
+
|
103 |
+
# if using multiple loggers, prevent closing in wrong order
|
104 |
+
if sys.stdout is self:
|
105 |
+
sys.stdout = self.stdout
|
106 |
+
if sys.stderr is self:
|
107 |
+
sys.stderr = self.stderr
|
108 |
+
|
109 |
+
if self.file is not None:
|
110 |
+
self.file.close()
|
111 |
+
self.file = None
|
112 |
+
|
113 |
+
|
114 |
+
# Cache directories
|
115 |
+
# ------------------------------------------------------------------------------------------
|
116 |
+
|
117 |
+
_dnnlib_cache_dir = None
|
118 |
+
|
119 |
+
def set_cache_dir(path: str) -> None:
|
120 |
+
global _dnnlib_cache_dir
|
121 |
+
_dnnlib_cache_dir = path
|
122 |
+
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
|
136 |
+
# Small util functions
|
137 |
+
# ------------------------------------------------------------------------------------------
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
154 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
155 |
+
s = int(np.rint(seconds))
|
156 |
+
|
157 |
+
if s < 60:
|
158 |
+
return "{0}s".format(s)
|
159 |
+
elif s < 60 * 60:
|
160 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
161 |
+
elif s < 24 * 60 * 60:
|
162 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
163 |
+
else:
|
164 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
165 |
+
|
166 |
+
|
167 |
+
def ask_yes_no(question: str) -> bool:
|
168 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
169 |
+
while True:
|
170 |
+
try:
|
171 |
+
print("{0} [y/n]".format(question))
|
172 |
+
return strtobool(input().lower())
|
173 |
+
except ValueError:
|
174 |
+
pass
|
175 |
+
|
176 |
+
|
177 |
+
def tuple_product(t: Tuple) -> Any:
|
178 |
+
"""Calculate the product of the tuple elements."""
|
179 |
+
result = 1
|
180 |
+
|
181 |
+
for v in t:
|
182 |
+
result *= v
|
183 |
+
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
_str_to_ctype = {
|
188 |
+
"uint8": ctypes.c_ubyte,
|
189 |
+
"uint16": ctypes.c_uint16,
|
190 |
+
"uint32": ctypes.c_uint32,
|
191 |
+
"uint64": ctypes.c_uint64,
|
192 |
+
"int8": ctypes.c_byte,
|
193 |
+
"int16": ctypes.c_int16,
|
194 |
+
"int32": ctypes.c_int32,
|
195 |
+
"int64": ctypes.c_int64,
|
196 |
+
"float32": ctypes.c_float,
|
197 |
+
"float64": ctypes.c_double
|
198 |
+
}
|
199 |
+
|
200 |
+
|
201 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
202 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
203 |
+
type_str = None
|
204 |
+
|
205 |
+
if isinstance(type_obj, str):
|
206 |
+
type_str = type_obj
|
207 |
+
elif hasattr(type_obj, "__name__"):
|
208 |
+
type_str = type_obj.__name__
|
209 |
+
elif hasattr(type_obj, "name"):
|
210 |
+
type_str = type_obj.name
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Cannot infer type name from input")
|
213 |
+
|
214 |
+
assert type_str in _str_to_ctype.keys()
|
215 |
+
|
216 |
+
my_dtype = np.dtype(type_str)
|
217 |
+
my_ctype = _str_to_ctype[type_str]
|
218 |
+
|
219 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
220 |
+
|
221 |
+
return my_dtype, my_ctype
|
222 |
+
|
223 |
+
|
224 |
+
def is_pickleable(obj: Any) -> bool:
|
225 |
+
try:
|
226 |
+
with io.BytesIO() as stream:
|
227 |
+
pickle.dump(obj, stream)
|
228 |
+
return True
|
229 |
+
except:
|
230 |
+
return False
|
231 |
+
|
232 |
+
|
233 |
+
# Functionality to import modules/objects by name, and call functions by name
|
234 |
+
# ------------------------------------------------------------------------------------------
|
235 |
+
|
236 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
237 |
+
"""Searches for the underlying module behind the name to some python object.
|
238 |
+
Returns the module and the object name (original name with module part removed)."""
|
239 |
+
|
240 |
+
# allow convenience shorthands, substitute them by full names
|
241 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
242 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
243 |
+
|
244 |
+
# list alternatives for (module_name, local_obj_name)
|
245 |
+
parts = obj_name.split(".")
|
246 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
247 |
+
|
248 |
+
# try each alternative in turn
|
249 |
+
for module_name, local_obj_name in name_pairs:
|
250 |
+
try:
|
251 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
252 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
253 |
+
return module, local_obj_name
|
254 |
+
except:
|
255 |
+
pass
|
256 |
+
|
257 |
+
# maybe some of the modules themselves contain errors?
|
258 |
+
for module_name, _local_obj_name in name_pairs:
|
259 |
+
try:
|
260 |
+
importlib.import_module(module_name) # may raise ImportError
|
261 |
+
except ImportError:
|
262 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
263 |
+
raise
|
264 |
+
|
265 |
+
# maybe the requested attribute is missing?
|
266 |
+
for module_name, local_obj_name in name_pairs:
|
267 |
+
try:
|
268 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
269 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
270 |
+
except ImportError:
|
271 |
+
pass
|
272 |
+
|
273 |
+
# we are out of luck, but we have no idea why
|
274 |
+
raise ImportError(obj_name)
|
275 |
+
|
276 |
+
|
277 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
278 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
279 |
+
if obj_name == '':
|
280 |
+
return module
|
281 |
+
obj = module
|
282 |
+
for part in obj_name.split("."):
|
283 |
+
obj = getattr(obj, part)
|
284 |
+
return obj
|
285 |
+
|
286 |
+
|
287 |
+
def get_obj_by_name(name: str) -> Any:
|
288 |
+
"""Finds the python object with the given name."""
|
289 |
+
module, obj_name = get_module_from_obj_name(name)
|
290 |
+
return get_obj_from_module(module, obj_name)
|
291 |
+
|
292 |
+
|
293 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
294 |
+
"""Finds the python object with the given name and calls it as a function."""
|
295 |
+
assert func_name is not None
|
296 |
+
func_obj = get_obj_by_name(func_name)
|
297 |
+
assert callable(func_obj)
|
298 |
+
return func_obj(*args, **kwargs)
|
299 |
+
|
300 |
+
|
301 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
302 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
303 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
304 |
+
|
305 |
+
|
306 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
307 |
+
"""Get the directory path of the module containing the given object name."""
|
308 |
+
module, _ = get_module_from_obj_name(obj_name)
|
309 |
+
return os.path.dirname(inspect.getfile(module))
|
310 |
+
|
311 |
+
|
312 |
+
def is_top_level_function(obj: Any) -> bool:
|
313 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
314 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
315 |
+
|
316 |
+
|
317 |
+
def get_top_level_function_name(obj: Any) -> str:
|
318 |
+
"""Return the fully-qualified name of a top-level function."""
|
319 |
+
assert is_top_level_function(obj)
|
320 |
+
module = obj.__module__
|
321 |
+
if module == '__main__':
|
322 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
323 |
+
return module + "." + obj.__name__
|
324 |
+
|
325 |
+
|
326 |
+
# File system helpers
|
327 |
+
# ------------------------------------------------------------------------------------------
|
328 |
+
|
329 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
330 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
331 |
+
Returns list of tuples containing both absolute and relative paths."""
|
332 |
+
assert os.path.isdir(dir_path)
|
333 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
334 |
+
|
335 |
+
if ignores is None:
|
336 |
+
ignores = []
|
337 |
+
|
338 |
+
result = []
|
339 |
+
|
340 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
341 |
+
for ignore_ in ignores:
|
342 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
343 |
+
|
344 |
+
# dirs need to be edited in-place
|
345 |
+
for d in dirs_to_remove:
|
346 |
+
dirs.remove(d)
|
347 |
+
|
348 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
349 |
+
|
350 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
351 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
352 |
+
|
353 |
+
if add_base_to_relative:
|
354 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
355 |
+
|
356 |
+
assert len(absolute_paths) == len(relative_paths)
|
357 |
+
result += zip(absolute_paths, relative_paths)
|
358 |
+
|
359 |
+
return result
|
360 |
+
|
361 |
+
|
362 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
363 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
364 |
+
Will create all necessary directories."""
|
365 |
+
for file in files:
|
366 |
+
target_dir_name = os.path.dirname(file[1])
|
367 |
+
|
368 |
+
# will create all intermediate-level directories
|
369 |
+
if not os.path.exists(target_dir_name):
|
370 |
+
os.makedirs(target_dir_name)
|
371 |
+
|
372 |
+
shutil.copyfile(file[0], file[1])
|
373 |
+
|
374 |
+
|
375 |
+
# URL helpers
|
376 |
+
# ------------------------------------------------------------------------------------------
|
377 |
+
|
378 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
379 |
+
"""Determine whether the given object is a valid URL string."""
|
380 |
+
if not isinstance(obj, str) or not "://" in obj:
|
381 |
+
return False
|
382 |
+
if allow_file_urls and obj.startswith('file://'):
|
383 |
+
return True
|
384 |
+
try:
|
385 |
+
res = requests.compat.urlparse(obj)
|
386 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
387 |
+
return False
|
388 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
389 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
390 |
+
return False
|
391 |
+
except:
|
392 |
+
return False
|
393 |
+
return True
|
394 |
+
|
395 |
+
|
396 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
397 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
398 |
+
assert num_attempts >= 1
|
399 |
+
assert not (return_filename and (not cache))
|
400 |
+
|
401 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
402 |
+
if not re.match('^[a-z]+://', url):
|
403 |
+
return url if return_filename else open(url, "rb")
|
404 |
+
|
405 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
406 |
+
# arise on Windows:
|
407 |
+
#
|
408 |
+
# file:///c:/foo.txt
|
409 |
+
#
|
410 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
411 |
+
# invalid. Drop the forward slash for such pathnames.
|
412 |
+
#
|
413 |
+
# If you touch this code path, you should test it on both Linux and
|
414 |
+
# Windows.
|
415 |
+
#
|
416 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
417 |
+
# but that converts forward slashes to backslashes and this causes
|
418 |
+
# its own set of problems.
|
419 |
+
if url.startswith('file://'):
|
420 |
+
filename = urllib.parse.urlparse(url).path
|
421 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
422 |
+
filename = filename[1:]
|
423 |
+
return filename if return_filename else open(filename, "rb")
|
424 |
+
|
425 |
+
assert is_url(url)
|
426 |
+
|
427 |
+
# Lookup from cache.
|
428 |
+
if cache_dir is None:
|
429 |
+
cache_dir = make_cache_dir_path('downloads')
|
430 |
+
|
431 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
432 |
+
if cache:
|
433 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
434 |
+
if len(cache_files) == 1:
|
435 |
+
filename = cache_files[0]
|
436 |
+
return filename if return_filename else open(filename, "rb")
|
437 |
+
|
438 |
+
# Download.
|
439 |
+
url_name = None
|
440 |
+
url_data = None
|
441 |
+
with requests.Session() as session:
|
442 |
+
if verbose:
|
443 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
444 |
+
for attempts_left in reversed(range(num_attempts)):
|
445 |
+
try:
|
446 |
+
with session.get(url) as res:
|
447 |
+
res.raise_for_status()
|
448 |
+
if len(res.content) == 0:
|
449 |
+
raise IOError("No data received")
|
450 |
+
|
451 |
+
if len(res.content) < 8192:
|
452 |
+
content_str = res.content.decode("utf-8")
|
453 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
454 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
455 |
+
if len(links) == 1:
|
456 |
+
url = requests.compat.urljoin(url, links[0])
|
457 |
+
raise IOError("Google Drive virus checker nag")
|
458 |
+
if "Google Drive - Quota exceeded" in content_str:
|
459 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
460 |
+
|
461 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
462 |
+
url_name = match[1] if match else url
|
463 |
+
url_data = res.content
|
464 |
+
if verbose:
|
465 |
+
print(" done")
|
466 |
+
break
|
467 |
+
except KeyboardInterrupt:
|
468 |
+
raise
|
469 |
+
except:
|
470 |
+
if not attempts_left:
|
471 |
+
if verbose:
|
472 |
+
print(" failed")
|
473 |
+
raise
|
474 |
+
if verbose:
|
475 |
+
print(".", end="", flush=True)
|
476 |
+
|
477 |
+
# Save to cache.
|
478 |
+
if cache:
|
479 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
480 |
+
safe_name = safe_name[:min(len(safe_name), 128)]
|
481 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
482 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
483 |
+
os.makedirs(cache_dir, exist_ok=True)
|
484 |
+
with open(temp_file, "wb") as f:
|
485 |
+
f.write(url_data)
|
486 |
+
os.replace(temp_file, cache_file) # atomic
|
487 |
+
if return_filename:
|
488 |
+
return cache_file
|
489 |
+
|
490 |
+
# Return data as file object.
|
491 |
+
assert not return_filename
|
492 |
+
return io.BytesIO(url_data)
|
ADD/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dino_head import DINOHead
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
ADD/layers/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int = 8,
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
proj_bias: bool = True,
|
43 |
+
attn_drop: float = 0.0,
|
44 |
+
proj_drop: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor) -> Tensor:
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
+
return super().forward(x)
|
78 |
+
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
+
|
82 |
+
q, k, v = unbind(qkv, 2)
|
83 |
+
|
84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
+
x = x.reshape([B, N, C])
|
86 |
+
|
87 |
+
x = self.proj(x)
|
88 |
+
x = self.proj_drop(x)
|
89 |
+
return x
|
ADD/layers/block.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("dinov2")
|
25 |
+
|
26 |
+
|
27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
28 |
+
try:
|
29 |
+
if XFORMERS_ENABLED:
|
30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
31 |
+
|
32 |
+
XFORMERS_AVAILABLE = True
|
33 |
+
warnings.warn("xFormers is available (Block)")
|
34 |
+
else:
|
35 |
+
warnings.warn("xFormers is disabled (Block)")
|
36 |
+
raise ImportError
|
37 |
+
except ImportError:
|
38 |
+
XFORMERS_AVAILABLE = False
|
39 |
+
|
40 |
+
warnings.warn("xFormers is not available (Block)")
|
41 |
+
|
42 |
+
|
43 |
+
class Block(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
dim: int,
|
47 |
+
num_heads: int,
|
48 |
+
mlp_ratio: float = 4.0,
|
49 |
+
qkv_bias: bool = False,
|
50 |
+
proj_bias: bool = True,
|
51 |
+
ffn_bias: bool = True,
|
52 |
+
drop: float = 0.0,
|
53 |
+
attn_drop: float = 0.0,
|
54 |
+
init_values=None,
|
55 |
+
drop_path: float = 0.0,
|
56 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
57 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
58 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
59 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
60 |
+
) -> None:
|
61 |
+
super().__init__()
|
62 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
63 |
+
self.norm1 = norm_layer(dim)
|
64 |
+
self.attn = attn_class(
|
65 |
+
dim,
|
66 |
+
num_heads=num_heads,
|
67 |
+
qkv_bias=qkv_bias,
|
68 |
+
proj_bias=proj_bias,
|
69 |
+
attn_drop=attn_drop,
|
70 |
+
proj_drop=drop,
|
71 |
+
)
|
72 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
73 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
74 |
+
|
75 |
+
self.norm2 = norm_layer(dim)
|
76 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
77 |
+
self.mlp = ffn_layer(
|
78 |
+
in_features=dim,
|
79 |
+
hidden_features=mlp_hidden_dim,
|
80 |
+
act_layer=act_layer,
|
81 |
+
drop=drop,
|
82 |
+
bias=ffn_bias,
|
83 |
+
)
|
84 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
85 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
86 |
+
|
87 |
+
self.sample_drop_ratio = drop_path
|
88 |
+
|
89 |
+
def forward(self, x: Tensor) -> Tensor:
|
90 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
91 |
+
return self.ls1(self.attn(self.norm1(x)))
|
92 |
+
|
93 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
94 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
95 |
+
|
96 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
97 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
98 |
+
x = drop_add_residual_stochastic_depth(
|
99 |
+
x,
|
100 |
+
residual_func=attn_residual_func,
|
101 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
102 |
+
)
|
103 |
+
x = drop_add_residual_stochastic_depth(
|
104 |
+
x,
|
105 |
+
residual_func=ffn_residual_func,
|
106 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
107 |
+
)
|
108 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
109 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
110 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
111 |
+
else:
|
112 |
+
x = x + attn_residual_func(x)
|
113 |
+
x = x + ffn_residual_func(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
def drop_add_residual_stochastic_depth(
|
118 |
+
x: Tensor,
|
119 |
+
residual_func: Callable[[Tensor], Tensor],
|
120 |
+
sample_drop_ratio: float = 0.0,
|
121 |
+
) -> Tensor:
|
122 |
+
# 1) extract subset using permutation
|
123 |
+
b, n, d = x.shape
|
124 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
125 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
126 |
+
x_subset = x[brange]
|
127 |
+
|
128 |
+
# 2) apply residual_func to get residual
|
129 |
+
residual = residual_func(x_subset)
|
130 |
+
|
131 |
+
x_flat = x.flatten(1)
|
132 |
+
residual = residual.flatten(1)
|
133 |
+
|
134 |
+
residual_scale_factor = b / sample_subset_size
|
135 |
+
|
136 |
+
# 3) add the residual
|
137 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
138 |
+
return x_plus_residual.view_as(x)
|
139 |
+
|
140 |
+
|
141 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
142 |
+
b, n, d = x.shape
|
143 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
144 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
145 |
+
residual_scale_factor = b / sample_subset_size
|
146 |
+
return brange, residual_scale_factor
|
147 |
+
|
148 |
+
|
149 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
150 |
+
if scaling_vector is None:
|
151 |
+
x_flat = x.flatten(1)
|
152 |
+
residual = residual.flatten(1)
|
153 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
154 |
+
else:
|
155 |
+
x_plus_residual = scaled_index_add(
|
156 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
157 |
+
)
|
158 |
+
return x_plus_residual
|
159 |
+
|
160 |
+
|
161 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
162 |
+
|
163 |
+
|
164 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
165 |
+
"""
|
166 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
167 |
+
"""
|
168 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
169 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
170 |
+
if all_shapes not in attn_bias_cache.keys():
|
171 |
+
seqlens = []
|
172 |
+
for b, x in zip(batch_sizes, x_list):
|
173 |
+
for _ in range(b):
|
174 |
+
seqlens.append(x.shape[1])
|
175 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
176 |
+
attn_bias._batch_sizes = batch_sizes
|
177 |
+
attn_bias_cache[all_shapes] = attn_bias
|
178 |
+
|
179 |
+
if branges is not None:
|
180 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
181 |
+
else:
|
182 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
183 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
184 |
+
|
185 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
186 |
+
|
187 |
+
|
188 |
+
def drop_add_residual_stochastic_depth_list(
|
189 |
+
x_list: List[Tensor],
|
190 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
191 |
+
sample_drop_ratio: float = 0.0,
|
192 |
+
scaling_vector=None,
|
193 |
+
) -> Tensor:
|
194 |
+
# 1) generate random set of indices for dropping samples in the batch
|
195 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
196 |
+
branges = [s[0] for s in branges_scales]
|
197 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
198 |
+
|
199 |
+
# 2) get attention bias and index+concat the tensors
|
200 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
201 |
+
|
202 |
+
# 3) apply residual_func to get residual, and split the result
|
203 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
204 |
+
|
205 |
+
outputs = []
|
206 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
207 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
208 |
+
return outputs
|
209 |
+
|
210 |
+
|
211 |
+
class NestedTensorBlock(Block):
|
212 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
213 |
+
"""
|
214 |
+
x_list contains a list of tensors to nest together and run
|
215 |
+
"""
|
216 |
+
assert isinstance(self.attn, MemEffAttention)
|
217 |
+
|
218 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
219 |
+
|
220 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
221 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
222 |
+
|
223 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
224 |
+
return self.mlp(self.norm2(x))
|
225 |
+
|
226 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
227 |
+
x_list,
|
228 |
+
residual_func=attn_residual_func,
|
229 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
230 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
231 |
+
)
|
232 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
233 |
+
x_list,
|
234 |
+
residual_func=ffn_residual_func,
|
235 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
236 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
237 |
+
)
|
238 |
+
return x_list
|
239 |
+
else:
|
240 |
+
|
241 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
242 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
243 |
+
|
244 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
245 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
246 |
+
|
247 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
248 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
249 |
+
x = x + ffn_residual_func(x)
|
250 |
+
return attn_bias.split(x)
|
251 |
+
|
252 |
+
def forward(self, x_or_x_list):
|
253 |
+
if isinstance(x_or_x_list, Tensor):
|
254 |
+
return super().forward(x_or_x_list)
|
255 |
+
elif isinstance(x_or_x_list, list):
|
256 |
+
if not XFORMERS_AVAILABLE:
|
257 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
258 |
+
return self.forward_nested(x_or_x_list)
|
259 |
+
else:
|
260 |
+
raise AssertionError
|
ADD/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
ADD/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
ADD/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
ADD/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
ADD/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
#self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
#def flops(self) -> float:
|
84 |
+
#Ho, Wo = self.patches_resolution
|
85 |
+
#flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
#if self.norm is not None:
|
87 |
+
# flops += Ho * Wo * self.embed_dim
|
88 |
+
#return flops
|
ADD/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
else:
|
45 |
+
warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
raise ImportError
|
47 |
+
except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
ADD/models/discriminator.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""
|
10 |
+
Projected discriminator architecture from
|
11 |
+
"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.nn.utils.spectral_norm import SpectralNorm
|
19 |
+
from torchvision.transforms import RandomCrop, Normalize
|
20 |
+
import timm
|
21 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
22 |
+
|
23 |
+
from ADD.th_utils import misc
|
24 |
+
from models.shared import ResidualBlock, FullyConnectedLayer
|
25 |
+
from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone
|
26 |
+
from models.DiffAugment import DiffAugment
|
27 |
+
from ADD.utils.util_net import reload_model_
|
28 |
+
|
29 |
+
from functools import partial
|
30 |
+
|
31 |
+
class SpectralConv1d(nn.Conv1d):
|
32 |
+
def __init__(self, *args, **kwargs):
|
33 |
+
super().__init__(*args, **kwargs)
|
34 |
+
SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)
|
35 |
+
|
36 |
+
|
37 |
+
class BatchNormLocal(nn.Module):
|
38 |
+
def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):
|
39 |
+
super().__init__()
|
40 |
+
self.virtual_bs = virtual_bs
|
41 |
+
self.eps = eps
|
42 |
+
self.affine = affine
|
43 |
+
|
44 |
+
if self.affine:
|
45 |
+
self.weight = nn.Parameter(torch.ones(num_features))
|
46 |
+
self.bias = nn.Parameter(torch.zeros(num_features))
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
shape = x.size()
|
50 |
+
|
51 |
+
# Reshape batch into groups.
|
52 |
+
G = np.ceil(x.size(0)/self.virtual_bs).astype(int)
|
53 |
+
x = x.view(G, -1, x.size(-2), x.size(-1))
|
54 |
+
|
55 |
+
# Calculate stats.
|
56 |
+
mean = x.mean([1, 3], keepdim=True)
|
57 |
+
var = x.var([1, 3], keepdim=True, unbiased=False)
|
58 |
+
x = (x - mean) / (torch.sqrt(var + self.eps))
|
59 |
+
|
60 |
+
if self.affine:
|
61 |
+
x = x * self.weight[None, :, None] + self.bias[None, :, None]
|
62 |
+
|
63 |
+
return x.view(shape)
|
64 |
+
|
65 |
+
|
66 |
+
def make_block(channels: int, kernel_size: int) -> nn.Module:
|
67 |
+
return nn.Sequential(
|
68 |
+
SpectralConv1d(
|
69 |
+
channels,
|
70 |
+
channels,
|
71 |
+
kernel_size = kernel_size,
|
72 |
+
padding = kernel_size//2,
|
73 |
+
padding_mode = 'circular',
|
74 |
+
),
|
75 |
+
#BatchNormLocal(channels),
|
76 |
+
nn.GroupNorm(4, channels),
|
77 |
+
nn.LeakyReLU(0.2, True),
|
78 |
+
)
|
79 |
+
|
80 |
+
class DiscHead(nn.Module):
|
81 |
+
def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
|
82 |
+
super().__init__()
|
83 |
+
self.channels = channels
|
84 |
+
self.c_dim = c_dim
|
85 |
+
self.cmap_dim = cmap_dim
|
86 |
+
|
87 |
+
self.main = nn.Sequential(
|
88 |
+
make_block(channels, kernel_size=1),
|
89 |
+
ResidualBlock(make_block(channels, kernel_size=9))
|
90 |
+
)
|
91 |
+
|
92 |
+
if self.c_dim > 0:
|
93 |
+
self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)
|
94 |
+
self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
|
95 |
+
else:
|
96 |
+
self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
|
97 |
+
|
98 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
99 |
+
h = self.main(x)
|
100 |
+
out = self.cls(h)
|
101 |
+
|
102 |
+
if self.c_dim > 0:
|
103 |
+
cmap = self.cmapper(c).unsqueeze(-1)
|
104 |
+
out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
105 |
+
|
106 |
+
return out
|
107 |
+
|
108 |
+
class DINO(torch.nn.Module):
|
109 |
+
def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):
|
110 |
+
super().__init__()
|
111 |
+
self.n_hooks = len(hooks) + int(hook_patch)
|
112 |
+
|
113 |
+
self.model = make_vit_backbone(
|
114 |
+
timm.create_model('vit_small_patch16_224.dino', pretrained=False),
|
115 |
+
patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
|
116 |
+
)
|
117 |
+
reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))
|
118 |
+
self.model = self.model.eval().requires_grad_(False)
|
119 |
+
|
120 |
+
|
121 |
+
self.img_resolution = self.model.model.patch_embed.img_size[0]
|
122 |
+
self.embed_dim = self.model.model.embed_dim
|
123 |
+
self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
124 |
+
|
125 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
126 |
+
''' input: x in [0, 1]; output: dict of activations '''
|
127 |
+
x = F.interpolate(x, self.img_resolution, mode='area')
|
128 |
+
x = self.norm(x)
|
129 |
+
features = forward_vit(self.model, x)
|
130 |
+
return features
|
131 |
+
|
132 |
+
|
133 |
+
class ProjectedDiscriminator(nn.Module):
|
134 |
+
def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):
|
135 |
+
super().__init__()
|
136 |
+
self.c_dim = c_dim
|
137 |
+
self.diffaug = diffaug
|
138 |
+
self.p_crop = p_crop
|
139 |
+
|
140 |
+
self.dino = DINO()
|
141 |
+
|
142 |
+
heads = []
|
143 |
+
for i in range(self.dino.n_hooks):
|
144 |
+
heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],
|
145 |
+
self.heads = nn.ModuleDict(heads)
|
146 |
+
|
147 |
+
def train(self, mode: bool = True):
|
148 |
+
self.dino = self.dino.train(False)
|
149 |
+
self.heads = self.heads.train(mode)
|
150 |
+
return self
|
151 |
+
|
152 |
+
def eval(self):
|
153 |
+
return self.train(False)
|
154 |
+
|
155 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
156 |
+
# Apply augmentation (x in [-1, 1]).
|
157 |
+
if self.diffaug:
|
158 |
+
x = DiffAugment(x, policy='translation,cutout')
|
159 |
+
|
160 |
+
# Transform to [0, 1].
|
161 |
+
x = x.add(1).div(2)
|
162 |
+
|
163 |
+
# Take crops with probablity p_crop if the image is larger.
|
164 |
+
if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:
|
165 |
+
x = RandomCrop(self.dino.img_resolution)(x)
|
166 |
+
|
167 |
+
# Forward pass through DINO ViT.
|
168 |
+
features = self.dino(x)
|
169 |
+
|
170 |
+
# Apply discriminator heads.
|
171 |
+
logits = []
|
172 |
+
for k, head in self.heads.items():
|
173 |
+
features[k].requires_grad_(True)
|
174 |
+
logits.append(head(features[k], c).view(x.size(0), -1))
|
175 |
+
#logits = torch.cat(logits, dim=1)
|
176 |
+
|
177 |
+
return logits, features
|
178 |
+
|
ADD/models/vit.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from ADD.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
194 |
+
|
195 |
+
sqrt_N = math.sqrt(N)
|
196 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
197 |
+
patch_pos_embed = nn.functional.interpolate(
|
198 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
199 |
+
scale_factor=(sx, sy),
|
200 |
+
mode="bicubic",
|
201 |
+
antialias=self.interpolate_antialias,
|
202 |
+
)
|
203 |
+
|
204 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
205 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
206 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
207 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
208 |
+
|
209 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
210 |
+
B, nc, w, h = x.shape
|
211 |
+
x = self.patch_embed(x)
|
212 |
+
if masks is not None:
|
213 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
214 |
+
|
215 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
216 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
217 |
+
|
218 |
+
if self.register_tokens is not None:
|
219 |
+
x = torch.cat(
|
220 |
+
(
|
221 |
+
x[:, :1],
|
222 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
223 |
+
x[:, 1:],
|
224 |
+
),
|
225 |
+
dim=1,
|
226 |
+
)
|
227 |
+
|
228 |
+
return x
|
229 |
+
|
230 |
+
def forward_features_list(self, x_list, masks_list):
|
231 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
232 |
+
for blk in self.blocks:
|
233 |
+
x = blk(x)
|
234 |
+
|
235 |
+
all_x = x
|
236 |
+
output = []
|
237 |
+
for x, masks in zip(all_x, masks_list):
|
238 |
+
x_norm = self.norm(x)
|
239 |
+
output.append(
|
240 |
+
{
|
241 |
+
"x_norm_clstoken": x_norm[:, 0],
|
242 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
243 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
244 |
+
"x_prenorm": x,
|
245 |
+
"masks": masks,
|
246 |
+
}
|
247 |
+
)
|
248 |
+
return output
|
249 |
+
|
250 |
+
def forward_features(self, x, masks=None):
|
251 |
+
fea_list = []
|
252 |
+
counter = 0
|
253 |
+
if isinstance(x, list):
|
254 |
+
return self.forward_features_list(x, masks)
|
255 |
+
|
256 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
257 |
+
fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
counter += 1
|
262 |
+
if counter % 3 == 0:
|
263 |
+
fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))
|
264 |
+
|
265 |
+
x_norm = self.norm(x)
|
266 |
+
return fea_list, x_norm[:, 0]
|
267 |
+
|
268 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
269 |
+
x = self.prepare_tokens_with_masks(x)
|
270 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
271 |
+
output, total_block_len = [], len(self.blocks)
|
272 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
273 |
+
for i, blk in enumerate(self.blocks):
|
274 |
+
x = blk(x)
|
275 |
+
if i in blocks_to_take:
|
276 |
+
output.append(x)
|
277 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
278 |
+
return output
|
279 |
+
|
280 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
281 |
+
x = self.prepare_tokens_with_masks(x)
|
282 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
283 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
284 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
285 |
+
for block_chunk in self.blocks:
|
286 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
287 |
+
x = blk(x)
|
288 |
+
if i in blocks_to_take:
|
289 |
+
output.append(x)
|
290 |
+
i += 1
|
291 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
292 |
+
return output
|
293 |
+
|
294 |
+
def get_intermediate_layers(
|
295 |
+
self,
|
296 |
+
x: torch.Tensor,
|
297 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
298 |
+
reshape: bool = False,
|
299 |
+
return_class_token: bool = False,
|
300 |
+
norm=True,
|
301 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
302 |
+
if self.chunked_blocks:
|
303 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
304 |
+
else:
|
305 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
306 |
+
if norm:
|
307 |
+
outputs = [self.norm(out) for out in outputs]
|
308 |
+
class_tokens = [out[:, 0] for out in outputs]
|
309 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
310 |
+
if reshape:
|
311 |
+
B, _, w, h = x.shape
|
312 |
+
outputs = [
|
313 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
314 |
+
for out in outputs
|
315 |
+
]
|
316 |
+
if return_class_token:
|
317 |
+
return tuple(zip(outputs, class_tokens))
|
318 |
+
return tuple(outputs)
|
319 |
+
|
320 |
+
def forward(self, *args, is_training=False, **kwargs):
|
321 |
+
ret = self.forward_features(*args, **kwargs)
|
322 |
+
if is_training:
|
323 |
+
return ret
|
324 |
+
else:
|
325 |
+
return ret#self.head(ret["x_norm_clstoken"])
|
326 |
+
|
327 |
+
|
328 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
329 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
330 |
+
if isinstance(module, nn.Linear):
|
331 |
+
trunc_normal_(module.weight, std=0.02)
|
332 |
+
if module.bias is not None:
|
333 |
+
nn.init.zeros_(module.bias)
|
334 |
+
|
335 |
+
|
336 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
337 |
+
model = DinoVisionTransformer(
|
338 |
+
patch_size=patch_size,
|
339 |
+
embed_dim=384,
|
340 |
+
depth=12,
|
341 |
+
num_heads=6,
|
342 |
+
mlp_ratio=4,
|
343 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
344 |
+
num_register_tokens=num_register_tokens,
|
345 |
+
**kwargs,
|
346 |
+
)
|
347 |
+
return model
|
348 |
+
|
349 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
350 |
+
model = DinoVisionTransformer(
|
351 |
+
patch_size=patch_size,
|
352 |
+
embed_dim=1024,
|
353 |
+
depth=24,
|
354 |
+
num_heads=16,
|
355 |
+
mlp_ratio=4,
|
356 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
357 |
+
num_register_tokens=num_register_tokens,
|
358 |
+
**kwargs,
|
359 |
+
)
|
360 |
+
return model
|
361 |
+
|
362 |
+
|
363 |
+
# net = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0)
|
364 |
+
# prefile = torch.load('../weights/dinov2_vits14_pretrain.pth')
|
365 |
+
# net.load_state_dict(prefile, True)
|
366 |
+
# out = net(torch.rand(1, 3, 518, 518))
|
367 |
+
# print(out.shape)
|
368 |
+
|
369 |
+
# net = vit_large(patch_size=14, img_size=526, block_chunks=0, init_values=1.0, num_register_tokens=4)
|
370 |
+
# prefile = torch.load('../weights/dinov2_vitl14_reg4_pretrain.pth')
|
371 |
+
# net.load_state_dict(prefile, True)
|
372 |
+
# out = net(torch.rand(1, 3, 70, 70))
|
373 |
+
# print(out.shape)
|
ADD/th_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
ADD/th_utils/custom_ops.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import glob
|
10 |
+
import hashlib
|
11 |
+
import importlib
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
import shutil
|
15 |
+
import uuid
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.utils.cpp_extension
|
19 |
+
from torch.utils.file_baton import FileBaton
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
# Global options.
|
23 |
+
|
24 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
# Internal helper funcs.
|
28 |
+
|
29 |
+
def _find_compiler_bindir():
|
30 |
+
patterns = [
|
31 |
+
'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
34 |
+
'C:/Program Files*/Microsoft Visual Studio */vc/bin',
|
35 |
+
]
|
36 |
+
for pattern in patterns:
|
37 |
+
matches = sorted(glob.glob(pattern))
|
38 |
+
if len(matches):
|
39 |
+
return matches[-1]
|
40 |
+
return None
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
|
44 |
+
def _get_mangled_gpu_name():
|
45 |
+
name = torch.cuda.get_device_name().lower()
|
46 |
+
out = []
|
47 |
+
for c in name:
|
48 |
+
if re.match('[a-z0-9_-]+', c):
|
49 |
+
out.append(c)
|
50 |
+
else:
|
51 |
+
out.append('-')
|
52 |
+
return ''.join(out)
|
53 |
+
|
54 |
+
#----------------------------------------------------------------------------
|
55 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
56 |
+
|
57 |
+
_cached_plugins = dict()
|
58 |
+
|
59 |
+
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
60 |
+
assert verbosity in ['none', 'brief', 'full']
|
61 |
+
if headers is None:
|
62 |
+
headers = []
|
63 |
+
if source_dir is not None:
|
64 |
+
sources = [os.path.join(source_dir, fname) for fname in sources]
|
65 |
+
headers = [os.path.join(source_dir, fname) for fname in headers]
|
66 |
+
|
67 |
+
# Already cached?
|
68 |
+
if module_name in _cached_plugins:
|
69 |
+
return _cached_plugins[module_name]
|
70 |
+
|
71 |
+
# Print status.
|
72 |
+
if verbosity == 'full':
|
73 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
74 |
+
elif verbosity == 'brief':
|
75 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
76 |
+
verbose_build = (verbosity == 'full')
|
77 |
+
|
78 |
+
# Compile and load.
|
79 |
+
try: # pylint: disable=too-many-nested-blocks
|
80 |
+
# Make sure we can find the necessary compiler binaries.
|
81 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
82 |
+
compiler_bindir = _find_compiler_bindir()
|
83 |
+
if compiler_bindir is None:
|
84 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
85 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
86 |
+
|
87 |
+
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
88 |
+
# break the build or unnecessarily restrict what's available to nvcc.
|
89 |
+
# Unset it to let nvcc decide based on what's available on the
|
90 |
+
# machine.
|
91 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
92 |
+
|
93 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
94 |
+
# into a cached build directory under a combined md5 digest of the input
|
95 |
+
# source files. Copying is done only if the combined digest has changed.
|
96 |
+
# This keeps input file timestamps and filenames the same as in previous
|
97 |
+
# extension builds, allowing for fast incremental rebuilds.
|
98 |
+
#
|
99 |
+
# This optimization is done only in case all the source files reside in
|
100 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
101 |
+
# environment variable is set (we take this as a signal that the user
|
102 |
+
# actually cares about this.)
|
103 |
+
#
|
104 |
+
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
105 |
+
# around the *.cu dependency bug in ninja config.
|
106 |
+
#
|
107 |
+
all_source_files = sorted(sources + headers)
|
108 |
+
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
109 |
+
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
110 |
+
|
111 |
+
# Compute combined hash digest for all source files.
|
112 |
+
hash_md5 = hashlib.md5()
|
113 |
+
for src in all_source_files:
|
114 |
+
with open(src, 'rb') as f:
|
115 |
+
hash_md5.update(f.read())
|
116 |
+
|
117 |
+
# Select cached build directory name.
|
118 |
+
source_digest = hash_md5.hexdigest()
|
119 |
+
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
120 |
+
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
121 |
+
|
122 |
+
if not os.path.isdir(cached_build_dir):
|
123 |
+
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
124 |
+
os.makedirs(tmpdir)
|
125 |
+
for src in all_source_files:
|
126 |
+
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
127 |
+
try:
|
128 |
+
os.replace(tmpdir, cached_build_dir) # atomic
|
129 |
+
except OSError:
|
130 |
+
# source directory already exists, delete tmpdir and its contents.
|
131 |
+
shutil.rmtree(tmpdir)
|
132 |
+
if not os.path.isdir(cached_build_dir): raise
|
133 |
+
|
134 |
+
# Compile.
|
135 |
+
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
136 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
137 |
+
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
138 |
+
else:
|
139 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
140 |
+
|
141 |
+
# Load.
|
142 |
+
module = importlib.import_module(module_name)
|
143 |
+
|
144 |
+
except:
|
145 |
+
if verbosity == 'brief':
|
146 |
+
print('Failed!')
|
147 |
+
raise
|
148 |
+
|
149 |
+
# Print status and add to cache dict.
|
150 |
+
if verbosity == 'full':
|
151 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
152 |
+
elif verbosity == 'brief':
|
153 |
+
print('Done.')
|
154 |
+
_cached_plugins[module_name] = module
|
155 |
+
return module
|
156 |
+
|
157 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/misc.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import re
|
10 |
+
import contextlib
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import warnings
|
14 |
+
import ADD.dnnlib as dnnlib
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
18 |
+
# same constant is used multiple times.
|
19 |
+
|
20 |
+
_constant_cache = dict()
|
21 |
+
|
22 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
23 |
+
value = np.asarray(value)
|
24 |
+
if shape is not None:
|
25 |
+
shape = tuple(shape)
|
26 |
+
if dtype is None:
|
27 |
+
dtype = torch.get_default_dtype()
|
28 |
+
if device is None:
|
29 |
+
device = torch.device('cpu')
|
30 |
+
if memory_format is None:
|
31 |
+
memory_format = torch.contiguous_format
|
32 |
+
|
33 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
34 |
+
tensor = _constant_cache.get(key, None)
|
35 |
+
if tensor is None:
|
36 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
37 |
+
if shape is not None:
|
38 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
39 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
40 |
+
_constant_cache[key] = tensor
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
#----------------------------------------------------------------------------
|
44 |
+
# Replace NaN/Inf with specified numerical values.
|
45 |
+
|
46 |
+
try:
|
47 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
48 |
+
except AttributeError:
|
49 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
50 |
+
assert isinstance(input, torch.Tensor)
|
51 |
+
if posinf is None:
|
52 |
+
posinf = torch.finfo(input.dtype).max
|
53 |
+
if neginf is None:
|
54 |
+
neginf = torch.finfo(input.dtype).min
|
55 |
+
assert nan == 0
|
56 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
57 |
+
|
58 |
+
#----------------------------------------------------------------------------
|
59 |
+
# Symbolic assert.
|
60 |
+
|
61 |
+
try:
|
62 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
63 |
+
except AttributeError:
|
64 |
+
symbolic_assert = torch.Assert # 1.7.0
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
68 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
69 |
+
|
70 |
+
@contextlib.contextmanager
|
71 |
+
def suppress_tracer_warnings():
|
72 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
73 |
+
warnings.filters.insert(0, flt)
|
74 |
+
yield
|
75 |
+
warnings.filters.remove(flt)
|
76 |
+
|
77 |
+
#----------------------------------------------------------------------------
|
78 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
79 |
+
# None indicates that the size of a dimension is allowed to vary.
|
80 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
81 |
+
|
82 |
+
def assert_shape(tensor, ref_shape):
|
83 |
+
if tensor.ndim != len(ref_shape):
|
84 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
85 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
86 |
+
if ref_size is None:
|
87 |
+
pass
|
88 |
+
elif isinstance(ref_size, torch.Tensor):
|
89 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
90 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
91 |
+
elif isinstance(size, torch.Tensor):
|
92 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
93 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
94 |
+
elif size != ref_size:
|
95 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
96 |
+
|
97 |
+
#----------------------------------------------------------------------------
|
98 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
99 |
+
|
100 |
+
def profiled_function(fn):
|
101 |
+
def decorator(*args, **kwargs):
|
102 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
103 |
+
return fn(*args, **kwargs)
|
104 |
+
decorator.__name__ = fn.__name__
|
105 |
+
return decorator
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
109 |
+
# indefinitely, shuffling items as it goes.
|
110 |
+
|
111 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
112 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
113 |
+
assert len(dataset) > 0
|
114 |
+
assert num_replicas > 0
|
115 |
+
assert 0 <= rank < num_replicas
|
116 |
+
assert 0 <= window_size <= 1
|
117 |
+
super().__init__(dataset)
|
118 |
+
self.dataset = dataset
|
119 |
+
self.rank = rank
|
120 |
+
self.num_replicas = num_replicas
|
121 |
+
self.shuffle = shuffle
|
122 |
+
self.seed = seed
|
123 |
+
self.window_size = window_size
|
124 |
+
|
125 |
+
def __iter__(self):
|
126 |
+
order = np.arange(len(self.dataset))
|
127 |
+
rnd = None
|
128 |
+
window = 0
|
129 |
+
if self.shuffle:
|
130 |
+
rnd = np.random.RandomState(self.seed)
|
131 |
+
rnd.shuffle(order)
|
132 |
+
window = int(np.rint(order.size * self.window_size))
|
133 |
+
|
134 |
+
idx = 0
|
135 |
+
while True:
|
136 |
+
i = idx % order.size
|
137 |
+
if idx % self.num_replicas == self.rank:
|
138 |
+
yield order[i]
|
139 |
+
if window >= 2:
|
140 |
+
j = (i - rnd.randint(window)) % order.size
|
141 |
+
order[i], order[j] = order[j], order[i]
|
142 |
+
idx += 1
|
143 |
+
|
144 |
+
#----------------------------------------------------------------------------
|
145 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
146 |
+
def spectral_to_cpu(model: torch.nn.Module):
|
147 |
+
def wrapped_in_spectral(m): return hasattr(m, 'weight_v')
|
148 |
+
children = get_children(model)
|
149 |
+
for child in children:
|
150 |
+
if wrapped_in_spectral(child):
|
151 |
+
child.weight = child.weight.cpu()
|
152 |
+
return model
|
153 |
+
|
154 |
+
def get_children(model: torch.nn.Module):
|
155 |
+
children = list(model.children())
|
156 |
+
flatt_children = []
|
157 |
+
if children == []:
|
158 |
+
return model
|
159 |
+
else:
|
160 |
+
for child in children:
|
161 |
+
try:
|
162 |
+
flatt_children.extend(get_children(child))
|
163 |
+
except TypeError:
|
164 |
+
flatt_children.append(get_children(child))
|
165 |
+
return flatt_children
|
166 |
+
|
167 |
+
def params_and_buffers(module):
|
168 |
+
assert isinstance(module, torch.nn.Module)
|
169 |
+
return list(module.parameters()) + list(module.buffers())
|
170 |
+
|
171 |
+
def named_params_and_buffers(module):
|
172 |
+
assert isinstance(module, torch.nn.Module)
|
173 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
174 |
+
|
175 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
176 |
+
assert isinstance(src_module, torch.nn.Module)
|
177 |
+
assert isinstance(dst_module, torch.nn.Module)
|
178 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
179 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
180 |
+
assert (name in src_tensors) or (not require_all)
|
181 |
+
if name in src_tensors:
|
182 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
183 |
+
|
184 |
+
#----------------------------------------------------------------------------
|
185 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
186 |
+
# synchronization.
|
187 |
+
|
188 |
+
@contextlib.contextmanager
|
189 |
+
def ddp_sync(module, sync):
|
190 |
+
assert isinstance(module, torch.nn.Module)
|
191 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
192 |
+
yield
|
193 |
+
else:
|
194 |
+
with module.no_sync():
|
195 |
+
yield
|
196 |
+
|
197 |
+
#----------------------------------------------------------------------------
|
198 |
+
# Check DistributedDataParallel consistency across processes.
|
199 |
+
|
200 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
201 |
+
assert isinstance(module, torch.nn.Module)
|
202 |
+
for name, tensor in named_params_and_buffers(module):
|
203 |
+
fullname = type(module).__name__ + '.' + name
|
204 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
205 |
+
continue
|
206 |
+
tensor = tensor.detach()
|
207 |
+
if tensor.is_floating_point():
|
208 |
+
tensor = nan_to_num(tensor)
|
209 |
+
other = tensor.clone()
|
210 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
211 |
+
assert (tensor == other).all(), fullname
|
212 |
+
|
213 |
+
#----------------------------------------------------------------------------
|
214 |
+
# Print summary table of module hierarchy.
|
215 |
+
|
216 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
217 |
+
assert isinstance(module, torch.nn.Module)
|
218 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
219 |
+
assert isinstance(inputs, (tuple, list))
|
220 |
+
|
221 |
+
# Register hooks.
|
222 |
+
entries = []
|
223 |
+
nesting = [0]
|
224 |
+
def pre_hook(_mod, _inputs):
|
225 |
+
nesting[0] += 1
|
226 |
+
def post_hook(mod, _inputs, outputs):
|
227 |
+
nesting[0] -= 1
|
228 |
+
if nesting[0] <= max_nesting:
|
229 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
230 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
231 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
232 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
233 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
234 |
+
|
235 |
+
# Run module.
|
236 |
+
outputs = module(*inputs)
|
237 |
+
for hook in hooks:
|
238 |
+
hook.remove()
|
239 |
+
|
240 |
+
# Identify unique outputs, parameters, and buffers.
|
241 |
+
tensors_seen = set()
|
242 |
+
for e in entries:
|
243 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
244 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
245 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
246 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
247 |
+
|
248 |
+
# Filter out redundant entries.
|
249 |
+
if skip_redundant:
|
250 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
251 |
+
|
252 |
+
# Construct table.
|
253 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
254 |
+
rows += [['---'] * len(rows[0])]
|
255 |
+
param_total = 0
|
256 |
+
buffer_total = 0
|
257 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
258 |
+
for e in entries:
|
259 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
260 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
261 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
262 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
263 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
264 |
+
rows += [[
|
265 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
266 |
+
str(param_size) if param_size else '-',
|
267 |
+
str(buffer_size) if buffer_size else '-',
|
268 |
+
(output_shapes + ['-'])[0],
|
269 |
+
(output_dtypes + ['-'])[0],
|
270 |
+
]]
|
271 |
+
for idx in range(1, len(e.outputs)):
|
272 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
273 |
+
param_total += param_size
|
274 |
+
buffer_total += buffer_size
|
275 |
+
rows += [['---'] * len(rows[0])]
|
276 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
277 |
+
|
278 |
+
# Print table.
|
279 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
280 |
+
print()
|
281 |
+
for row in rows:
|
282 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
283 |
+
print()
|
284 |
+
return outputs
|
ADD/th_utils/ops/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
ADD/th_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "bias_act.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
//------------------------------------------------------------------------
|
21 |
+
// CUDA kernel.
|
22 |
+
|
23 |
+
template <class T, int A>
|
24 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
+
{
|
26 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
+
int G = p.grad;
|
28 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
+
scalar_t gain = (scalar_t)p.gain;
|
30 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
+
scalar_t one = (scalar_t)1;
|
32 |
+
scalar_t two = (scalar_t)2;
|
33 |
+
scalar_t expRange = (scalar_t)80;
|
34 |
+
scalar_t halfExpRange = (scalar_t)40;
|
35 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
+
|
38 |
+
// Loop over elements.
|
39 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
+
{
|
42 |
+
// Load.
|
43 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
+
scalar_t y = 0;
|
50 |
+
|
51 |
+
// Apply bias.
|
52 |
+
((G == 0) ? x : xref) += b;
|
53 |
+
|
54 |
+
// linear
|
55 |
+
if (A == 1)
|
56 |
+
{
|
57 |
+
if (G == 0) y = x;
|
58 |
+
if (G == 1) y = x;
|
59 |
+
}
|
60 |
+
|
61 |
+
// relu
|
62 |
+
if (A == 2)
|
63 |
+
{
|
64 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
// lrelu
|
69 |
+
if (A == 3)
|
70 |
+
{
|
71 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
+
}
|
74 |
+
|
75 |
+
// tanh
|
76 |
+
if (A == 4)
|
77 |
+
{
|
78 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
+
if (G == 1) y = x * (one - yy * yy);
|
80 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
+
}
|
82 |
+
|
83 |
+
// sigmoid
|
84 |
+
if (A == 5)
|
85 |
+
{
|
86 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
+
if (G == 1) y = x * yy * (one - yy);
|
88 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
+
}
|
90 |
+
|
91 |
+
// elu
|
92 |
+
if (A == 6)
|
93 |
+
{
|
94 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
+
}
|
98 |
+
|
99 |
+
// selu
|
100 |
+
if (A == 7)
|
101 |
+
{
|
102 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
+
}
|
106 |
+
|
107 |
+
// softplus
|
108 |
+
if (A == 8)
|
109 |
+
{
|
110 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
+
if (G == 1) y = x * (one - exp(-yy));
|
112 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
+
}
|
114 |
+
|
115 |
+
// swish
|
116 |
+
if (A == 9)
|
117 |
+
{
|
118 |
+
if (G == 0)
|
119 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
+
else
|
121 |
+
{
|
122 |
+
scalar_t c = exp(xref);
|
123 |
+
scalar_t d = c + one;
|
124 |
+
if (G == 1)
|
125 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
+
else
|
127 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Apply gain.
|
133 |
+
y *= gain * dy;
|
134 |
+
|
135 |
+
// Clamp.
|
136 |
+
if (clamp >= 0)
|
137 |
+
{
|
138 |
+
if (G == 0)
|
139 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
+
else
|
141 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Store.
|
145 |
+
((T*)p.y)[xi] = (T)y;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
//------------------------------------------------------------------------
|
150 |
+
// CUDA kernel selection.
|
151 |
+
|
152 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
+
{
|
154 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
+
return NULL;
|
164 |
+
}
|
165 |
+
|
166 |
+
//------------------------------------------------------------------------
|
167 |
+
// Template specializations.
|
168 |
+
|
169 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
+
|
173 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import ADD.dnnlib as dnnlib
|
15 |
+
|
16 |
+
from .. import custom_ops
|
17 |
+
from .. import misc
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
activation_funcs = {
|
22 |
+
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
23 |
+
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
24 |
+
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
25 |
+
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
26 |
+
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
27 |
+
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
28 |
+
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
29 |
+
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
30 |
+
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
31 |
+
}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
_plugin = None
|
36 |
+
_null_tensor = torch.empty([0])
|
37 |
+
|
38 |
+
def _init():
|
39 |
+
global _plugin
|
40 |
+
if _plugin is None:
|
41 |
+
_plugin = custom_ops.get_plugin(
|
42 |
+
module_name='bias_act_plugin',
|
43 |
+
sources=['bias_act.cpp', 'bias_act.cu'],
|
44 |
+
headers=['bias_act.h'],
|
45 |
+
source_dir=os.path.dirname(__file__),
|
46 |
+
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
47 |
+
)
|
48 |
+
return True
|
49 |
+
|
50 |
+
#----------------------------------------------------------------------------
|
51 |
+
|
52 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
53 |
+
r"""Fused bias and activation function.
|
54 |
+
|
55 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
56 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
57 |
+
the fused op is considerably more efficient than performing the same calculation
|
58 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
59 |
+
but not third order gradients.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
x: Input activation tensor. Can be of any shape.
|
63 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
64 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
65 |
+
corresponding to `dim`.
|
66 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
67 |
+
The value of `dim` is ignored if `b` is not specified.
|
68 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
69 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
70 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
71 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
72 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
73 |
+
See `activation_funcs` for the default scaling of each activation function.
|
74 |
+
If unsure, consider specifying 1.
|
75 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
76 |
+
the clamping (default).
|
77 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Tensor of the same shape and datatype as `x`.
|
81 |
+
"""
|
82 |
+
assert isinstance(x, torch.Tensor)
|
83 |
+
assert impl in ['ref', 'cuda']
|
84 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
85 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
86 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
87 |
+
|
88 |
+
#----------------------------------------------------------------------------
|
89 |
+
|
90 |
+
@misc.profiled_function
|
91 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
92 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
93 |
+
"""
|
94 |
+
assert isinstance(x, torch.Tensor)
|
95 |
+
assert clamp is None or clamp >= 0
|
96 |
+
spec = activation_funcs[act]
|
97 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
98 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
99 |
+
clamp = float(clamp if clamp is not None else -1)
|
100 |
+
|
101 |
+
# Add bias.
|
102 |
+
if b is not None:
|
103 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
104 |
+
assert 0 <= dim < x.ndim
|
105 |
+
assert b.shape[0] == x.shape[dim]
|
106 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
107 |
+
|
108 |
+
# Evaluate activation function.
|
109 |
+
alpha = float(alpha)
|
110 |
+
x = spec.func(x, alpha=alpha)
|
111 |
+
|
112 |
+
# Scale by gain.
|
113 |
+
gain = float(gain)
|
114 |
+
if gain != 1:
|
115 |
+
x = x * gain
|
116 |
+
|
117 |
+
# Clamp.
|
118 |
+
if clamp >= 0:
|
119 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
120 |
+
return x
|
121 |
+
|
122 |
+
#----------------------------------------------------------------------------
|
123 |
+
|
124 |
+
_bias_act_cuda_cache = dict()
|
125 |
+
|
126 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
127 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
128 |
+
"""
|
129 |
+
# Parse arguments.
|
130 |
+
assert clamp is None or clamp >= 0
|
131 |
+
spec = activation_funcs[act]
|
132 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
133 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
134 |
+
clamp = float(clamp if clamp is not None else -1)
|
135 |
+
|
136 |
+
# Lookup from cache.
|
137 |
+
key = (dim, act, alpha, gain, clamp)
|
138 |
+
if key in _bias_act_cuda_cache:
|
139 |
+
return _bias_act_cuda_cache[key]
|
140 |
+
|
141 |
+
# Forward op.
|
142 |
+
class BiasActCuda(torch.autograd.Function):
|
143 |
+
@staticmethod
|
144 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
145 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
146 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
147 |
+
b = b.contiguous() if b is not None else _null_tensor
|
148 |
+
y = x
|
149 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
150 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
151 |
+
ctx.save_for_backward(
|
152 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
153 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
154 |
+
y if 'y' in spec.ref else _null_tensor)
|
155 |
+
return y
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
159 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
160 |
+
x, b, y = ctx.saved_tensors
|
161 |
+
dx = None
|
162 |
+
db = None
|
163 |
+
|
164 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
165 |
+
dx = dy
|
166 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
167 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
168 |
+
|
169 |
+
if ctx.needs_input_grad[1]:
|
170 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
171 |
+
|
172 |
+
return dx, db
|
173 |
+
|
174 |
+
# Backward op.
|
175 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
176 |
+
@staticmethod
|
177 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
178 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
179 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
180 |
+
ctx.save_for_backward(
|
181 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
182 |
+
x, b, y)
|
183 |
+
return dx
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
187 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
188 |
+
dy, x, b, y = ctx.saved_tensors
|
189 |
+
d_dy = None
|
190 |
+
d_x = None
|
191 |
+
d_b = None
|
192 |
+
d_y = None
|
193 |
+
|
194 |
+
if ctx.needs_input_grad[0]:
|
195 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
196 |
+
|
197 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
198 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
199 |
+
|
200 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
201 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
202 |
+
|
203 |
+
return d_dy, d_x, d_b, d_y
|
204 |
+
|
205 |
+
# Add to cache.
|
206 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
207 |
+
return BiasActCuda
|
208 |
+
|
209 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
+
arbitrarily high order gradients with zero performance penalty."""
|
11 |
+
|
12 |
+
import contextlib
|
13 |
+
import torch
|
14 |
+
from pkg_resources import parse_version
|
15 |
+
|
16 |
+
# pylint: disable=redefined-builtin
|
17 |
+
# pylint: disable=arguments-differ
|
18 |
+
# pylint: disable=protected-access
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
enabled = False # Enable the custom op by setting this to true.
|
23 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
24 |
+
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
|
25 |
+
|
26 |
+
@contextlib.contextmanager
|
27 |
+
def no_weight_gradients(disable=True):
|
28 |
+
global weight_gradients_disabled
|
29 |
+
old = weight_gradients_disabled
|
30 |
+
if disable:
|
31 |
+
weight_gradients_disabled = True
|
32 |
+
yield
|
33 |
+
weight_gradients_disabled = old
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
38 |
+
if _should_use_custom_op(input):
|
39 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
40 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
41 |
+
|
42 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
43 |
+
if _should_use_custom_op(input):
|
44 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
45 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
46 |
+
|
47 |
+
#----------------------------------------------------------------------------
|
48 |
+
|
49 |
+
def _should_use_custom_op(input):
|
50 |
+
assert isinstance(input, torch.Tensor)
|
51 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
52 |
+
return False
|
53 |
+
if _use_pytorch_1_11_api:
|
54 |
+
# The work-around code doesn't work on PyTorch 1.11.0 onwards
|
55 |
+
return False
|
56 |
+
if input.device.type != 'cuda':
|
57 |
+
return False
|
58 |
+
return True
|
59 |
+
|
60 |
+
def _tuple_of_ints(xs, ndim):
|
61 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
62 |
+
assert len(xs) == ndim
|
63 |
+
assert all(isinstance(x, int) for x in xs)
|
64 |
+
return xs
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
|
68 |
+
_conv2d_gradfix_cache = dict()
|
69 |
+
_null_tensor = torch.empty([0])
|
70 |
+
|
71 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
72 |
+
# Parse arguments.
|
73 |
+
ndim = 2
|
74 |
+
weight_shape = tuple(weight_shape)
|
75 |
+
stride = _tuple_of_ints(stride, ndim)
|
76 |
+
padding = _tuple_of_ints(padding, ndim)
|
77 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
78 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
79 |
+
|
80 |
+
# Lookup from cache.
|
81 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
82 |
+
if key in _conv2d_gradfix_cache:
|
83 |
+
return _conv2d_gradfix_cache[key]
|
84 |
+
|
85 |
+
# Validate arguments.
|
86 |
+
assert groups >= 1
|
87 |
+
assert len(weight_shape) == ndim + 2
|
88 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
89 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
90 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
91 |
+
if not transpose:
|
92 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
93 |
+
else: # transpose
|
94 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
95 |
+
|
96 |
+
# Helpers.
|
97 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
98 |
+
def calc_output_padding(input_shape, output_shape):
|
99 |
+
if transpose:
|
100 |
+
return [0, 0]
|
101 |
+
return [
|
102 |
+
input_shape[i + 2]
|
103 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
104 |
+
- (1 - 2 * padding[i])
|
105 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
106 |
+
for i in range(ndim)
|
107 |
+
]
|
108 |
+
|
109 |
+
# Forward & backward.
|
110 |
+
class Conv2d(torch.autograd.Function):
|
111 |
+
@staticmethod
|
112 |
+
def forward(ctx, input, weight, bias):
|
113 |
+
assert weight.shape == weight_shape
|
114 |
+
ctx.save_for_backward(
|
115 |
+
input if weight.requires_grad else _null_tensor,
|
116 |
+
weight if input.requires_grad else _null_tensor,
|
117 |
+
)
|
118 |
+
ctx.input_shape = input.shape
|
119 |
+
|
120 |
+
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
|
121 |
+
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
|
122 |
+
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
|
123 |
+
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
|
124 |
+
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
|
125 |
+
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
|
126 |
+
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
127 |
+
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
128 |
+
|
129 |
+
# General case => cuDNN.
|
130 |
+
if transpose:
|
131 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
132 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def backward(ctx, grad_output):
|
136 |
+
input, weight = ctx.saved_tensors
|
137 |
+
input_shape = ctx.input_shape
|
138 |
+
grad_input = None
|
139 |
+
grad_weight = None
|
140 |
+
grad_bias = None
|
141 |
+
|
142 |
+
if ctx.needs_input_grad[0]:
|
143 |
+
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
|
144 |
+
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
145 |
+
grad_input = op.apply(grad_output, weight, None)
|
146 |
+
assert grad_input.shape == input_shape
|
147 |
+
|
148 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
149 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
150 |
+
assert grad_weight.shape == weight_shape
|
151 |
+
|
152 |
+
if ctx.needs_input_grad[2]:
|
153 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
154 |
+
|
155 |
+
return grad_input, grad_weight, grad_bias
|
156 |
+
|
157 |
+
# Gradient with respect to the weights.
|
158 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
159 |
+
@staticmethod
|
160 |
+
def forward(ctx, grad_output, input):
|
161 |
+
ctx.save_for_backward(
|
162 |
+
grad_output if input.requires_grad else _null_tensor,
|
163 |
+
input if grad_output.requires_grad else _null_tensor,
|
164 |
+
)
|
165 |
+
ctx.grad_output_shape = grad_output.shape
|
166 |
+
ctx.input_shape = input.shape
|
167 |
+
|
168 |
+
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
|
169 |
+
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
|
170 |
+
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
171 |
+
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
172 |
+
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
|
173 |
+
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
174 |
+
|
175 |
+
# General case => cuDNN.
|
176 |
+
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
|
177 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
178 |
+
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def backward(ctx, grad2_grad_weight):
|
182 |
+
grad_output, input = ctx.saved_tensors
|
183 |
+
grad_output_shape = ctx.grad_output_shape
|
184 |
+
input_shape = ctx.input_shape
|
185 |
+
grad2_grad_output = None
|
186 |
+
grad2_input = None
|
187 |
+
|
188 |
+
if ctx.needs_input_grad[0]:
|
189 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
190 |
+
assert grad2_grad_output.shape == grad_output_shape
|
191 |
+
|
192 |
+
if ctx.needs_input_grad[1]:
|
193 |
+
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
|
194 |
+
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
195 |
+
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
196 |
+
assert grad2_input.shape == input_shape
|
197 |
+
|
198 |
+
return grad2_grad_output, grad2_input
|
199 |
+
|
200 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
201 |
+
return Conv2d
|
202 |
+
|
203 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/conv2d_resample.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""2D convolution with optional up/downsampling."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .. import misc
|
14 |
+
from . import conv2d_gradfix
|
15 |
+
from . import upfirdn2d
|
16 |
+
from .upfirdn2d import _parse_padding
|
17 |
+
from .upfirdn2d import _get_filter_size
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def _get_weight_shape(w):
|
22 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
+
shape = [int(sz) for sz in w.shape]
|
24 |
+
misc.assert_shape(w, shape)
|
25 |
+
return shape
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
+
"""
|
32 |
+
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
+
|
34 |
+
# Flip weight if requested.
|
35 |
+
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
+
if not flip_weight and (kw > 1 or kh > 1):
|
37 |
+
w = w.flip([2, 3])
|
38 |
+
|
39 |
+
# Execute using conv2d_gradfix.
|
40 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
41 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
42 |
+
|
43 |
+
#----------------------------------------------------------------------------
|
44 |
+
|
45 |
+
@misc.profiled_function
|
46 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
47 |
+
r"""2D convolution with optional up/downsampling.
|
48 |
+
|
49 |
+
Padding is performed only once at the beginning, not between the operations.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
x: Input tensor of shape
|
53 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
54 |
+
w: Weight tensor of shape
|
55 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
56 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
57 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
58 |
+
up: Integer upsampling factor (default: 1).
|
59 |
+
down: Integer downsampling factor (default: 1).
|
60 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
61 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
62 |
+
(default: 0).
|
63 |
+
groups: Split input channels into N groups (default: 1).
|
64 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
65 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
69 |
+
"""
|
70 |
+
# Validate arguments.
|
71 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
72 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
73 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
74 |
+
assert isinstance(up, int) and (up >= 1)
|
75 |
+
assert isinstance(down, int) and (down >= 1)
|
76 |
+
assert isinstance(groups, int) and (groups >= 1)
|
77 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
78 |
+
fw, fh = _get_filter_size(f)
|
79 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
80 |
+
|
81 |
+
# Adjust padding to account for up/downsampling.
|
82 |
+
if up > 1:
|
83 |
+
px0 += (fw + up - 1) // 2
|
84 |
+
px1 += (fw - up) // 2
|
85 |
+
py0 += (fh + up - 1) // 2
|
86 |
+
py1 += (fh - up) // 2
|
87 |
+
if down > 1:
|
88 |
+
px0 += (fw - down + 1) // 2
|
89 |
+
px1 += (fw - down) // 2
|
90 |
+
py0 += (fh - down + 1) // 2
|
91 |
+
py1 += (fh - down) // 2
|
92 |
+
|
93 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
94 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
95 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
96 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
97 |
+
return x
|
98 |
+
|
99 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
100 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
101 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
102 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
103 |
+
return x
|
104 |
+
|
105 |
+
# Fast path: downsampling only => use strided convolution.
|
106 |
+
if down > 1 and up == 1:
|
107 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
108 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
109 |
+
return x
|
110 |
+
|
111 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
112 |
+
if up > 1:
|
113 |
+
if groups == 1:
|
114 |
+
w = w.transpose(0, 1)
|
115 |
+
else:
|
116 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
117 |
+
w = w.transpose(1, 2)
|
118 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
119 |
+
px0 -= kw - 1
|
120 |
+
px1 -= kw - up
|
121 |
+
py0 -= kh - 1
|
122 |
+
py1 -= kh - up
|
123 |
+
pxt = max(min(-px0, -px1), 0)
|
124 |
+
pyt = max(min(-py0, -py1), 0)
|
125 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
126 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
127 |
+
if down > 1:
|
128 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
129 |
+
return x
|
130 |
+
|
131 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
132 |
+
if up == 1 and down == 1:
|
133 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
134 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
135 |
+
|
136 |
+
# Fallback: Generic reference implementation.
|
137 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
138 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
139 |
+
if down > 1:
|
140 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
141 |
+
return x
|
142 |
+
|
143 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/filtered_lrelu.cpp
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "filtered_lrelu.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
|
17 |
+
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
|
18 |
+
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
|
19 |
+
{
|
20 |
+
// Set CUDA device.
|
21 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
22 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
23 |
+
|
24 |
+
// Validate arguments.
|
25 |
+
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
|
26 |
+
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
|
27 |
+
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
|
28 |
+
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
|
29 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
30 |
+
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
31 |
+
TORCH_CHECK(x.numel() > 0, "x is empty");
|
32 |
+
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
|
33 |
+
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
|
34 |
+
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
|
35 |
+
TORCH_CHECK(fu.numel() > 0, "fu is empty");
|
36 |
+
TORCH_CHECK(fd.numel() > 0, "fd is empty");
|
37 |
+
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
|
38 |
+
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
|
39 |
+
|
40 |
+
// Figure out how much shared memory is available on the device.
|
41 |
+
int maxSharedBytes = 0;
|
42 |
+
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
|
43 |
+
int sharedKB = maxSharedBytes >> 10;
|
44 |
+
|
45 |
+
// Populate enough launch parameters to check if a CUDA kernel exists.
|
46 |
+
filtered_lrelu_kernel_params p;
|
47 |
+
p.up = up;
|
48 |
+
p.down = down;
|
49 |
+
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
|
50 |
+
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
|
51 |
+
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
|
52 |
+
if (!test_spec.exec)
|
53 |
+
{
|
54 |
+
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
|
55 |
+
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
|
56 |
+
}
|
57 |
+
|
58 |
+
// Input/output element size.
|
59 |
+
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
|
60 |
+
|
61 |
+
// Input sizes.
|
62 |
+
int64_t xw = (int)x.size(3);
|
63 |
+
int64_t xh = (int)x.size(2);
|
64 |
+
int64_t fut_w = (int)fu.size(-1) - 1;
|
65 |
+
int64_t fut_h = (int)fu.size(0) - 1;
|
66 |
+
int64_t fdt_w = (int)fd.size(-1) - 1;
|
67 |
+
int64_t fdt_h = (int)fd.size(0) - 1;
|
68 |
+
|
69 |
+
// Logical size of upsampled buffer.
|
70 |
+
int64_t cw = xw * up + (px0 + px1) - fut_w;
|
71 |
+
int64_t ch = xh * up + (py0 + py1) - fut_h;
|
72 |
+
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
|
73 |
+
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
|
74 |
+
|
75 |
+
// Compute output size and allocate.
|
76 |
+
int64_t yw = (cw - fdt_w + (down - 1)) / down;
|
77 |
+
int64_t yh = (ch - fdt_h + (down - 1)) / down;
|
78 |
+
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
|
79 |
+
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
|
80 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
|
81 |
+
|
82 |
+
// Allocate sign tensor.
|
83 |
+
torch::Tensor so;
|
84 |
+
torch::Tensor s = si;
|
85 |
+
bool readSigns = !!s.numel();
|
86 |
+
int64_t sw_active = 0; // Active width of sign tensor.
|
87 |
+
if (writeSigns)
|
88 |
+
{
|
89 |
+
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
|
90 |
+
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
|
91 |
+
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
|
92 |
+
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
|
93 |
+
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
94 |
+
}
|
95 |
+
else if (readSigns)
|
96 |
+
sw_active = s.size(3) << 2;
|
97 |
+
|
98 |
+
// Validate sign tensor if in use.
|
99 |
+
if (readSigns || writeSigns)
|
100 |
+
{
|
101 |
+
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
102 |
+
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
103 |
+
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
104 |
+
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
105 |
+
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
106 |
+
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
|
107 |
+
}
|
108 |
+
|
109 |
+
// Populate rest of CUDA kernel parameters.
|
110 |
+
p.x = x.data_ptr();
|
111 |
+
p.y = y.data_ptr();
|
112 |
+
p.b = b.data_ptr();
|
113 |
+
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
114 |
+
p.fu = fu.data_ptr<float>();
|
115 |
+
p.fd = fd.data_ptr<float>();
|
116 |
+
p.pad0 = make_int2(px0, py0);
|
117 |
+
p.gain = gain;
|
118 |
+
p.slope = slope;
|
119 |
+
p.clamp = clamp;
|
120 |
+
p.flip = (flip_filters) ? 1 : 0;
|
121 |
+
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
122 |
+
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
123 |
+
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
|
124 |
+
p.sOfs = make_int2(sx, sy);
|
125 |
+
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
|
126 |
+
|
127 |
+
// x, y, b strides are in bytes.
|
128 |
+
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
|
129 |
+
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
|
130 |
+
p.bStride = sz * b.stride(0);
|
131 |
+
|
132 |
+
// fu, fd strides are in elements.
|
133 |
+
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
|
134 |
+
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
|
135 |
+
|
136 |
+
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
|
137 |
+
bool index64b = false;
|
138 |
+
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
|
139 |
+
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
|
140 |
+
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
|
141 |
+
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
|
142 |
+
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
|
143 |
+
if (s.numel() > INT_MAX) index64b = true;
|
144 |
+
|
145 |
+
// Choose CUDA kernel.
|
146 |
+
filtered_lrelu_kernel_spec spec = { 0 };
|
147 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
|
148 |
+
{
|
149 |
+
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
|
150 |
+
{
|
151 |
+
// Choose kernel based on index type, datatype and sign read/write modes.
|
152 |
+
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
|
153 |
+
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
|
154 |
+
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
|
155 |
+
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
|
156 |
+
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
|
157 |
+
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
|
158 |
+
}
|
159 |
+
});
|
160 |
+
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
|
161 |
+
|
162 |
+
// Launch CUDA kernel.
|
163 |
+
void* args[] = {&p};
|
164 |
+
int bx = spec.numWarps * 32;
|
165 |
+
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
|
166 |
+
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
|
167 |
+
int gz = p.yShape.z * p.yShape.w;
|
168 |
+
|
169 |
+
// Repeat multiple horizontal tiles in a CTA?
|
170 |
+
if (spec.xrep)
|
171 |
+
{
|
172 |
+
p.tilesXrep = spec.xrep;
|
173 |
+
p.tilesXdim = gx;
|
174 |
+
|
175 |
+
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
|
176 |
+
std::swap(gx, gy);
|
177 |
+
}
|
178 |
+
else
|
179 |
+
{
|
180 |
+
p.tilesXrep = 0;
|
181 |
+
p.tilesXdim = 0;
|
182 |
+
}
|
183 |
+
|
184 |
+
// Launch filter setup kernel.
|
185 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
|
186 |
+
|
187 |
+
// Copy kernels to constant memory.
|
188 |
+
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
|
189 |
+
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
|
190 |
+
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
|
191 |
+
|
192 |
+
// Set cache and shared memory configurations for main kernel.
|
193 |
+
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
|
194 |
+
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
|
195 |
+
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
|
196 |
+
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
|
197 |
+
|
198 |
+
// Launch main kernel.
|
199 |
+
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
|
200 |
+
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
|
201 |
+
{
|
202 |
+
p.blockZofs = zofs;
|
203 |
+
int subGz = std::min(maxSubGz, gz - zofs);
|
204 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
|
205 |
+
}
|
206 |
+
|
207 |
+
// Done.
|
208 |
+
return std::make_tuple(y, so, 0);
|
209 |
+
}
|
210 |
+
|
211 |
+
//------------------------------------------------------------------------
|
212 |
+
|
213 |
+
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
|
214 |
+
{
|
215 |
+
// Set CUDA device.
|
216 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
217 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
218 |
+
|
219 |
+
// Validate arguments.
|
220 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
221 |
+
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
222 |
+
TORCH_CHECK(x.numel() > 0, "x is empty");
|
223 |
+
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
|
224 |
+
|
225 |
+
// Output signs if we don't have sign input.
|
226 |
+
torch::Tensor so;
|
227 |
+
torch::Tensor s = si;
|
228 |
+
bool readSigns = !!s.numel();
|
229 |
+
if (writeSigns)
|
230 |
+
{
|
231 |
+
int64_t sw = x.size(3);
|
232 |
+
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
|
233 |
+
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
234 |
+
}
|
235 |
+
|
236 |
+
// Validate sign tensor if in use.
|
237 |
+
if (readSigns || writeSigns)
|
238 |
+
{
|
239 |
+
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
240 |
+
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
241 |
+
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
242 |
+
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
243 |
+
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
244 |
+
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
|
245 |
+
}
|
246 |
+
|
247 |
+
// Initialize CUDA kernel parameters.
|
248 |
+
filtered_lrelu_act_kernel_params p;
|
249 |
+
p.x = x.data_ptr();
|
250 |
+
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
251 |
+
p.gain = gain;
|
252 |
+
p.slope = slope;
|
253 |
+
p.clamp = clamp;
|
254 |
+
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
255 |
+
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
|
256 |
+
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
|
257 |
+
p.sOfs = make_int2(sx, sy);
|
258 |
+
|
259 |
+
// Choose CUDA kernel.
|
260 |
+
void* func = 0;
|
261 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
|
262 |
+
{
|
263 |
+
if (writeSigns)
|
264 |
+
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
|
265 |
+
else if (readSigns)
|
266 |
+
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
|
267 |
+
else
|
268 |
+
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
|
269 |
+
});
|
270 |
+
TORCH_CHECK(func, "internal error - CUDA kernel not found");
|
271 |
+
|
272 |
+
// Launch CUDA kernel.
|
273 |
+
void* args[] = {&p};
|
274 |
+
int bx = 128; // 4 warps per block.
|
275 |
+
|
276 |
+
// Logical size of launch = writeSigns ? p.s : p.x
|
277 |
+
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
|
278 |
+
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
|
279 |
+
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
|
280 |
+
gx = (gx - 1) / bx + 1;
|
281 |
+
|
282 |
+
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
|
283 |
+
const uint32_t gmax = 65535;
|
284 |
+
gy = std::min(gy, gmax);
|
285 |
+
gz = std::min(gz, gmax);
|
286 |
+
|
287 |
+
// Launch.
|
288 |
+
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
|
289 |
+
return so;
|
290 |
+
}
|
291 |
+
|
292 |
+
//------------------------------------------------------------------------
|
293 |
+
|
294 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
295 |
+
{
|
296 |
+
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
|
297 |
+
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
|
298 |
+
}
|
299 |
+
|
300 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/filtered_lrelu.cu
ADDED
@@ -0,0 +1,1284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "filtered_lrelu.h"
|
11 |
+
#include <cstdint>
|
12 |
+
|
13 |
+
//------------------------------------------------------------------------
|
14 |
+
// Helpers.
|
15 |
+
|
16 |
+
enum // Filter modes.
|
17 |
+
{
|
18 |
+
MODE_SUSD = 0, // Separable upsampling, separable downsampling.
|
19 |
+
MODE_FUSD = 1, // Full upsampling, separable downsampling.
|
20 |
+
MODE_SUFD = 2, // Separable upsampling, full downsampling.
|
21 |
+
MODE_FUFD = 3, // Full upsampling, full downsampling.
|
22 |
+
};
|
23 |
+
|
24 |
+
template <class T> struct InternalType;
|
25 |
+
template <> struct InternalType<double>
|
26 |
+
{
|
27 |
+
typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
|
28 |
+
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
|
29 |
+
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
|
30 |
+
__device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
|
31 |
+
};
|
32 |
+
template <> struct InternalType<float>
|
33 |
+
{
|
34 |
+
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
|
35 |
+
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
|
36 |
+
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
|
37 |
+
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
|
38 |
+
};
|
39 |
+
template <> struct InternalType<c10::Half>
|
40 |
+
{
|
41 |
+
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
|
42 |
+
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
|
43 |
+
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
|
44 |
+
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
|
45 |
+
};
|
46 |
+
|
47 |
+
#define MIN(A, B) ((A) < (B) ? (A) : (B))
|
48 |
+
#define MAX(A, B) ((A) > (B) ? (A) : (B))
|
49 |
+
#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
|
50 |
+
((B)==2) ? ((int)((A)+1) >> 1) : \
|
51 |
+
((B)==4) ? ((int)((A)+3) >> 2) : \
|
52 |
+
(((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
|
53 |
+
|
54 |
+
// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
|
55 |
+
template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
|
56 |
+
{
|
57 |
+
if ((N & (N-1)) && N <= 256)
|
58 |
+
y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
|
59 |
+
else
|
60 |
+
y = i/N;
|
61 |
+
|
62 |
+
x = i - y*N;
|
63 |
+
}
|
64 |
+
|
65 |
+
// Type cast stride before reading it.
|
66 |
+
template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
|
67 |
+
{
|
68 |
+
return *reinterpret_cast<const T*>(&x);
|
69 |
+
}
|
70 |
+
|
71 |
+
//------------------------------------------------------------------------
|
72 |
+
// Filters, setup kernel, copying function.
|
73 |
+
|
74 |
+
#define MAX_FILTER_SIZE 32
|
75 |
+
|
76 |
+
// Combined up/down filter buffers so that transfer can be done with one copy.
|
77 |
+
__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
|
78 |
+
__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
|
79 |
+
|
80 |
+
// Accessors to combined buffers to index up/down filters individually.
|
81 |
+
#define c_fu (c_fbuf)
|
82 |
+
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
|
83 |
+
#define g_fu (g_fbuf)
|
84 |
+
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
|
85 |
+
|
86 |
+
// Set up filters into global memory buffer.
|
87 |
+
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
|
88 |
+
{
|
89 |
+
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
|
90 |
+
{
|
91 |
+
int x, y;
|
92 |
+
fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
|
93 |
+
|
94 |
+
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
|
95 |
+
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
|
96 |
+
if (p.fuShape.y > 0)
|
97 |
+
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
|
98 |
+
else
|
99 |
+
g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
|
100 |
+
|
101 |
+
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
|
102 |
+
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
|
103 |
+
if (p.fdShape.y > 0)
|
104 |
+
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
|
105 |
+
else
|
106 |
+
g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
// Host function to copy filters written by setup kernel into constant buffer for main kernel.
|
111 |
+
template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
|
112 |
+
{
|
113 |
+
void* src = 0;
|
114 |
+
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
|
115 |
+
if (err) return err;
|
116 |
+
return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
|
117 |
+
}
|
118 |
+
|
119 |
+
//------------------------------------------------------------------------
|
120 |
+
// Coordinate spaces:
|
121 |
+
// - Relative to input tensor: inX, inY, tileInX, tileInY
|
122 |
+
// - Relative to input tile: relInX, relInY, tileInW, tileInH
|
123 |
+
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
|
124 |
+
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
|
125 |
+
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
|
126 |
+
//
|
127 |
+
// Relationships between coordinate spaces:
|
128 |
+
// - inX = tileInX + relInX
|
129 |
+
// - inY = tileInY + relInY
|
130 |
+
// - relUpX = relInX * up + phaseInX
|
131 |
+
// - relUpY = relInY * up + phaseInY
|
132 |
+
// - relUpX = relOutX * down
|
133 |
+
// - relUpY = relOutY * down
|
134 |
+
// - outX = tileOutX + relOutX
|
135 |
+
// - outY = tileOutY + relOutY
|
136 |
+
|
137 |
+
extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
|
138 |
+
|
139 |
+
template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
|
140 |
+
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
|
141 |
+
{
|
142 |
+
// Check that we don't try to support non-existing filter modes.
|
143 |
+
static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
|
144 |
+
static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
|
145 |
+
static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
|
146 |
+
static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
|
147 |
+
static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
|
148 |
+
static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
|
149 |
+
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
|
150 |
+
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
|
151 |
+
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
|
152 |
+
static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
|
153 |
+
static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
|
154 |
+
|
155 |
+
// Static definitions.
|
156 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
157 |
+
typedef typename InternalType<T>::vec2_t vec2_t;
|
158 |
+
typedef typename InternalType<T>::vec4_t vec4_t;
|
159 |
+
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
|
160 |
+
const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
|
161 |
+
const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
|
162 |
+
const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
|
163 |
+
const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
|
164 |
+
const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
|
165 |
+
|
166 |
+
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
|
167 |
+
const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
|
168 |
+
|
169 |
+
// Sizes of logical buffers.
|
170 |
+
const int szIn = tileInH_up * tileInW;
|
171 |
+
const int szUpX = tileInH_up * tileUpW;
|
172 |
+
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
|
173 |
+
const int szDownX = tileUpH * tileOutW;
|
174 |
+
|
175 |
+
// Sizes for shared memory arrays.
|
176 |
+
const int s_buf0_size_base =
|
177 |
+
(filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
|
178 |
+
(filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
|
179 |
+
(filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
|
180 |
+
(filterMode == MODE_FUFD) ? szIn :
|
181 |
+
-1;
|
182 |
+
const int s_buf1_size_base =
|
183 |
+
(filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
|
184 |
+
(filterMode == MODE_FUSD) ? szUpXY :
|
185 |
+
(filterMode == MODE_SUFD) ? szUpX :
|
186 |
+
(filterMode == MODE_FUFD) ? szUpXY :
|
187 |
+
-1;
|
188 |
+
|
189 |
+
// Ensure U128 alignment.
|
190 |
+
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
|
191 |
+
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
|
192 |
+
|
193 |
+
// Check at compile time that we don't use too much shared memory.
|
194 |
+
static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
|
195 |
+
|
196 |
+
// Declare shared memory arrays.
|
197 |
+
scalar_t* s_buf0;
|
198 |
+
scalar_t* s_buf1;
|
199 |
+
if (sharedKB <= 48)
|
200 |
+
{
|
201 |
+
// Allocate shared memory arrays here.
|
202 |
+
__shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
|
203 |
+
s_buf0 = s_buf0_st;
|
204 |
+
s_buf1 = s_buf0 + s_buf0_size;
|
205 |
+
}
|
206 |
+
else
|
207 |
+
{
|
208 |
+
// Use the dynamically allocated shared memory array.
|
209 |
+
s_buf0 = (scalar_t*)s_buf_raw;
|
210 |
+
s_buf1 = s_buf0 + s_buf0_size;
|
211 |
+
}
|
212 |
+
|
213 |
+
// Pointers to the buffers.
|
214 |
+
scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
|
215 |
+
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
|
216 |
+
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
|
217 |
+
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
|
218 |
+
if (filterMode == MODE_SUSD)
|
219 |
+
{
|
220 |
+
s_tileIn = s_buf0;
|
221 |
+
s_tileUpX = s_buf1;
|
222 |
+
s_tileUpXY = s_buf0;
|
223 |
+
s_tileDownX = s_buf1;
|
224 |
+
}
|
225 |
+
else if (filterMode == MODE_FUSD)
|
226 |
+
{
|
227 |
+
s_tileIn = s_buf0;
|
228 |
+
s_tileUpXY = s_buf1;
|
229 |
+
s_tileDownX = s_buf0;
|
230 |
+
}
|
231 |
+
else if (filterMode == MODE_SUFD)
|
232 |
+
{
|
233 |
+
s_tileIn = s_buf0;
|
234 |
+
s_tileUpX = s_buf1;
|
235 |
+
s_tileUpXY = s_buf0;
|
236 |
+
}
|
237 |
+
else if (filterMode == MODE_FUFD)
|
238 |
+
{
|
239 |
+
s_tileIn = s_buf0;
|
240 |
+
s_tileUpXY = s_buf1;
|
241 |
+
}
|
242 |
+
|
243 |
+
// Allow large grids in z direction via per-launch offset.
|
244 |
+
int channelIdx = blockIdx.z + p.blockZofs;
|
245 |
+
int batchIdx = channelIdx / p.yShape.z;
|
246 |
+
channelIdx -= batchIdx * p.yShape.z;
|
247 |
+
|
248 |
+
// Offset to output feature map. In bytes.
|
249 |
+
index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
|
250 |
+
|
251 |
+
// Sign shift amount.
|
252 |
+
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
|
253 |
+
|
254 |
+
// Inner tile loop.
|
255 |
+
#pragma unroll 1
|
256 |
+
for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
|
257 |
+
{
|
258 |
+
// Locate output tile.
|
259 |
+
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
|
260 |
+
int tileOutX = tileX * tileOutW;
|
261 |
+
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
|
262 |
+
|
263 |
+
// Locate input tile.
|
264 |
+
int tmpX = tileOutX * down - p.pad0.x;
|
265 |
+
int tmpY = tileOutY * down - p.pad0.y;
|
266 |
+
int tileInX = CEIL_DIV(tmpX, up);
|
267 |
+
int tileInY = CEIL_DIV(tmpY, up);
|
268 |
+
const int phaseInX = tileInX * up - tmpX;
|
269 |
+
const int phaseInY = tileInY * up - tmpY;
|
270 |
+
|
271 |
+
// Extra sync if input and output buffers are the same and we are not on first tile.
|
272 |
+
if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
|
273 |
+
__syncthreads();
|
274 |
+
|
275 |
+
// Load input tile & apply bias. Unrolled.
|
276 |
+
scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
|
277 |
+
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
|
278 |
+
int idx = threadIdx.x;
|
279 |
+
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
|
280 |
+
#pragma unroll
|
281 |
+
for (int loop = 0; loop < loopCountIN; loop++)
|
282 |
+
{
|
283 |
+
int relInX, relInY;
|
284 |
+
fast_div_mod<tileInW>(relInX, relInY, idx);
|
285 |
+
int inX = tileInX + relInX;
|
286 |
+
int inY = tileInY + relInY;
|
287 |
+
scalar_t v = 0;
|
288 |
+
|
289 |
+
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
|
290 |
+
v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
|
291 |
+
|
292 |
+
bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
|
293 |
+
if (!skip)
|
294 |
+
s_tileIn[idx] = v;
|
295 |
+
|
296 |
+
idx += threadsPerBlock;
|
297 |
+
}
|
298 |
+
|
299 |
+
if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
|
300 |
+
{
|
301 |
+
// Horizontal upsampling.
|
302 |
+
__syncthreads();
|
303 |
+
if (up == 4)
|
304 |
+
{
|
305 |
+
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
|
306 |
+
{
|
307 |
+
int relUpX0, relInY;
|
308 |
+
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
|
309 |
+
int relInX0 = relUpX0 / up;
|
310 |
+
int src0 = relInX0 + tileInW * relInY;
|
311 |
+
int dst = relInY * tileUpW + relUpX0;
|
312 |
+
vec4_t v = InternalType<T>::zero_vec4();
|
313 |
+
scalar_t a = s_tileIn[src0];
|
314 |
+
if (phaseInX == 0)
|
315 |
+
{
|
316 |
+
#pragma unroll
|
317 |
+
for (int step = 0; step < fuSize / up; step++)
|
318 |
+
{
|
319 |
+
v.x += a * (scalar_t)c_fu[step * up + 0];
|
320 |
+
a = s_tileIn[src0 + step + 1];
|
321 |
+
v.y += a * (scalar_t)c_fu[step * up + 3];
|
322 |
+
v.z += a * (scalar_t)c_fu[step * up + 2];
|
323 |
+
v.w += a * (scalar_t)c_fu[step * up + 1];
|
324 |
+
}
|
325 |
+
}
|
326 |
+
else if (phaseInX == 1)
|
327 |
+
{
|
328 |
+
#pragma unroll
|
329 |
+
for (int step = 0; step < fuSize / up; step++)
|
330 |
+
{
|
331 |
+
v.x += a * (scalar_t)c_fu[step * up + 1];
|
332 |
+
v.y += a * (scalar_t)c_fu[step * up + 0];
|
333 |
+
a = s_tileIn[src0 + step + 1];
|
334 |
+
v.z += a * (scalar_t)c_fu[step * up + 3];
|
335 |
+
v.w += a * (scalar_t)c_fu[step * up + 2];
|
336 |
+
}
|
337 |
+
}
|
338 |
+
else if (phaseInX == 2)
|
339 |
+
{
|
340 |
+
#pragma unroll
|
341 |
+
for (int step = 0; step < fuSize / up; step++)
|
342 |
+
{
|
343 |
+
v.x += a * (scalar_t)c_fu[step * up + 2];
|
344 |
+
v.y += a * (scalar_t)c_fu[step * up + 1];
|
345 |
+
v.z += a * (scalar_t)c_fu[step * up + 0];
|
346 |
+
a = s_tileIn[src0 + step + 1];
|
347 |
+
v.w += a * (scalar_t)c_fu[step * up + 3];
|
348 |
+
}
|
349 |
+
}
|
350 |
+
else // (phaseInX == 3)
|
351 |
+
{
|
352 |
+
#pragma unroll
|
353 |
+
for (int step = 0; step < fuSize / up; step++)
|
354 |
+
{
|
355 |
+
v.x += a * (scalar_t)c_fu[step * up + 3];
|
356 |
+
v.y += a * (scalar_t)c_fu[step * up + 2];
|
357 |
+
v.z += a * (scalar_t)c_fu[step * up + 1];
|
358 |
+
v.w += a * (scalar_t)c_fu[step * up + 0];
|
359 |
+
a = s_tileIn[src0 + step + 1];
|
360 |
+
}
|
361 |
+
}
|
362 |
+
s_tileUpX[dst+0] = v.x;
|
363 |
+
s_tileUpX[dst+1] = v.y;
|
364 |
+
s_tileUpX[dst+2] = v.z;
|
365 |
+
s_tileUpX[dst+3] = v.w;
|
366 |
+
}
|
367 |
+
}
|
368 |
+
else if (up == 2)
|
369 |
+
{
|
370 |
+
bool p0 = (phaseInX == 0);
|
371 |
+
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
|
372 |
+
{
|
373 |
+
int relUpX0, relInY;
|
374 |
+
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
|
375 |
+
int relInX0 = relUpX0 / up;
|
376 |
+
int src0 = relInX0 + tileInW * relInY;
|
377 |
+
int dst = relInY * tileUpW + relUpX0;
|
378 |
+
vec2_t v = InternalType<T>::zero_vec2();
|
379 |
+
scalar_t a = s_tileIn[src0];
|
380 |
+
if (p0) // (phaseInX == 0)
|
381 |
+
{
|
382 |
+
#pragma unroll
|
383 |
+
for (int step = 0; step < fuSize / up; step++)
|
384 |
+
{
|
385 |
+
v.x += a * (scalar_t)c_fu[step * up + 0];
|
386 |
+
a = s_tileIn[src0 + step + 1];
|
387 |
+
v.y += a * (scalar_t)c_fu[step * up + 1];
|
388 |
+
}
|
389 |
+
}
|
390 |
+
else // (phaseInX == 1)
|
391 |
+
{
|
392 |
+
#pragma unroll
|
393 |
+
for (int step = 0; step < fuSize / up; step++)
|
394 |
+
{
|
395 |
+
v.x += a * (scalar_t)c_fu[step * up + 1];
|
396 |
+
v.y += a * (scalar_t)c_fu[step * up + 0];
|
397 |
+
a = s_tileIn[src0 + step + 1];
|
398 |
+
}
|
399 |
+
}
|
400 |
+
s_tileUpX[dst+0] = v.x;
|
401 |
+
s_tileUpX[dst+1] = v.y;
|
402 |
+
}
|
403 |
+
}
|
404 |
+
|
405 |
+
// Vertical upsampling & nonlinearity.
|
406 |
+
|
407 |
+
__syncthreads();
|
408 |
+
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
|
409 |
+
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
|
410 |
+
int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
|
411 |
+
if (up == 4)
|
412 |
+
{
|
413 |
+
minY -= 3; // Adjust according to block height.
|
414 |
+
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
|
415 |
+
{
|
416 |
+
int relUpX, relInY0;
|
417 |
+
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
|
418 |
+
int relUpY0 = relInY0 * up;
|
419 |
+
int src0 = relInY0 * tileUpW + relUpX;
|
420 |
+
int dst = relUpY0 * tileUpW + relUpX;
|
421 |
+
vec4_t v = InternalType<T>::zero_vec4();
|
422 |
+
|
423 |
+
scalar_t a = s_tileUpX[src0];
|
424 |
+
if (phaseInY == 0)
|
425 |
+
{
|
426 |
+
#pragma unroll
|
427 |
+
for (int step = 0; step < fuSize / up; step++)
|
428 |
+
{
|
429 |
+
v.x += a * (scalar_t)c_fu[step * up + 0];
|
430 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
431 |
+
v.y += a * (scalar_t)c_fu[step * up + 3];
|
432 |
+
v.z += a * (scalar_t)c_fu[step * up + 2];
|
433 |
+
v.w += a * (scalar_t)c_fu[step * up + 1];
|
434 |
+
}
|
435 |
+
}
|
436 |
+
else if (phaseInY == 1)
|
437 |
+
{
|
438 |
+
#pragma unroll
|
439 |
+
for (int step = 0; step < fuSize / up; step++)
|
440 |
+
{
|
441 |
+
v.x += a * (scalar_t)c_fu[step * up + 1];
|
442 |
+
v.y += a * (scalar_t)c_fu[step * up + 0];
|
443 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
444 |
+
v.z += a * (scalar_t)c_fu[step * up + 3];
|
445 |
+
v.w += a * (scalar_t)c_fu[step * up + 2];
|
446 |
+
}
|
447 |
+
}
|
448 |
+
else if (phaseInY == 2)
|
449 |
+
{
|
450 |
+
#pragma unroll
|
451 |
+
for (int step = 0; step < fuSize / up; step++)
|
452 |
+
{
|
453 |
+
v.x += a * (scalar_t)c_fu[step * up + 2];
|
454 |
+
v.y += a * (scalar_t)c_fu[step * up + 1];
|
455 |
+
v.z += a * (scalar_t)c_fu[step * up + 0];
|
456 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
457 |
+
v.w += a * (scalar_t)c_fu[step * up + 3];
|
458 |
+
}
|
459 |
+
}
|
460 |
+
else // (phaseInY == 3)
|
461 |
+
{
|
462 |
+
#pragma unroll
|
463 |
+
for (int step = 0; step < fuSize / up; step++)
|
464 |
+
{
|
465 |
+
v.x += a * (scalar_t)c_fu[step * up + 3];
|
466 |
+
v.y += a * (scalar_t)c_fu[step * up + 2];
|
467 |
+
v.z += a * (scalar_t)c_fu[step * up + 1];
|
468 |
+
v.w += a * (scalar_t)c_fu[step * up + 0];
|
469 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
470 |
+
}
|
471 |
+
}
|
472 |
+
|
473 |
+
int x = tileOutX * down + relUpX;
|
474 |
+
int y = tileOutY * down + relUpY0;
|
475 |
+
int signX = x + p.sOfs.x;
|
476 |
+
int signY = y + p.sOfs.y;
|
477 |
+
int signZ = blockIdx.z + p.blockZofs;
|
478 |
+
int signXb = signX >> 2;
|
479 |
+
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
480 |
+
index_t si1 = si0 + p.sShape.x;
|
481 |
+
index_t si2 = si0 + p.sShape.x * 2;
|
482 |
+
index_t si3 = si0 + p.sShape.x * 3;
|
483 |
+
|
484 |
+
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
485 |
+
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
486 |
+
v.z *= (scalar_t)((float)up * (float)up * p.gain);
|
487 |
+
v.w *= (scalar_t)((float)up * (float)up * p.gain);
|
488 |
+
|
489 |
+
if (signWrite)
|
490 |
+
{
|
491 |
+
if (!enableWriteSkip)
|
492 |
+
{
|
493 |
+
// Determine and write signs.
|
494 |
+
int sx = __float_as_uint(v.x) >> 31 << 0;
|
495 |
+
int sy = __float_as_uint(v.y) >> 31 << 8;
|
496 |
+
int sz = __float_as_uint(v.z) >> 31 << 16;
|
497 |
+
int sw = __float_as_uint(v.w) >> 31 << 24;
|
498 |
+
if (sx) v.x *= p.slope;
|
499 |
+
if (sy) v.y *= p.slope;
|
500 |
+
if (sz) v.z *= p.slope;
|
501 |
+
if (sw) v.w *= p.slope;
|
502 |
+
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
503 |
+
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
504 |
+
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
505 |
+
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
506 |
+
|
507 |
+
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
508 |
+
{
|
509 |
+
// Combine signs.
|
510 |
+
uint32_t s = sx + sy + sw + sz;
|
511 |
+
s <<= (signX & 3) << 1;
|
512 |
+
s |= __shfl_xor_sync(groupMask, s, 1);
|
513 |
+
s |= __shfl_xor_sync(groupMask, s, 2);
|
514 |
+
|
515 |
+
// Write signs.
|
516 |
+
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
517 |
+
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
518 |
+
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
|
519 |
+
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
|
520 |
+
}
|
521 |
+
}
|
522 |
+
else
|
523 |
+
{
|
524 |
+
// Determine and write signs.
|
525 |
+
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
526 |
+
{
|
527 |
+
int sx = __float_as_uint(v.x) >> 31 << 0;
|
528 |
+
int sy = __float_as_uint(v.y) >> 31 << 8;
|
529 |
+
int sz = __float_as_uint(v.z) >> 31 << 16;
|
530 |
+
int sw = __float_as_uint(v.w) >> 31 << 24;
|
531 |
+
if (sx) v.x *= p.slope;
|
532 |
+
if (sy) v.y *= p.slope;
|
533 |
+
if (sz) v.z *= p.slope;
|
534 |
+
if (sw) v.w *= p.slope;
|
535 |
+
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
536 |
+
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
537 |
+
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
538 |
+
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
539 |
+
|
540 |
+
// Combine signs.
|
541 |
+
uint32_t s = sx + sy + sw + sz;
|
542 |
+
s <<= (signX & 3) << 1;
|
543 |
+
s |= __shfl_xor_sync(groupMask, s, 1);
|
544 |
+
s |= __shfl_xor_sync(groupMask, s, 2);
|
545 |
+
|
546 |
+
// Write signs.
|
547 |
+
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
548 |
+
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
549 |
+
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
|
550 |
+
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
|
551 |
+
}
|
552 |
+
else
|
553 |
+
{
|
554 |
+
// Just compute the values.
|
555 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
556 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
557 |
+
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
558 |
+
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
559 |
+
}
|
560 |
+
}
|
561 |
+
}
|
562 |
+
else if (signRead) // Read signs and apply.
|
563 |
+
{
|
564 |
+
if ((uint32_t)signXb < p.swLimit)
|
565 |
+
{
|
566 |
+
int ss = (signX & 3) << 1;
|
567 |
+
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
|
568 |
+
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
|
569 |
+
if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
|
570 |
+
if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
|
571 |
+
}
|
572 |
+
}
|
573 |
+
else // Forward pass with no sign write.
|
574 |
+
{
|
575 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
576 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
577 |
+
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
578 |
+
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
579 |
+
}
|
580 |
+
|
581 |
+
s_tileUpXY[dst + 0 * tileUpW] = v.x;
|
582 |
+
if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
|
583 |
+
if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
|
584 |
+
if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
|
585 |
+
}
|
586 |
+
}
|
587 |
+
else if (up == 2)
|
588 |
+
{
|
589 |
+
minY -= 1; // Adjust according to block height.
|
590 |
+
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
|
591 |
+
{
|
592 |
+
int relUpX, relInY0;
|
593 |
+
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
|
594 |
+
int relUpY0 = relInY0 * up;
|
595 |
+
int src0 = relInY0 * tileUpW + relUpX;
|
596 |
+
int dst = relUpY0 * tileUpW + relUpX;
|
597 |
+
vec2_t v = InternalType<T>::zero_vec2();
|
598 |
+
|
599 |
+
scalar_t a = s_tileUpX[src0];
|
600 |
+
if (phaseInY == 0)
|
601 |
+
{
|
602 |
+
#pragma unroll
|
603 |
+
for (int step = 0; step < fuSize / up; step++)
|
604 |
+
{
|
605 |
+
v.x += a * (scalar_t)c_fu[step * up + 0];
|
606 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
607 |
+
v.y += a * (scalar_t)c_fu[step * up + 1];
|
608 |
+
}
|
609 |
+
}
|
610 |
+
else // (phaseInY == 1)
|
611 |
+
{
|
612 |
+
#pragma unroll
|
613 |
+
for (int step = 0; step < fuSize / up; step++)
|
614 |
+
{
|
615 |
+
v.x += a * (scalar_t)c_fu[step * up + 1];
|
616 |
+
v.y += a * (scalar_t)c_fu[step * up + 0];
|
617 |
+
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
618 |
+
}
|
619 |
+
}
|
620 |
+
|
621 |
+
int x = tileOutX * down + relUpX;
|
622 |
+
int y = tileOutY * down + relUpY0;
|
623 |
+
int signX = x + p.sOfs.x;
|
624 |
+
int signY = y + p.sOfs.y;
|
625 |
+
int signZ = blockIdx.z + p.blockZofs;
|
626 |
+
int signXb = signX >> 2;
|
627 |
+
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
628 |
+
index_t si1 = si0 + p.sShape.x;
|
629 |
+
|
630 |
+
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
631 |
+
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
632 |
+
|
633 |
+
if (signWrite)
|
634 |
+
{
|
635 |
+
if (!enableWriteSkip)
|
636 |
+
{
|
637 |
+
// Determine and write signs.
|
638 |
+
int sx = __float_as_uint(v.x) >> 31 << 0;
|
639 |
+
int sy = __float_as_uint(v.y) >> 31 << 8;
|
640 |
+
if (sx) v.x *= p.slope;
|
641 |
+
if (sy) v.y *= p.slope;
|
642 |
+
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
643 |
+
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
644 |
+
|
645 |
+
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
646 |
+
{
|
647 |
+
// Combine signs.
|
648 |
+
int s = sx + sy;
|
649 |
+
s <<= signXo;
|
650 |
+
s |= __shfl_xor_sync(groupMask, s, 1);
|
651 |
+
s |= __shfl_xor_sync(groupMask, s, 2);
|
652 |
+
|
653 |
+
// Write signs.
|
654 |
+
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
655 |
+
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
656 |
+
}
|
657 |
+
}
|
658 |
+
else
|
659 |
+
{
|
660 |
+
// Determine and write signs.
|
661 |
+
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
662 |
+
{
|
663 |
+
int sx = __float_as_uint(v.x) >> 31 << 0;
|
664 |
+
int sy = __float_as_uint(v.y) >> 31 << 8;
|
665 |
+
if (sx) v.x *= p.slope;
|
666 |
+
if (sy) v.y *= p.slope;
|
667 |
+
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
668 |
+
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
669 |
+
|
670 |
+
// Combine signs.
|
671 |
+
int s = sx + sy;
|
672 |
+
s <<= signXo;
|
673 |
+
s |= __shfl_xor_sync(groupMask, s, 1);
|
674 |
+
s |= __shfl_xor_sync(groupMask, s, 2);
|
675 |
+
|
676 |
+
// Write signs.
|
677 |
+
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
678 |
+
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
679 |
+
}
|
680 |
+
else
|
681 |
+
{
|
682 |
+
// Just compute the values.
|
683 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
684 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
685 |
+
}
|
686 |
+
}
|
687 |
+
}
|
688 |
+
else if (signRead) // Read signs and apply.
|
689 |
+
{
|
690 |
+
if ((uint32_t)signXb < p.swLimit)
|
691 |
+
{
|
692 |
+
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
|
693 |
+
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
|
694 |
+
}
|
695 |
+
}
|
696 |
+
else // Forward pass with no sign write.
|
697 |
+
{
|
698 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
699 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
700 |
+
}
|
701 |
+
|
702 |
+
if (!downInline)
|
703 |
+
{
|
704 |
+
// Write into temporary buffer.
|
705 |
+
s_tileUpXY[dst] = v.x;
|
706 |
+
if (relUpY0 < tileUpH - 1)
|
707 |
+
s_tileUpXY[dst + tileUpW] = v.y;
|
708 |
+
}
|
709 |
+
else
|
710 |
+
{
|
711 |
+
// Write directly into output buffer.
|
712 |
+
if ((uint32_t)x < p.yShape.x)
|
713 |
+
{
|
714 |
+
int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
|
715 |
+
index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
|
716 |
+
if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
|
717 |
+
if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
|
718 |
+
}
|
719 |
+
}
|
720 |
+
}
|
721 |
+
}
|
722 |
+
}
|
723 |
+
else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
|
724 |
+
{
|
725 |
+
// Full upsampling filter.
|
726 |
+
|
727 |
+
if (up == 2)
|
728 |
+
{
|
729 |
+
// 2 x 2-wide.
|
730 |
+
__syncthreads();
|
731 |
+
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
|
732 |
+
for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
|
733 |
+
{
|
734 |
+
int relUpX0, relUpY0;
|
735 |
+
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
|
736 |
+
int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
|
737 |
+
int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
|
738 |
+
int src0 = relInX0 + tileInW * relInY0;
|
739 |
+
int tap0y = (relInY0 * up + phaseInY - relUpY0);
|
740 |
+
|
741 |
+
#define X_LOOP(TAPY, PX) \
|
742 |
+
for (int sx = 0; sx < fuSize / up; sx++) \
|
743 |
+
{ \
|
744 |
+
v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
|
745 |
+
v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
|
746 |
+
v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
|
747 |
+
v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
|
748 |
+
}
|
749 |
+
|
750 |
+
vec4_t v = InternalType<T>::zero_vec4();
|
751 |
+
if (tap0y == 0 && phaseInX == 0)
|
752 |
+
#pragma unroll
|
753 |
+
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
754 |
+
#pragma unroll
|
755 |
+
X_LOOP(0, 0) }
|
756 |
+
if (tap0y == 0 && phaseInX == 1)
|
757 |
+
#pragma unroll
|
758 |
+
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
759 |
+
#pragma unroll
|
760 |
+
X_LOOP(0, 1) }
|
761 |
+
if (tap0y == 1 && phaseInX == 0)
|
762 |
+
#pragma unroll
|
763 |
+
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
764 |
+
#pragma unroll
|
765 |
+
X_LOOP(1, 0) }
|
766 |
+
if (tap0y == 1 && phaseInX == 1)
|
767 |
+
#pragma unroll
|
768 |
+
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
769 |
+
#pragma unroll
|
770 |
+
X_LOOP(1, 1) }
|
771 |
+
|
772 |
+
#undef X_LOOP
|
773 |
+
|
774 |
+
int x = tileOutX * down + relUpX0;
|
775 |
+
int y = tileOutY * down + relUpY0;
|
776 |
+
int signX = x + p.sOfs.x;
|
777 |
+
int signY = y + p.sOfs.y;
|
778 |
+
int signZ = blockIdx.z + p.blockZofs;
|
779 |
+
int signXb = signX >> 2;
|
780 |
+
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
781 |
+
|
782 |
+
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
783 |
+
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
784 |
+
v.z *= (scalar_t)((float)up * (float)up * p.gain);
|
785 |
+
v.w *= (scalar_t)((float)up * (float)up * p.gain);
|
786 |
+
|
787 |
+
if (signWrite)
|
788 |
+
{
|
789 |
+
if (!enableWriteSkip)
|
790 |
+
{
|
791 |
+
// Determine and write signs.
|
792 |
+
int sx = __float_as_uint(v.x) >> 31;
|
793 |
+
int sy = __float_as_uint(v.y) >> 31;
|
794 |
+
int sz = __float_as_uint(v.z) >> 31;
|
795 |
+
int sw = __float_as_uint(v.w) >> 31;
|
796 |
+
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
797 |
+
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
798 |
+
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
799 |
+
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
800 |
+
|
801 |
+
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
802 |
+
{
|
803 |
+
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
|
804 |
+
}
|
805 |
+
}
|
806 |
+
else
|
807 |
+
{
|
808 |
+
// Determine and write signs.
|
809 |
+
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
810 |
+
{
|
811 |
+
int sx = __float_as_uint(v.x) >> 31;
|
812 |
+
int sy = __float_as_uint(v.y) >> 31;
|
813 |
+
int sz = __float_as_uint(v.z) >> 31;
|
814 |
+
int sw = __float_as_uint(v.w) >> 31;
|
815 |
+
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
816 |
+
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
817 |
+
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
818 |
+
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
819 |
+
|
820 |
+
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
|
821 |
+
}
|
822 |
+
else
|
823 |
+
{
|
824 |
+
// Just compute the values.
|
825 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
826 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
827 |
+
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
828 |
+
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
829 |
+
}
|
830 |
+
}
|
831 |
+
}
|
832 |
+
else if (signRead) // Read sign and apply.
|
833 |
+
{
|
834 |
+
if ((uint32_t)signY < p.sShape.y)
|
835 |
+
{
|
836 |
+
int s = 0;
|
837 |
+
if ((uint32_t)signXb < p.swLimit) s = p.s[si];
|
838 |
+
if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
|
839 |
+
s >>= (signX & 3) << 1;
|
840 |
+
if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
|
841 |
+
if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
|
842 |
+
if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
|
843 |
+
if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
|
844 |
+
}
|
845 |
+
}
|
846 |
+
else // Forward pass with no sign write.
|
847 |
+
{
|
848 |
+
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
849 |
+
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
850 |
+
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
851 |
+
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
852 |
+
}
|
853 |
+
|
854 |
+
s_tileUpXY[idx + 0] = v.x;
|
855 |
+
s_tileUpXY[idx + 1] = v.y;
|
856 |
+
s_tileUpXY[idx + 2] = v.z;
|
857 |
+
s_tileUpXY[idx + 3] = v.w;
|
858 |
+
}
|
859 |
+
}
|
860 |
+
else if (up == 1)
|
861 |
+
{
|
862 |
+
__syncthreads();
|
863 |
+
uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
|
864 |
+
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
|
865 |
+
for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
|
866 |
+
{
|
867 |
+
int relUpX0, relUpY0;
|
868 |
+
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
|
869 |
+
scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
|
870 |
+
|
871 |
+
int x = tileOutX * down + relUpX0;
|
872 |
+
int y = tileOutY * down + relUpY0;
|
873 |
+
int signX = x + p.sOfs.x;
|
874 |
+
int signY = y + p.sOfs.y;
|
875 |
+
int signZ = blockIdx.z + p.blockZofs;
|
876 |
+
int signXb = signX >> 2;
|
877 |
+
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
878 |
+
v *= (scalar_t)((float)up * (float)up * p.gain);
|
879 |
+
|
880 |
+
if (signWrite)
|
881 |
+
{
|
882 |
+
if (!enableWriteSkip)
|
883 |
+
{
|
884 |
+
// Determine and write sign.
|
885 |
+
uint32_t s = 0;
|
886 |
+
uint32_t signXbit = (1u << signXo);
|
887 |
+
if (v < 0.f)
|
888 |
+
{
|
889 |
+
s = signXbit;
|
890 |
+
v *= p.slope;
|
891 |
+
}
|
892 |
+
if (fabsf(v) > p.clamp)
|
893 |
+
{
|
894 |
+
s = signXbit * 2;
|
895 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
896 |
+
}
|
897 |
+
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
898 |
+
{
|
899 |
+
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
|
900 |
+
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
|
901 |
+
p.s[si] = s; // Write.
|
902 |
+
}
|
903 |
+
}
|
904 |
+
else
|
905 |
+
{
|
906 |
+
// Determine and write sign.
|
907 |
+
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
908 |
+
{
|
909 |
+
uint32_t s = 0;
|
910 |
+
uint32_t signXbit = (1u << signXo);
|
911 |
+
if (v < 0.f)
|
912 |
+
{
|
913 |
+
s = signXbit;
|
914 |
+
v *= p.slope;
|
915 |
+
}
|
916 |
+
if (fabsf(v) > p.clamp)
|
917 |
+
{
|
918 |
+
s = signXbit * 2;
|
919 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
920 |
+
}
|
921 |
+
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
|
922 |
+
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
|
923 |
+
p.s[si] = s; // Write.
|
924 |
+
}
|
925 |
+
else
|
926 |
+
{
|
927 |
+
// Just compute the value.
|
928 |
+
if (v < 0.f) v *= p.slope;
|
929 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
930 |
+
}
|
931 |
+
}
|
932 |
+
}
|
933 |
+
else if (signRead)
|
934 |
+
{
|
935 |
+
// Read sign and apply if within sign tensor bounds.
|
936 |
+
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
|
937 |
+
{
|
938 |
+
int s = p.s[si];
|
939 |
+
s >>= signXo;
|
940 |
+
if (s & 1) v *= p.slope;
|
941 |
+
if (s & 2) v = 0.f;
|
942 |
+
}
|
943 |
+
}
|
944 |
+
else // Forward pass with no sign write.
|
945 |
+
{
|
946 |
+
if (v < 0.f) v *= p.slope;
|
947 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
948 |
+
}
|
949 |
+
|
950 |
+
if (!downInline) // Write into temporary buffer.
|
951 |
+
s_tileUpXY[idx] = v;
|
952 |
+
else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
|
953 |
+
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
|
954 |
+
}
|
955 |
+
}
|
956 |
+
}
|
957 |
+
|
958 |
+
// Downsampling.
|
959 |
+
if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
|
960 |
+
{
|
961 |
+
// Horizontal downsampling.
|
962 |
+
__syncthreads();
|
963 |
+
if (down == 4 && tileOutW % 4 == 0)
|
964 |
+
{
|
965 |
+
// Calculate 4 pixels at a time.
|
966 |
+
for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
|
967 |
+
{
|
968 |
+
int relOutX0, relUpY;
|
969 |
+
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
970 |
+
int relUpX0 = relOutX0 * down;
|
971 |
+
int src0 = relUpY * tileUpW + relUpX0;
|
972 |
+
vec4_t v = InternalType<T>::zero_vec4();
|
973 |
+
#pragma unroll
|
974 |
+
for (int step = 0; step < fdSize; step++)
|
975 |
+
{
|
976 |
+
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
|
977 |
+
v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
|
978 |
+
v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
|
979 |
+
v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
|
980 |
+
}
|
981 |
+
s_tileDownX[idx+0] = v.x;
|
982 |
+
s_tileDownX[idx+1] = v.y;
|
983 |
+
s_tileDownX[idx+2] = v.z;
|
984 |
+
s_tileDownX[idx+3] = v.w;
|
985 |
+
}
|
986 |
+
}
|
987 |
+
else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
|
988 |
+
{
|
989 |
+
// Calculate 2 pixels at a time.
|
990 |
+
for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
|
991 |
+
{
|
992 |
+
int relOutX0, relUpY;
|
993 |
+
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
994 |
+
int relUpX0 = relOutX0 * down;
|
995 |
+
int src0 = relUpY * tileUpW + relUpX0;
|
996 |
+
vec2_t v = InternalType<T>::zero_vec2();
|
997 |
+
#pragma unroll
|
998 |
+
for (int step = 0; step < fdSize; step++)
|
999 |
+
{
|
1000 |
+
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
|
1001 |
+
v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
|
1002 |
+
}
|
1003 |
+
s_tileDownX[idx+0] = v.x;
|
1004 |
+
s_tileDownX[idx+1] = v.y;
|
1005 |
+
}
|
1006 |
+
}
|
1007 |
+
else
|
1008 |
+
{
|
1009 |
+
// Calculate 1 pixel at a time.
|
1010 |
+
for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
|
1011 |
+
{
|
1012 |
+
int relOutX0, relUpY;
|
1013 |
+
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
1014 |
+
int relUpX0 = relOutX0 * down;
|
1015 |
+
int src = relUpY * tileUpW + relUpX0;
|
1016 |
+
scalar_t v = 0.f;
|
1017 |
+
#pragma unroll
|
1018 |
+
for (int step = 0; step < fdSize; step++)
|
1019 |
+
v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
|
1020 |
+
s_tileDownX[idx] = v;
|
1021 |
+
}
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
// Vertical downsampling & store output tile.
|
1025 |
+
__syncthreads();
|
1026 |
+
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
|
1027 |
+
{
|
1028 |
+
int relOutX, relOutY0;
|
1029 |
+
fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
|
1030 |
+
int relUpY0 = relOutY0 * down;
|
1031 |
+
int src0 = relUpY0 * tileOutW + relOutX;
|
1032 |
+
scalar_t v = 0;
|
1033 |
+
#pragma unroll
|
1034 |
+
for (int step = 0; step < fdSize; step++)
|
1035 |
+
v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
|
1036 |
+
|
1037 |
+
int outX = tileOutX + relOutX;
|
1038 |
+
int outY = tileOutY + relOutY0;
|
1039 |
+
|
1040 |
+
if (outX < p.yShape.x & outY < p.yShape.y)
|
1041 |
+
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
|
1042 |
+
}
|
1043 |
+
}
|
1044 |
+
else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
|
1045 |
+
{
|
1046 |
+
// Full downsampling filter.
|
1047 |
+
if (down == 2)
|
1048 |
+
{
|
1049 |
+
// 2-wide.
|
1050 |
+
__syncthreads();
|
1051 |
+
for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
|
1052 |
+
{
|
1053 |
+
int relOutX0, relOutY0;
|
1054 |
+
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
|
1055 |
+
int relUpX0 = relOutX0 * down;
|
1056 |
+
int relUpY0 = relOutY0 * down;
|
1057 |
+
int src0 = relUpY0 * tileUpW + relUpX0;
|
1058 |
+
vec2_t v = InternalType<T>::zero_vec2();
|
1059 |
+
#pragma unroll
|
1060 |
+
for (int sy = 0; sy < fdSize; sy++)
|
1061 |
+
#pragma unroll
|
1062 |
+
for (int sx = 0; sx < fdSize; sx++)
|
1063 |
+
{
|
1064 |
+
v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
|
1065 |
+
v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
int outX = tileOutX + relOutX0;
|
1069 |
+
int outY = tileOutY + relOutY0;
|
1070 |
+
if ((uint32_t)outY < p.yShape.y)
|
1071 |
+
{
|
1072 |
+
index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
|
1073 |
+
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
|
1074 |
+
if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
|
1075 |
+
}
|
1076 |
+
}
|
1077 |
+
}
|
1078 |
+
else if (down == 1 && !downInline)
|
1079 |
+
{
|
1080 |
+
// Thread per pixel.
|
1081 |
+
__syncthreads();
|
1082 |
+
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
|
1083 |
+
{
|
1084 |
+
int relOutX0, relOutY0;
|
1085 |
+
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
|
1086 |
+
scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
|
1087 |
+
|
1088 |
+
int outX = tileOutX + relOutX0;
|
1089 |
+
int outY = tileOutY + relOutY0;
|
1090 |
+
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
|
1091 |
+
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
|
1092 |
+
}
|
1093 |
+
}
|
1094 |
+
}
|
1095 |
+
|
1096 |
+
if (!enableXrep)
|
1097 |
+
break;
|
1098 |
+
}
|
1099 |
+
}
|
1100 |
+
|
1101 |
+
//------------------------------------------------------------------------
|
1102 |
+
// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
|
1103 |
+
// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
|
1104 |
+
|
1105 |
+
template <class T, bool signWrite, bool signRead>
|
1106 |
+
static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
|
1107 |
+
{
|
1108 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
1109 |
+
|
1110 |
+
// Indexing.
|
1111 |
+
int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
|
1112 |
+
int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
|
1113 |
+
int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
|
1114 |
+
|
1115 |
+
// Loop to accommodate oversized tensors.
|
1116 |
+
for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
|
1117 |
+
for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
|
1118 |
+
{
|
1119 |
+
// Extract z and w (channel, minibatch index).
|
1120 |
+
int32_t w = q / p.xShape.z;
|
1121 |
+
int32_t z = q - w * p.xShape.z;
|
1122 |
+
|
1123 |
+
// Choose behavior based on sign read/write mode.
|
1124 |
+
if (signWrite)
|
1125 |
+
{
|
1126 |
+
// Process value if in p.x.
|
1127 |
+
uint32_t s = 0;
|
1128 |
+
if (x < p.xShape.x && y < p.xShape.y)
|
1129 |
+
{
|
1130 |
+
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1131 |
+
T* pv = ((T*)p.x) + ix;
|
1132 |
+
scalar_t v = (scalar_t)(*pv);
|
1133 |
+
|
1134 |
+
// Gain, LReLU, clamp.
|
1135 |
+
v *= p.gain;
|
1136 |
+
if (v < 0.f)
|
1137 |
+
{
|
1138 |
+
v *= p.slope;
|
1139 |
+
s = 1; // Sign.
|
1140 |
+
}
|
1141 |
+
if (fabsf(v) > p.clamp)
|
1142 |
+
{
|
1143 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
1144 |
+
s = 2; // Clamp.
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
*pv = (T)v; // Write value.
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
// Coalesce into threads 0 and 16 of warp.
|
1151 |
+
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
|
1152 |
+
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
|
1153 |
+
s |= __shfl_xor_sync(m, s, 1); // Distribute.
|
1154 |
+
s |= __shfl_xor_sync(m, s, 2);
|
1155 |
+
s |= __shfl_xor_sync(m, s, 4);
|
1156 |
+
s |= __shfl_xor_sync(m, s, 8);
|
1157 |
+
|
1158 |
+
// Write signs if leader and in p.s.
|
1159 |
+
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
|
1160 |
+
{
|
1161 |
+
uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
|
1162 |
+
((uint32_t*)p.s)[is >> 4] = s;
|
1163 |
+
}
|
1164 |
+
}
|
1165 |
+
else if (signRead)
|
1166 |
+
{
|
1167 |
+
// Process value if in p.x.
|
1168 |
+
if (x < p.xShape.x) // y is always in.
|
1169 |
+
{
|
1170 |
+
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1171 |
+
T* pv = ((T*)p.x) + ix;
|
1172 |
+
scalar_t v = (scalar_t)(*pv);
|
1173 |
+
v *= p.gain;
|
1174 |
+
|
1175 |
+
// Apply sign buffer offset.
|
1176 |
+
uint32_t sx = x + p.sOfs.x;
|
1177 |
+
uint32_t sy = y + p.sOfs.y;
|
1178 |
+
|
1179 |
+
// Read and apply signs if we land inside valid region of sign buffer.
|
1180 |
+
if (sx < p.sShape.x && sy < p.sShape.y)
|
1181 |
+
{
|
1182 |
+
uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
|
1183 |
+
unsigned char s = p.s[is];
|
1184 |
+
s >>= (sx & 3) << 1; // Shift into place.
|
1185 |
+
if (s & 1) // Sign?
|
1186 |
+
v *= p.slope;
|
1187 |
+
if (s & 2) // Clamp?
|
1188 |
+
v = 0.f;
|
1189 |
+
}
|
1190 |
+
|
1191 |
+
*pv = (T)v; // Write value.
|
1192 |
+
}
|
1193 |
+
}
|
1194 |
+
else
|
1195 |
+
{
|
1196 |
+
// Forward pass with no sign write. Process value if in p.x.
|
1197 |
+
if (x < p.xShape.x) // y is always in.
|
1198 |
+
{
|
1199 |
+
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1200 |
+
T* pv = ((T*)p.x) + ix;
|
1201 |
+
scalar_t v = (scalar_t)(*pv);
|
1202 |
+
v *= p.gain;
|
1203 |
+
if (v < 0.f)
|
1204 |
+
v *= p.slope;
|
1205 |
+
if (fabsf(v) > p.clamp)
|
1206 |
+
v = InternalType<T>::clamp(v, p.clamp);
|
1207 |
+
*pv = (T)v; // Write value.
|
1208 |
+
}
|
1209 |
+
}
|
1210 |
+
}
|
1211 |
+
}
|
1212 |
+
|
1213 |
+
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
|
1214 |
+
{
|
1215 |
+
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
|
1216 |
+
}
|
1217 |
+
|
1218 |
+
//------------------------------------------------------------------------
|
1219 |
+
// CUDA kernel selection.
|
1220 |
+
|
1221 |
+
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
|
1222 |
+
{
|
1223 |
+
filtered_lrelu_kernel_spec s = { 0 };
|
1224 |
+
|
1225 |
+
// Return the first matching kernel.
|
1226 |
+
#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
|
1227 |
+
if (sharedKB >= SH) \
|
1228 |
+
if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
|
1229 |
+
if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
|
1230 |
+
if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
|
1231 |
+
{ \
|
1232 |
+
static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
|
1233 |
+
static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
|
1234 |
+
static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
|
1235 |
+
s.setup = (void*)setup_filters_kernel; \
|
1236 |
+
s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
|
1237 |
+
s.tileOut = make_int2(TW, TH); \
|
1238 |
+
s.numWarps = W; \
|
1239 |
+
s.xrep = XR; \
|
1240 |
+
s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
|
1241 |
+
return s; \
|
1242 |
+
}
|
1243 |
+
|
1244 |
+
// Launch parameters for various kernel specializations.
|
1245 |
+
// Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
|
1246 |
+
// Kernels that use more shared memory must be listed before those that use less, for the same reason.
|
1247 |
+
|
1248 |
+
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
|
1249 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
|
1250 |
+
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
|
1251 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
|
1252 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
|
1253 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
|
1254 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
|
1255 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
|
1256 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
|
1257 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
|
1258 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
|
1259 |
+
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
|
1260 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
|
1261 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
|
1262 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
|
1263 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
|
1264 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
|
1265 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
|
1266 |
+
CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
|
1267 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
|
1268 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
|
1269 |
+
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
|
1270 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
|
1271 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
|
1272 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
|
1273 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
|
1274 |
+
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
|
1275 |
+
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
|
1276 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
|
1277 |
+
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
|
1278 |
+
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
|
1279 |
+
|
1280 |
+
#undef CASE
|
1281 |
+
return s; // No kernel found.
|
1282 |
+
}
|
1283 |
+
|
1284 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/filtered_lrelu.h
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct filtered_lrelu_kernel_params
|
15 |
+
{
|
16 |
+
// These parameters decide which kernel to use.
|
17 |
+
int up; // upsampling ratio (1, 2, 4)
|
18 |
+
int down; // downsampling ratio (1, 2, 4)
|
19 |
+
int2 fuShape; // [size, 1] | [size, size]
|
20 |
+
int2 fdShape; // [size, 1] | [size, size]
|
21 |
+
|
22 |
+
int _dummy; // Alignment.
|
23 |
+
|
24 |
+
// Rest of the parameters.
|
25 |
+
const void* x; // Input tensor.
|
26 |
+
void* y; // Output tensor.
|
27 |
+
const void* b; // Bias tensor.
|
28 |
+
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
29 |
+
const float* fu; // Upsampling filter.
|
30 |
+
const float* fd; // Downsampling filter.
|
31 |
+
|
32 |
+
int2 pad0; // Left/top padding.
|
33 |
+
float gain; // Additional gain factor.
|
34 |
+
float slope; // Leaky ReLU slope on negative side.
|
35 |
+
float clamp; // Clamp after nonlinearity.
|
36 |
+
int flip; // Filter kernel flip for gradient computation.
|
37 |
+
|
38 |
+
int tilesXdim; // Original number of horizontal output tiles.
|
39 |
+
int tilesXrep; // Number of horizontal tiles per CTA.
|
40 |
+
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
|
41 |
+
|
42 |
+
int4 xShape; // [width, height, channel, batch]
|
43 |
+
int4 yShape; // [width, height, channel, batch]
|
44 |
+
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
|
45 |
+
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
46 |
+
int swLimit; // Active width of sign tensor in bytes.
|
47 |
+
|
48 |
+
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
|
49 |
+
longlong4 yStride; //
|
50 |
+
int64_t bStride; //
|
51 |
+
longlong3 fuStride; //
|
52 |
+
longlong3 fdStride; //
|
53 |
+
};
|
54 |
+
|
55 |
+
struct filtered_lrelu_act_kernel_params
|
56 |
+
{
|
57 |
+
void* x; // Input/output, modified in-place.
|
58 |
+
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
59 |
+
|
60 |
+
float gain; // Additional gain factor.
|
61 |
+
float slope; // Leaky ReLU slope on negative side.
|
62 |
+
float clamp; // Clamp after nonlinearity.
|
63 |
+
|
64 |
+
int4 xShape; // [width, height, channel, batch]
|
65 |
+
longlong4 xStride; // Input/output tensor strides, same order as in shape.
|
66 |
+
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
|
67 |
+
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
68 |
+
};
|
69 |
+
|
70 |
+
//------------------------------------------------------------------------
|
71 |
+
// CUDA kernel specialization.
|
72 |
+
|
73 |
+
struct filtered_lrelu_kernel_spec
|
74 |
+
{
|
75 |
+
void* setup; // Function for filter kernel setup.
|
76 |
+
void* exec; // Function for main operation.
|
77 |
+
int2 tileOut; // Width/height of launch tile.
|
78 |
+
int numWarps; // Number of warps per thread block, determines launch block size.
|
79 |
+
int xrep; // For processing multiple horizontal tiles per thread block.
|
80 |
+
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
|
81 |
+
};
|
82 |
+
|
83 |
+
//------------------------------------------------------------------------
|
84 |
+
// CUDA kernel selection.
|
85 |
+
|
86 |
+
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
|
87 |
+
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
|
88 |
+
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
|
89 |
+
|
90 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/filtered_lrelu.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from .. import custom_ops
|
15 |
+
from .. import misc
|
16 |
+
from . import upfirdn2d
|
17 |
+
from . import bias_act
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
_plugin = None
|
22 |
+
|
23 |
+
def _init():
|
24 |
+
global _plugin
|
25 |
+
if _plugin is None:
|
26 |
+
_plugin = custom_ops.get_plugin(
|
27 |
+
module_name='filtered_lrelu_plugin',
|
28 |
+
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
|
29 |
+
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
|
30 |
+
source_dir=os.path.dirname(__file__),
|
31 |
+
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
32 |
+
)
|
33 |
+
return True
|
34 |
+
|
35 |
+
def _get_filter_size(f):
|
36 |
+
if f is None:
|
37 |
+
return 1, 1
|
38 |
+
assert isinstance(f, torch.Tensor)
|
39 |
+
assert 1 <= f.ndim <= 2
|
40 |
+
return f.shape[-1], f.shape[0] # width, height
|
41 |
+
|
42 |
+
def _parse_padding(padding):
|
43 |
+
if isinstance(padding, int):
|
44 |
+
padding = [padding, padding]
|
45 |
+
assert isinstance(padding, (list, tuple))
|
46 |
+
assert all(isinstance(x, (int, np.integer)) for x in padding)
|
47 |
+
padding = [int(x) for x in padding]
|
48 |
+
if len(padding) == 2:
|
49 |
+
px, py = padding
|
50 |
+
padding = [px, px, py, py]
|
51 |
+
px0, px1, py0, py1 = padding
|
52 |
+
return px0, px1, py0, py1
|
53 |
+
|
54 |
+
#----------------------------------------------------------------------------
|
55 |
+
|
56 |
+
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
|
57 |
+
r"""Filtered leaky ReLU for a batch of 2D images.
|
58 |
+
|
59 |
+
Performs the following sequence of operations for each channel:
|
60 |
+
|
61 |
+
1. Add channel-specific bias if provided (`b`).
|
62 |
+
|
63 |
+
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
64 |
+
|
65 |
+
3. Pad the image with the specified number of zeros on each side (`padding`).
|
66 |
+
Negative padding corresponds to cropping the image.
|
67 |
+
|
68 |
+
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
|
69 |
+
so that the footprint of all output pixels lies within the input image.
|
70 |
+
|
71 |
+
5. Multiply each value by the provided gain factor (`gain`).
|
72 |
+
|
73 |
+
6. Apply leaky ReLU activation function to each value.
|
74 |
+
|
75 |
+
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
|
76 |
+
|
77 |
+
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
|
78 |
+
it so that the footprint of all output pixels lies within the input image.
|
79 |
+
|
80 |
+
9. Downsample the image by keeping every Nth pixel (`down`).
|
81 |
+
|
82 |
+
The fused op is considerably more efficient than performing the same calculation
|
83 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
x: Float32/float16/float64 input tensor of the shape
|
87 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
88 |
+
fu: Float32 upsampling FIR filter of the shape
|
89 |
+
`[filter_height, filter_width]` (non-separable),
|
90 |
+
`[filter_taps]` (separable), or
|
91 |
+
`None` (identity).
|
92 |
+
fd: Float32 downsampling FIR filter of the shape
|
93 |
+
`[filter_height, filter_width]` (non-separable),
|
94 |
+
`[filter_taps]` (separable), or
|
95 |
+
`None` (identity).
|
96 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
97 |
+
as `x`. The length of vector must must match the channel dimension of `x`.
|
98 |
+
up: Integer upsampling factor (default: 1).
|
99 |
+
down: Integer downsampling factor. (default: 1).
|
100 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
101 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
102 |
+
(default: 0).
|
103 |
+
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
|
104 |
+
slope: Slope on the negative side of leaky ReLU (default: 0.2).
|
105 |
+
clamp: Maximum magnitude for leaky ReLU output (default: None).
|
106 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
107 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
111 |
+
"""
|
112 |
+
assert isinstance(x, torch.Tensor)
|
113 |
+
assert impl in ['ref', 'cuda']
|
114 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
115 |
+
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
|
116 |
+
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
@misc.profiled_function
|
121 |
+
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
122 |
+
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
|
123 |
+
existing `upfirdn2n()` and `bias_act()` ops.
|
124 |
+
"""
|
125 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
126 |
+
fu_w, fu_h = _get_filter_size(fu)
|
127 |
+
fd_w, fd_h = _get_filter_size(fd)
|
128 |
+
if b is not None:
|
129 |
+
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
|
130 |
+
misc.assert_shape(b, [x.shape[1]])
|
131 |
+
assert isinstance(up, int) and up >= 1
|
132 |
+
assert isinstance(down, int) and down >= 1
|
133 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
134 |
+
assert gain == float(gain) and gain > 0
|
135 |
+
assert slope == float(slope) and slope >= 0
|
136 |
+
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
137 |
+
|
138 |
+
# Calculate output size.
|
139 |
+
batch_size, channels, in_h, in_w = x.shape
|
140 |
+
in_dtype = x.dtype
|
141 |
+
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
|
142 |
+
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
|
143 |
+
|
144 |
+
# Compute using existing ops.
|
145 |
+
x = bias_act.bias_act(x=x, b=b) # Apply bias.
|
146 |
+
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
147 |
+
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
|
148 |
+
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
149 |
+
|
150 |
+
# Check output shape & dtype.
|
151 |
+
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
|
152 |
+
assert x.dtype == in_dtype
|
153 |
+
return x
|
154 |
+
|
155 |
+
#----------------------------------------------------------------------------
|
156 |
+
|
157 |
+
_filtered_lrelu_cuda_cache = dict()
|
158 |
+
|
159 |
+
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
160 |
+
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
|
161 |
+
"""
|
162 |
+
assert isinstance(up, int) and up >= 1
|
163 |
+
assert isinstance(down, int) and down >= 1
|
164 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
165 |
+
assert gain == float(gain) and gain > 0
|
166 |
+
gain = float(gain)
|
167 |
+
assert slope == float(slope) and slope >= 0
|
168 |
+
slope = float(slope)
|
169 |
+
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
170 |
+
clamp = float(clamp if clamp is not None else 'inf')
|
171 |
+
|
172 |
+
# Lookup from cache.
|
173 |
+
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
|
174 |
+
if key in _filtered_lrelu_cuda_cache:
|
175 |
+
return _filtered_lrelu_cuda_cache[key]
|
176 |
+
|
177 |
+
# Forward op.
|
178 |
+
class FilteredLReluCuda(torch.autograd.Function):
|
179 |
+
@staticmethod
|
180 |
+
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
|
181 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
182 |
+
|
183 |
+
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
|
184 |
+
if fu is None:
|
185 |
+
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
186 |
+
if fd is None:
|
187 |
+
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
188 |
+
assert 1 <= fu.ndim <= 2
|
189 |
+
assert 1 <= fd.ndim <= 2
|
190 |
+
|
191 |
+
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
|
192 |
+
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
|
193 |
+
fu = fu.square()[None]
|
194 |
+
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
|
195 |
+
fd = fd.square()[None]
|
196 |
+
|
197 |
+
# Missing sign input tensor.
|
198 |
+
if si is None:
|
199 |
+
si = torch.empty([0])
|
200 |
+
|
201 |
+
# Missing bias tensor.
|
202 |
+
if b is None:
|
203 |
+
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
|
204 |
+
|
205 |
+
# Construct internal sign tensor only if gradients are needed.
|
206 |
+
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
|
207 |
+
|
208 |
+
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
|
209 |
+
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
|
210 |
+
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
|
211 |
+
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
|
212 |
+
|
213 |
+
# Call C++/Cuda plugin if datatype is supported.
|
214 |
+
if x.dtype in [torch.float16, torch.float32]:
|
215 |
+
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
|
216 |
+
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
|
217 |
+
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
|
218 |
+
else:
|
219 |
+
return_code = -1
|
220 |
+
|
221 |
+
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
|
222 |
+
# only the bit-packed sign tensor is retained for gradient computation.
|
223 |
+
if return_code < 0:
|
224 |
+
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
|
225 |
+
|
226 |
+
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
|
227 |
+
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
228 |
+
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
|
229 |
+
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
230 |
+
|
231 |
+
# Prepare for gradient computation.
|
232 |
+
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
|
233 |
+
ctx.x_shape = x.shape
|
234 |
+
ctx.y_shape = y.shape
|
235 |
+
ctx.s_ofs = sx, sy
|
236 |
+
return y
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
240 |
+
fu, fd, si = ctx.saved_tensors
|
241 |
+
_, _, xh, xw = ctx.x_shape
|
242 |
+
_, _, yh, yw = ctx.y_shape
|
243 |
+
sx, sy = ctx.s_ofs
|
244 |
+
dx = None # 0
|
245 |
+
dfu = None; assert not ctx.needs_input_grad[1]
|
246 |
+
dfd = None; assert not ctx.needs_input_grad[2]
|
247 |
+
db = None # 3
|
248 |
+
dsi = None; assert not ctx.needs_input_grad[4]
|
249 |
+
dsx = None; assert not ctx.needs_input_grad[5]
|
250 |
+
dsy = None; assert not ctx.needs_input_grad[6]
|
251 |
+
|
252 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
|
253 |
+
pp = [
|
254 |
+
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
|
255 |
+
xw * up - yw * down + px0 - (up - 1),
|
256 |
+
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
|
257 |
+
xh * up - yh * down + py0 - (up - 1),
|
258 |
+
]
|
259 |
+
gg = gain * (up ** 2) / (down ** 2)
|
260 |
+
ff = (not flip_filter)
|
261 |
+
sx = sx - (fu.shape[-1] - 1) + px0
|
262 |
+
sy = sy - (fu.shape[0] - 1) + py0
|
263 |
+
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
|
264 |
+
|
265 |
+
if ctx.needs_input_grad[3]:
|
266 |
+
db = dx.sum([0, 2, 3])
|
267 |
+
|
268 |
+
return dx, dfu, dfd, db, dsi, dsx, dsy
|
269 |
+
|
270 |
+
# Add to cache.
|
271 |
+
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
|
272 |
+
return FilteredLReluCuda
|
273 |
+
|
274 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/filtered_lrelu_ns.cu
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include "filtered_lrelu.cu"
|
10 |
+
|
11 |
+
// Template/kernel specializations for no signs mode (no gradients required).
|
12 |
+
|
13 |
+
// Full op, 32-bit indexing.
|
14 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
+
|
17 |
+
// Full op, 64-bit indexing.
|
18 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
+
|
21 |
+
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
+
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
|
23 |
+
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
|
24 |
+
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
|
25 |
+
|
26 |
+
// Copy filters to constant memory.
|
27 |
+
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
|
ADD/th_utils/ops/filtered_lrelu_rd.cu
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include "filtered_lrelu.cu"
|
10 |
+
|
11 |
+
// Template/kernel specializations for sign read mode.
|
12 |
+
|
13 |
+
// Full op, 32-bit indexing.
|
14 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
+
|
17 |
+
// Full op, 64-bit indexing.
|
18 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
+
|
21 |
+
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
+
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
|
23 |
+
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
|
24 |
+
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
|
25 |
+
|
26 |
+
// Copy filters to constant memory.
|
27 |
+
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
|
ADD/th_utils/ops/filtered_lrelu_wr.cu
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include "filtered_lrelu.cu"
|
10 |
+
|
11 |
+
// Template/kernel specializations for sign write mode.
|
12 |
+
|
13 |
+
// Full op, 32-bit indexing.
|
14 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
+
|
17 |
+
// Full op, 64-bit indexing.
|
18 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
+
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
+
|
21 |
+
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
+
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
|
23 |
+
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
|
24 |
+
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
|
25 |
+
|
26 |
+
// Copy filters to constant memory.
|
27 |
+
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
|
ADD/th_utils/ops/fma.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
def fma(a, b, c): # => a * b + c
|
16 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
23 |
+
out = torch.addcmul(c, a, b)
|
24 |
+
ctx.save_for_backward(a, b)
|
25 |
+
ctx.c_shape = c.shape
|
26 |
+
return out
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
30 |
+
a, b = ctx.saved_tensors
|
31 |
+
c_shape = ctx.c_shape
|
32 |
+
da = None
|
33 |
+
db = None
|
34 |
+
dc = None
|
35 |
+
|
36 |
+
if ctx.needs_input_grad[0]:
|
37 |
+
da = _unbroadcast(dout * b, a.shape)
|
38 |
+
|
39 |
+
if ctx.needs_input_grad[1]:
|
40 |
+
db = _unbroadcast(dout * a, b.shape)
|
41 |
+
|
42 |
+
if ctx.needs_input_grad[2]:
|
43 |
+
dc = _unbroadcast(dout, c_shape)
|
44 |
+
|
45 |
+
return da, db, dc
|
46 |
+
|
47 |
+
#----------------------------------------------------------------------------
|
48 |
+
|
49 |
+
def _unbroadcast(x, shape):
|
50 |
+
extra_dims = x.ndim - len(shape)
|
51 |
+
assert extra_dims >= 0
|
52 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
53 |
+
if len(dim):
|
54 |
+
x = x.sum(dim=dim, keepdim=True)
|
55 |
+
if extra_dims:
|
56 |
+
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
57 |
+
assert x.shape == shape
|
58 |
+
return x
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/grid_sample_gradfix.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
10 |
+
supports arbitrarily high order gradients between the input and output.
|
11 |
+
Only works on 2D images and assumes
|
12 |
+
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from pkg_resources import parse_version
|
16 |
+
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
enabled = False # Enable the custom op by setting this to true.
|
24 |
+
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
def grid_sample(input, grid):
|
29 |
+
if _should_use_custom_op():
|
30 |
+
return _GridSample2dForward.apply(input, grid)
|
31 |
+
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def _should_use_custom_op():
|
36 |
+
return enabled
|
37 |
+
|
38 |
+
#----------------------------------------------------------------------------
|
39 |
+
|
40 |
+
class _GridSample2dForward(torch.autograd.Function):
|
41 |
+
@staticmethod
|
42 |
+
def forward(ctx, input, grid):
|
43 |
+
assert input.ndim == 4
|
44 |
+
assert grid.ndim == 4
|
45 |
+
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
46 |
+
ctx.save_for_backward(input, grid)
|
47 |
+
return output
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_output):
|
51 |
+
input, grid = ctx.saved_tensors
|
52 |
+
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
53 |
+
return grad_input, grad_grid
|
54 |
+
|
55 |
+
#----------------------------------------------------------------------------
|
56 |
+
|
57 |
+
class _GridSample2dBackward(torch.autograd.Function):
|
58 |
+
@staticmethod
|
59 |
+
def forward(ctx, grad_output, input, grid):
|
60 |
+
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
61 |
+
if _use_pytorch_1_11_api:
|
62 |
+
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
|
63 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
|
64 |
+
else:
|
65 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
66 |
+
ctx.save_for_backward(grid)
|
67 |
+
return grad_input, grad_grid
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
71 |
+
_ = grad2_grad_grid # unused
|
72 |
+
grid, = ctx.saved_tensors
|
73 |
+
grad2_grad_output = None
|
74 |
+
grad2_input = None
|
75 |
+
grad2_grid = None
|
76 |
+
|
77 |
+
if ctx.needs_input_grad[0]:
|
78 |
+
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
79 |
+
|
80 |
+
assert not ctx.needs_input_grad[2]
|
81 |
+
return grad2_grad_output, grad2_input, grad2_grid
|
82 |
+
|
83 |
+
#----------------------------------------------------------------------------
|
ADD/th_utils/ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.numel() > 0, "x has zero size");
|
25 |
+
TORCH_CHECK(f.numel() > 0, "f has zero size");
|
26 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
27 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
28 |
+
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
|
29 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
30 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
31 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
32 |
+
|
33 |
+
// Create output tensor.
|
34 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
35 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
36 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
37 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
38 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
39 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
40 |
+
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
|
41 |
+
|
42 |
+
// Initialize CUDA kernel parameters.
|
43 |
+
upfirdn2d_kernel_params p;
|
44 |
+
p.x = x.data_ptr();
|
45 |
+
p.f = f.data_ptr<float>();
|
46 |
+
p.y = y.data_ptr();
|
47 |
+
p.up = make_int2(upx, upy);
|
48 |
+
p.down = make_int2(downx, downy);
|
49 |
+
p.pad0 = make_int2(padx0, pady0);
|
50 |
+
p.flip = (flip) ? 1 : 0;
|
51 |
+
p.gain = gain;
|
52 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
53 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
54 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
55 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
56 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
57 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
58 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
59 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
60 |
+
|
61 |
+
// Choose CUDA kernel.
|
62 |
+
upfirdn2d_kernel_spec spec;
|
63 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
64 |
+
{
|
65 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
66 |
+
});
|
67 |
+
|
68 |
+
// Set looping options.
|
69 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
70 |
+
p.loopMinor = spec.loopMinor;
|
71 |
+
p.loopX = spec.loopX;
|
72 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
73 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
74 |
+
|
75 |
+
// Compute grid size.
|
76 |
+
dim3 blockSize, gridSize;
|
77 |
+
if (spec.tileOutW < 0) // large
|
78 |
+
{
|
79 |
+
blockSize = dim3(4, 32, 1);
|
80 |
+
gridSize = dim3(
|
81 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
82 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
83 |
+
p.launchMajor);
|
84 |
+
}
|
85 |
+
else // small
|
86 |
+
{
|
87 |
+
blockSize = dim3(256, 1, 1);
|
88 |
+
gridSize = dim3(
|
89 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
90 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
91 |
+
p.launchMajor);
|
92 |
+
}
|
93 |
+
|
94 |
+
// Launch CUDA kernel.
|
95 |
+
void* args[] = {&p};
|
96 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
97 |
+
return y;
|
98 |
+
}
|
99 |
+
|
100 |
+
//------------------------------------------------------------------------
|
101 |
+
|
102 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
103 |
+
{
|
104 |
+
m.def("upfirdn2d", &upfirdn2d);
|
105 |
+
}
|
106 |
+
|
107 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "upfirdn2d.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
+
{
|
22 |
+
int t = 1 - a / b;
|
23 |
+
return (a + t * b) / b - t;
|
24 |
+
}
|
25 |
+
|
26 |
+
//------------------------------------------------------------------------
|
27 |
+
// Generic CUDA implementation for large filters.
|
28 |
+
|
29 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
+
{
|
31 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
+
|
33 |
+
// Calculate thread index.
|
34 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
+
int outY = minorBase / p.launchMinor;
|
36 |
+
minorBase -= outY * p.launchMinor;
|
37 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
+
return;
|
41 |
+
|
42 |
+
// Setup Y receptive field.
|
43 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
+
if (p.flip)
|
48 |
+
filterY = p.filterSize.y - 1 - filterY;
|
49 |
+
|
50 |
+
// Loop over major, minor, and X.
|
51 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
+
{
|
54 |
+
int nc = major * p.sizeMinor + minor;
|
55 |
+
int n = nc / p.inSize.z;
|
56 |
+
int c = nc - n * p.inSize.z;
|
57 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
+
{
|
59 |
+
// Setup X receptive field.
|
60 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
+
if (p.flip)
|
65 |
+
filterX = p.filterSize.x - 1 - filterX;
|
66 |
+
|
67 |
+
// Initialize pointers.
|
68 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
+
|
73 |
+
// Inner loop.
|
74 |
+
scalar_t v = 0;
|
75 |
+
for (int y = 0; y < h; y++)
|
76 |
+
{
|
77 |
+
for (int x = 0; x < w; x++)
|
78 |
+
{
|
79 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
+
xp += p.inStride.x;
|
81 |
+
fp += filterStepX;
|
82 |
+
}
|
83 |
+
xp += p.inStride.y - w * p.inStride.x;
|
84 |
+
fp += filterStepY - w * filterStepX;
|
85 |
+
}
|
86 |
+
|
87 |
+
// Store result.
|
88 |
+
v *= p.gain;
|
89 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
//------------------------------------------------------------------------
|
95 |
+
// Specialized CUDA implementation for small filters.
|
96 |
+
|
97 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
+
{
|
100 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
+
|
106 |
+
// Calculate tile index.
|
107 |
+
int minorBase = blockIdx.x;
|
108 |
+
int tileOutY = minorBase / p.launchMinor;
|
109 |
+
minorBase -= tileOutY * p.launchMinor;
|
110 |
+
minorBase *= loopMinor;
|
111 |
+
tileOutY *= tileOutH;
|
112 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
+
return;
|
116 |
+
|
117 |
+
// Load filter (flipped).
|
118 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
+
{
|
120 |
+
int fy = tapIdx / filterW;
|
121 |
+
int fx = tapIdx - fy * filterW;
|
122 |
+
scalar_t v = 0;
|
123 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
+
{
|
125 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
+
}
|
129 |
+
sf[fy][fx] = v;
|
130 |
+
}
|
131 |
+
|
132 |
+
// Loop over major and X.
|
133 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
+
{
|
135 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
+
int n = baseNC / p.inSize.z;
|
137 |
+
int baseC = baseNC - n * p.inSize.z;
|
138 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
+
{
|
140 |
+
// Load input pixels.
|
141 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
+
int tileInX = floor_div(tileMidX, upx);
|
144 |
+
int tileInY = floor_div(tileMidY, upy);
|
145 |
+
__syncthreads();
|
146 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
+
{
|
148 |
+
int relC = inIdx;
|
149 |
+
int relInX = relC / loopMinor;
|
150 |
+
int relInY = relInX / tileInW;
|
151 |
+
relC -= relInX * loopMinor;
|
152 |
+
relInX -= relInY * tileInW;
|
153 |
+
int c = baseC + relC;
|
154 |
+
int inX = tileInX + relInX;
|
155 |
+
int inY = tileInY + relInY;
|
156 |
+
scalar_t v = 0;
|
157 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
+
sx[relInY][relInX][relC] = v;
|
160 |
+
}
|
161 |
+
|
162 |
+
// Loop over output pixels.
|
163 |
+
__syncthreads();
|
164 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
+
{
|
166 |
+
int relC = outIdx;
|
167 |
+
int relOutX = relC / loopMinor;
|
168 |
+
int relOutY = relOutX / tileOutW;
|
169 |
+
relC -= relOutX * loopMinor;
|
170 |
+
relOutX -= relOutY * tileOutW;
|
171 |
+
int c = baseC + relC;
|
172 |
+
int outX = tileOutX + relOutX;
|
173 |
+
int outY = tileOutY + relOutY;
|
174 |
+
|
175 |
+
// Setup receptive field.
|
176 |
+
int midX = tileMidX + relOutX * downx;
|
177 |
+
int midY = tileMidY + relOutY * downy;
|
178 |
+
int inX = floor_div(midX, upx);
|
179 |
+
int inY = floor_div(midY, upy);
|
180 |
+
int relInX = inX - tileInX;
|
181 |
+
int relInY = inY - tileInY;
|
182 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
+
|
185 |
+
// Inner loop.
|
186 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
+
{
|
188 |
+
scalar_t v = 0;
|
189 |
+
#pragma unroll
|
190 |
+
for (int y = 0; y < filterH / upy; y++)
|
191 |
+
#pragma unroll
|
192 |
+
for (int x = 0; x < filterW / upx; x++)
|
193 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
+
v *= p.gain;
|
195 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
//------------------------------------------------------------------------
|
203 |
+
// CUDA kernel selection.
|
204 |
+
|
205 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
+
{
|
207 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
209 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
210 |
+
|
211 |
+
// No up/downsampling.
|
212 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
213 |
+
{
|
214 |
+
// contiguous
|
215 |
+
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
216 |
+
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
217 |
+
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
220 |
+
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
221 |
+
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
222 |
+
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
225 |
+
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
226 |
+
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
228 |
+
// channels_last
|
229 |
+
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
230 |
+
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
231 |
+
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
232 |
+
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
233 |
+
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
|
234 |
+
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
237 |
+
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
238 |
+
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
239 |
+
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
240 |
+
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
241 |
+
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
242 |
+
}
|
243 |
+
|
244 |
+
// 2x upsampling.
|
245 |
+
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
246 |
+
{
|
247 |
+
// contiguous
|
248 |
+
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
249 |
+
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
250 |
+
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
+
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
+
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
+
// channels_last
|
255 |
+
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
256 |
+
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
257 |
+
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
+
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
+
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
+
}
|
262 |
+
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
263 |
+
{
|
264 |
+
// contiguous
|
265 |
+
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
266 |
+
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
// channels_last
|
269 |
+
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
270 |
+
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
271 |
+
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
272 |
+
}
|
273 |
+
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
274 |
+
{
|
275 |
+
// contiguous
|
276 |
+
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
277 |
+
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
278 |
+
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
279 |
+
// channels_last
|
280 |
+
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
281 |
+
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
282 |
+
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
283 |
+
}
|
284 |
+
|
285 |
+
// 2x downsampling.
|
286 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
|
287 |
+
{
|
288 |
+
// contiguous
|
289 |
+
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
|
290 |
+
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
|
291 |
+
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
292 |
+
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
293 |
+
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
294 |
+
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
295 |
+
// channels_last
|
296 |
+
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
|
297 |
+
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
|
298 |
+
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
299 |
+
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
300 |
+
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
301 |
+
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
302 |
+
}
|
303 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
|
304 |
+
{
|
305 |
+
// contiguous
|
306 |
+
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
307 |
+
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
308 |
+
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
309 |
+
// channels_last
|
310 |
+
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
311 |
+
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
312 |
+
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
313 |
+
}
|
314 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
|
315 |
+
{
|
316 |
+
// contiguous
|
317 |
+
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
318 |
+
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
319 |
+
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
320 |
+
// channels_last
|
321 |
+
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
322 |
+
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
323 |
+
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
324 |
+
}
|
325 |
+
|
326 |
+
// 4x upsampling.
|
327 |
+
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
328 |
+
{
|
329 |
+
// contiguous
|
330 |
+
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
|
331 |
+
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
|
332 |
+
// channels_last
|
333 |
+
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
|
334 |
+
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
|
335 |
+
}
|
336 |
+
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
337 |
+
{
|
338 |
+
// contiguous
|
339 |
+
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
|
340 |
+
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
|
341 |
+
// channels_last
|
342 |
+
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
|
343 |
+
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
|
344 |
+
}
|
345 |
+
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
346 |
+
{
|
347 |
+
// contiguous
|
348 |
+
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
|
349 |
+
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
|
350 |
+
// channels_last
|
351 |
+
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
|
352 |
+
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
|
353 |
+
}
|
354 |
+
|
355 |
+
// 4x downsampling (inefficient).
|
356 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
|
357 |
+
{
|
358 |
+
// contiguous
|
359 |
+
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
|
360 |
+
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
|
361 |
+
// channels_last
|
362 |
+
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
|
363 |
+
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
|
364 |
+
}
|
365 |
+
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
|
366 |
+
{
|
367 |
+
// contiguous
|
368 |
+
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
|
369 |
+
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
|
370 |
+
// channels_last
|
371 |
+
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
|
372 |
+
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
|
373 |
+
}
|
374 |
+
return spec;
|
375 |
+
}
|
376 |
+
|
377 |
+
//------------------------------------------------------------------------
|
378 |
+
// Template specializations.
|
379 |
+
|
380 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
381 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
382 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
383 |
+
|
384 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
ADD/th_utils/ops/upfirdn2d.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from .. import custom_ops
|
16 |
+
from .. import misc
|
17 |
+
from . import conv2d_gradfix
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
_plugin = None
|
22 |
+
|
23 |
+
def _init():
|
24 |
+
global _plugin
|
25 |
+
if _plugin is None:
|
26 |
+
_plugin = custom_ops.get_plugin(
|
27 |
+
module_name='upfirdn2d_plugin',
|
28 |
+
sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
|
29 |
+
headers=['upfirdn2d.h'],
|
30 |
+
source_dir=os.path.dirname(__file__),
|
31 |
+
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
32 |
+
)
|
33 |
+
return True
|
34 |
+
|
35 |
+
def _parse_scaling(scaling):
|
36 |
+
if isinstance(scaling, int):
|
37 |
+
scaling = [scaling, scaling]
|
38 |
+
assert isinstance(scaling, (list, tuple))
|
39 |
+
assert all(isinstance(x, int) for x in scaling)
|
40 |
+
sx, sy = scaling
|
41 |
+
assert sx >= 1 and sy >= 1
|
42 |
+
return sx, sy
|
43 |
+
|
44 |
+
def _parse_padding(padding):
|
45 |
+
if isinstance(padding, int):
|
46 |
+
padding = [padding, padding]
|
47 |
+
assert isinstance(padding, (list, tuple))
|
48 |
+
assert all(isinstance(x, int) for x in padding)
|
49 |
+
if len(padding) == 2:
|
50 |
+
padx, pady = padding
|
51 |
+
padding = [padx, padx, pady, pady]
|
52 |
+
padx0, padx1, pady0, pady1 = padding
|
53 |
+
return padx0, padx1, pady0, pady1
|
54 |
+
|
55 |
+
def _get_filter_size(f):
|
56 |
+
if f is None:
|
57 |
+
return 1, 1
|
58 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
59 |
+
fw = f.shape[-1]
|
60 |
+
fh = f.shape[0]
|
61 |
+
with misc.suppress_tracer_warnings():
|
62 |
+
fw = int(fw)
|
63 |
+
fh = int(fh)
|
64 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
65 |
+
assert fw >= 1 and fh >= 1
|
66 |
+
return fw, fh
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
|
70 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
71 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
f: Torch tensor, numpy array, or python list of the shape
|
75 |
+
`[filter_height, filter_width]` (non-separable),
|
76 |
+
`[filter_taps]` (separable),
|
77 |
+
`[]` (impulse), or
|
78 |
+
`None` (identity).
|
79 |
+
device: Result device (default: cpu).
|
80 |
+
normalize: Normalize the filter so that it retains the magnitude
|
81 |
+
for constant input signal (DC)? (default: True).
|
82 |
+
flip_filter: Flip the filter? (default: False).
|
83 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
84 |
+
separable: Return a separable filter? (default: select automatically).
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Float32 tensor of the shape
|
88 |
+
`[filter_height, filter_width]` (non-separable) or
|
89 |
+
`[filter_taps]` (separable).
|
90 |
+
"""
|
91 |
+
# Validate.
|
92 |
+
if f is None:
|
93 |
+
f = 1
|
94 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
95 |
+
assert f.ndim in [0, 1, 2]
|
96 |
+
assert f.numel() > 0
|
97 |
+
if f.ndim == 0:
|
98 |
+
f = f[np.newaxis]
|
99 |
+
|
100 |
+
# Separable?
|
101 |
+
if separable is None:
|
102 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
103 |
+
if f.ndim == 1 and not separable:
|
104 |
+
f = f.ger(f)
|
105 |
+
assert f.ndim == (1 if separable else 2)
|
106 |
+
|
107 |
+
# Apply normalize, flip, gain, and device.
|
108 |
+
if normalize:
|
109 |
+
f /= f.sum()
|
110 |
+
if flip_filter:
|
111 |
+
f = f.flip(list(range(f.ndim)))
|
112 |
+
f = f * (gain ** (f.ndim / 2))
|
113 |
+
f = f.to(device=device)
|
114 |
+
return f
|
115 |
+
|
116 |
+
#----------------------------------------------------------------------------
|
117 |
+
|
118 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
119 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
120 |
+
|
121 |
+
Performs the following sequence of operations for each channel:
|
122 |
+
|
123 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
124 |
+
|
125 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
126 |
+
Negative padding corresponds to cropping the image.
|
127 |
+
|
128 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
129 |
+
so that the footprint of all output pixels lies within the input image.
|
130 |
+
|
131 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
132 |
+
|
133 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
134 |
+
The fused op is considerably more efficient than performing the same calculation
|
135 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x: Float32/float64/float16 input tensor of the shape
|
139 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
140 |
+
f: Float32 FIR filter of the shape
|
141 |
+
`[filter_height, filter_width]` (non-separable),
|
142 |
+
`[filter_taps]` (separable), or
|
143 |
+
`None` (identity).
|
144 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
145 |
+
`[x, y]` (default: 1).
|
146 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
147 |
+
`[x, y]` (default: 1).
|
148 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
149 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
150 |
+
(default: 0).
|
151 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
152 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
153 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
157 |
+
"""
|
158 |
+
assert isinstance(x, torch.Tensor)
|
159 |
+
assert impl in ['ref', 'cuda']
|
160 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
161 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
162 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
163 |
+
|
164 |
+
#----------------------------------------------------------------------------
|
165 |
+
|
166 |
+
@misc.profiled_function
|
167 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
168 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
169 |
+
"""
|
170 |
+
# Validate arguments.
|
171 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
172 |
+
if f is None:
|
173 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
174 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
175 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
176 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
177 |
+
upx, upy = _parse_scaling(up)
|
178 |
+
downx, downy = _parse_scaling(down)
|
179 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
180 |
+
|
181 |
+
# Check that upsampled buffer is not smaller than the filter.
|
182 |
+
upW = in_width * upx + padx0 + padx1
|
183 |
+
upH = in_height * upy + pady0 + pady1
|
184 |
+
assert upW >= f.shape[-1] and upH >= f.shape[0]
|
185 |
+
|
186 |
+
# Upsample by inserting zeros.
|
187 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
188 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
189 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
190 |
+
|
191 |
+
# Pad or crop.
|
192 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
193 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
194 |
+
|
195 |
+
# Setup filter.
|
196 |
+
f = f * (gain ** (f.ndim / 2))
|
197 |
+
f = f.to(x.dtype)
|
198 |
+
if not flip_filter:
|
199 |
+
f = f.flip(list(range(f.ndim)))
|
200 |
+
|
201 |
+
# Convolve with the filter.
|
202 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
203 |
+
if f.ndim == 4:
|
204 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
205 |
+
else:
|
206 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
207 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
208 |
+
|
209 |
+
# Downsample by throwing away pixels.
|
210 |
+
x = x[:, :, ::downy, ::downx]
|
211 |
+
return x
|
212 |
+
|
213 |
+
#----------------------------------------------------------------------------
|
214 |
+
|
215 |
+
_upfirdn2d_cuda_cache = dict()
|
216 |
+
|
217 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
218 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
219 |
+
"""
|
220 |
+
# Parse arguments.
|
221 |
+
upx, upy = _parse_scaling(up)
|
222 |
+
downx, downy = _parse_scaling(down)
|
223 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
224 |
+
|
225 |
+
# Lookup from cache.
|
226 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
227 |
+
if key in _upfirdn2d_cuda_cache:
|
228 |
+
return _upfirdn2d_cuda_cache[key]
|
229 |
+
|
230 |
+
# Forward op.
|
231 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
232 |
+
@staticmethod
|
233 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
234 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
235 |
+
if f is None:
|
236 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
237 |
+
if f.ndim == 1 and f.shape[0] == 1:
|
238 |
+
f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
|
239 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
240 |
+
y = x
|
241 |
+
if f.ndim == 2:
|
242 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
243 |
+
else:
|
244 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
|
245 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
|
246 |
+
ctx.save_for_backward(f)
|
247 |
+
ctx.x_shape = x.shape
|
248 |
+
return y
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
252 |
+
f, = ctx.saved_tensors
|
253 |
+
_, _, ih, iw = ctx.x_shape
|
254 |
+
_, _, oh, ow = dy.shape
|
255 |
+
fw, fh = _get_filter_size(f)
|
256 |
+
p = [
|
257 |
+
fw - padx0 - 1,
|
258 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
259 |
+
fh - pady0 - 1,
|
260 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
261 |
+
]
|
262 |
+
dx = None
|
263 |
+
df = None
|
264 |
+
|
265 |
+
if ctx.needs_input_grad[0]:
|
266 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
267 |
+
|
268 |
+
assert not ctx.needs_input_grad[1]
|
269 |
+
return dx, df
|
270 |
+
|
271 |
+
# Add to cache.
|
272 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
273 |
+
return Upfirdn2dCuda
|
274 |
+
|
275 |
+
#----------------------------------------------------------------------------
|
276 |
+
|
277 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
278 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
279 |
+
|
280 |
+
By default, the result is padded so that its shape matches the input.
|
281 |
+
User-specified padding is applied on top of that, with negative values
|
282 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
x: Float32/float64/float16 input tensor of the shape
|
286 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
287 |
+
f: Float32 FIR filter of the shape
|
288 |
+
`[filter_height, filter_width]` (non-separable),
|
289 |
+
`[filter_taps]` (separable), or
|
290 |
+
`None` (identity).
|
291 |
+
padding: Padding with respect to the output. Can be a single number or a
|
292 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
293 |
+
(default: 0).
|
294 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
295 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
296 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
300 |
+
"""
|
301 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
302 |
+
fw, fh = _get_filter_size(f)
|
303 |
+
p = [
|
304 |
+
padx0 + fw // 2,
|
305 |
+
padx1 + (fw - 1) // 2,
|
306 |
+
pady0 + fh // 2,
|
307 |
+
pady1 + (fh - 1) // 2,
|
308 |
+
]
|
309 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
310 |
+
|
311 |
+
#----------------------------------------------------------------------------
|
312 |
+
|
313 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
314 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
315 |
+
|
316 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
317 |
+
User-specified padding is applied on top of that, with negative values
|
318 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
x: Float32/float64/float16 input tensor of the shape
|
322 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
323 |
+
f: Float32 FIR filter of the shape
|
324 |
+
`[filter_height, filter_width]` (non-separable),
|
325 |
+
`[filter_taps]` (separable), or
|
326 |
+
`None` (identity).
|
327 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
328 |
+
`[x, y]` (default: 1).
|
329 |
+
padding: Padding with respect to the output. Can be a single number or a
|
330 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
331 |
+
(default: 0).
|
332 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
333 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
334 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
338 |
+
"""
|
339 |
+
upx, upy = _parse_scaling(up)
|
340 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
341 |
+
fw, fh = _get_filter_size(f)
|
342 |
+
p = [
|
343 |
+
padx0 + (fw + upx - 1) // 2,
|
344 |
+
padx1 + (fw - upx) // 2,
|
345 |
+
pady0 + (fh + upy - 1) // 2,
|
346 |
+
pady1 + (fh - upy) // 2,
|
347 |
+
]
|
348 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
349 |
+
|
350 |
+
#----------------------------------------------------------------------------
|
351 |
+
|
352 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
353 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
354 |
+
|
355 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
356 |
+
User-specified padding is applied on top of that, with negative values
|
357 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
x: Float32/float64/float16 input tensor of the shape
|
361 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
362 |
+
f: Float32 FIR filter of the shape
|
363 |
+
`[filter_height, filter_width]` (non-separable),
|
364 |
+
`[filter_taps]` (separable), or
|
365 |
+
`None` (identity).
|
366 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
367 |
+
`[x, y]` (default: 1).
|
368 |
+
padding: Padding with respect to the input. Can be a single number or a
|
369 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
370 |
+
(default: 0).
|
371 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
372 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
373 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
377 |
+
"""
|
378 |
+
downx, downy = _parse_scaling(down)
|
379 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
380 |
+
fw, fh = _get_filter_size(f)
|
381 |
+
p = [
|
382 |
+
padx0 + (fw - downx + 1) // 2,
|
383 |
+
padx1 + (fw - downx) // 2,
|
384 |
+
pady0 + (fh - downy + 1) // 2,
|
385 |
+
pady1 + (fh - downy) // 2,
|
386 |
+
]
|
387 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
388 |
+
|
389 |
+
#----------------------------------------------------------------------------
|
ADD/utils/util_net.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Power by Zongsheng Yue 2021-11-24 20:29:36
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from pathlib import Path
|
8 |
+
from collections import OrderedDict
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from copy import deepcopy
|
11 |
+
|
12 |
+
def calculate_parameters(net):
|
13 |
+
out = 0
|
14 |
+
for param in net.parameters():
|
15 |
+
out += param.numel()
|
16 |
+
return out
|
17 |
+
|
18 |
+
def pad_input(x, mod):
|
19 |
+
h, w = x.shape[-2:]
|
20 |
+
bottom = int(math.ceil(h/mod)*mod -h)
|
21 |
+
right = int(math.ceil(w/mod)*mod - w)
|
22 |
+
x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')
|
23 |
+
return x_pad
|
24 |
+
|
25 |
+
def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):
|
26 |
+
n_GPUs = 1
|
27 |
+
b, c, h, w = x.size()
|
28 |
+
h_half, w_half = h // 2, w // 2
|
29 |
+
h_size, w_size = h_half + shave, w_half + shave
|
30 |
+
lr_list = [
|
31 |
+
x[:, :, 0:h_size, 0:w_size],
|
32 |
+
x[:, :, 0:h_size, (w - w_size):w],
|
33 |
+
x[:, :, (h - h_size):h, 0:w_size],
|
34 |
+
x[:, :, (h - h_size):h, (w - w_size):w]]
|
35 |
+
|
36 |
+
if w_size * h_size < min_size:
|
37 |
+
sr_list = []
|
38 |
+
for i in range(0, 4, n_GPUs):
|
39 |
+
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
|
40 |
+
if net_kwargs is None:
|
41 |
+
sr_batch = net(lr_batch)
|
42 |
+
else:
|
43 |
+
sr_batch = net(lr_batch, **net_kwargs)
|
44 |
+
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
|
45 |
+
else:
|
46 |
+
sr_list = [
|
47 |
+
forward_chop(patch, shave=shave, min_size=min_size) \
|
48 |
+
for patch in lr_list
|
49 |
+
]
|
50 |
+
|
51 |
+
h, w = scale * h, scale * w
|
52 |
+
h_half, w_half = scale * h_half, scale * w_half
|
53 |
+
h_size, w_size = scale * h_size, scale * w_size
|
54 |
+
shave *= scale
|
55 |
+
|
56 |
+
output = x.new(b, c, h, w)
|
57 |
+
output[:, :, 0:h_half, 0:w_half] \
|
58 |
+
= sr_list[0][:, :, 0:h_half, 0:w_half]
|
59 |
+
output[:, :, 0:h_half, w_half:w] \
|
60 |
+
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
|
61 |
+
output[:, :, h_half:h, 0:w_half] \
|
62 |
+
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
|
63 |
+
output[:, :, h_half:h, w_half:w] \
|
64 |
+
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
|
65 |
+
|
66 |
+
return output
|
67 |
+
|
68 |
+
def measure_time(net, inputs, num_forward=100):
|
69 |
+
'''
|
70 |
+
Measuring the average runing time (seconds) for pytorch.
|
71 |
+
out = net(*inputs)
|
72 |
+
'''
|
73 |
+
start = torch.cuda.Event(enable_timing=True)
|
74 |
+
end = torch.cuda.Event(enable_timing=True)
|
75 |
+
|
76 |
+
start.record()
|
77 |
+
with torch.set_grad_enabled(False):
|
78 |
+
for _ in range(num_forward):
|
79 |
+
out = net(*inputs)
|
80 |
+
end.record()
|
81 |
+
|
82 |
+
torch.cuda.synchronize()
|
83 |
+
|
84 |
+
return start.elapsed_time(end) / 1000
|
85 |
+
|
86 |
+
def reload_model(model, ckpt):
|
87 |
+
if list(model.state_dict().keys())[0].startswith('module.'):
|
88 |
+
if list(ckpt.keys())[0].startswith('module.'):
|
89 |
+
ckpt = ckpt
|
90 |
+
else:
|
91 |
+
ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()})
|
92 |
+
else:
|
93 |
+
if list(ckpt.keys())[0].startswith('module.'):
|
94 |
+
ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
|
95 |
+
else:
|
96 |
+
ckpt = ckpt
|
97 |
+
model.load_state_dict(ckpt, True)
|
98 |
+
|
99 |
+
def compute_hinge_loss(real_output, fake_output, x_start_, r1_lambda):
|
100 |
+
if r1_lambda == 0:
|
101 |
+
real_loss_total = torch.relu(torch.ones_like(real_output) - real_output).mean()
|
102 |
+
fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
|
103 |
+
|
104 |
+
else:
|
105 |
+
real_loss_ = torch.relu(torch.ones_like(real_output) - real_output).mean()
|
106 |
+
|
107 |
+
# 计算真实样本的梯度
|
108 |
+
grad_real = torch.autograd.grad(outputs=real_output.sum(), inputs=x_start_, create_graph=True)[0]
|
109 |
+
|
110 |
+
# 计算梯度惩罚
|
111 |
+
grad_penalty = (grad_real.contiguous().view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() * r1_lambda
|
112 |
+
|
113 |
+
real_loss_total = real_loss_ + grad_penalty
|
114 |
+
fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
|
115 |
+
|
116 |
+
real_loss = real_loss_total
|
117 |
+
fake_loss = fake_loss_total
|
118 |
+
|
119 |
+
loss_d = real_loss + fake_loss
|
120 |
+
|
121 |
+
return loss_d
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
def reload_model_(model, ckpt):
|
126 |
+
if list(model.state_dict().keys())[0].startswith('model.'):
|
127 |
+
if list(ckpt.keys())[0].startswith('model.'):
|
128 |
+
ckpt = ckpt
|
129 |
+
else:
|
130 |
+
ckpt = OrderedDict({f'model.{key}':value for key, value in ckpt.items()})
|
131 |
+
else:
|
132 |
+
if list(ckpt.keys())[0].startswith('model.'):
|
133 |
+
ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
|
134 |
+
else:
|
135 |
+
ckpt = ckpt
|
136 |
+
model.load_state_dict(ckpt, True)
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
def reload_model_IDE(model, ckpt):
|
141 |
+
extracted_dict = OrderedDict()
|
142 |
+
for key, value in ckpt.items():
|
143 |
+
if key.startswith('E_st'):
|
144 |
+
new_key = key.replace('E_st.', '')
|
145 |
+
extracted_dict[new_key] = value
|
146 |
+
|
147 |
+
model.load_state_dict(extracted_dict, True)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class EMA():
|
152 |
+
def __init__(self, model, decay):
|
153 |
+
self.model = model
|
154 |
+
self.decay = decay
|
155 |
+
self.shadow = {}
|
156 |
+
self.backup = {}
|
157 |
+
|
158 |
+
def register(self):
|
159 |
+
for name, param in self.model.named_parameters():
|
160 |
+
if param.requires_grad:
|
161 |
+
self.shadow[name] = param.data.clone()
|
162 |
+
|
163 |
+
def update(self):
|
164 |
+
for name, param in self.model.named_parameters():
|
165 |
+
if param.requires_grad:
|
166 |
+
assert name in self.shadow
|
167 |
+
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
|
168 |
+
self.shadow[name] = new_average.clone()
|
169 |
+
|
170 |
+
def apply_shadow(self):
|
171 |
+
for name, param in self.model.named_parameters():
|
172 |
+
if param.requires_grad:
|
173 |
+
assert name in self.shadow
|
174 |
+
self.backup[name] = param.data
|
175 |
+
param.data = self.shadow[name]
|
176 |
+
|
177 |
+
def restore(self):
|
178 |
+
for name, param in self.model.named_parameters():
|
179 |
+
if param.requires_grad:
|
180 |
+
assert name in self.backup
|
181 |
+
param.data = self.backup[name]
|
182 |
+
self.backup = {}
|
README.md
CHANGED
@@ -1,15 +1,291 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="figs/logo.png" width="400">
|
3 |
+
</p>
|
4 |
+
|
5 |
+
<div align="center">
|
6 |
+
<h2>Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution</h2>
|
7 |
+
|
8 |
+
|
9 |
+
<a href='https://arxiv.org/pdf/2401.00877'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
|
10 |
+
|
11 |
+
|
12 |
+
[Lingchen Sun](https://scholar.google.com/citations?hl=zh-CN&tzom=-480&user=ZCDjTn8AAAAJ)<sup>1,2</sup>
|
13 |
+
| [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ&hl=zh-CN)<sup>1,2</sup> |
|
14 |
+
[Jie Liang](https://scholar.google.com.sg/citations?user=REWxLZsAAAAJ&hl)<sup>2</sup> |
|
15 |
+
[Zhengqiang Zhang](https://scholar.google.com/citations?hl=zh-CN&user=UX26wSMAAAAJ&view_op=list_works&sortby=pubdate)<sup>1,2</sup> |
|
16 |
+
[Hongwei Yong](https://scholar.google.com.hk/citations?user=Xii74qQAAAAJ&hl=zh-CN)<sup>1</sup> |
|
17 |
+
[Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)<sup>1,2</sup>
|
18 |
+
|
19 |
+
<sup>1</sup>The Hong Kong Polytechnic University, <sup>2</sup>OPPO Research Institute
|
20 |
+
</div>
|
21 |
+
|
22 |
+
:star: If CCSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:
|
23 |
+
|
24 |
+
## 🧡ྀི What's New in CCSR-v2?
|
25 |
+
We have implemented the CCSR-v2 code based on the [Diffusers](https://github.com/huggingface/diffusers). Compared to CCSR-v1, CCSR-v2 brings a host of upgrades:
|
26 |
+
|
27 |
+
- 🛠️**Step Flexibility**: Offers flexibility in diffusion step selection, **allowing users to freely adjust the number of steps to suit their specific requirements**. This adaptability **requires no additional re-training**, ensuring seamless integration into diverse workflows.
|
28 |
+
- ⚡**Efficiency**: Supports highly efficient inference with **as few as 2 or even 1 diffusion step**, drastically reducing computation time without compromising quality.
|
29 |
+
- 📈**Enhanced Clarity**: With upgraded algorithms, CCSR-v2 restores images with crisper details while maintaining fidelity.
|
30 |
+
- ⚖️**Results stability**: CCSR-v2 exhibits significantly improved stability in synthesizing fine image details, ensuring higher-quality outputs.
|
31 |
+
- 🔄**Stage 2 Refinement**: In CCSR-v2, the output $\hat{x}_{0 \gets T}$ from Stage 1 is now directly fed into Stage 2, streamlining the restoration process into an efficient one-step diffusion workflow. This strategy boosts both speed and performance.
|
32 |
+
|
33 |
+
![ccsr](figs/fig.png)
|
34 |
+
Visual comparisons between the SR outputs with the same input low-quality image but two different noise samples by different DM-based
|
35 |
+
methods. `S` denotes diffusion sampling timesteps. Existing DM-based methods, including StableSR, PASD, SeeSR, SUPIR and AddSR, **show noticeable instability with the different noise samples**. OSEDiff directly takes low-quality image as input without
|
36 |
+
noise sampling. It is deterministic and stable, but **cannot perform multi-step diffusion** for high generative capacity. In contrast, **our proposed CCSR method
|
37 |
+
is flexible for both multi-step diffusion and single-step diffusion, while producing stable results with high fidelity and visual quality**.
|
38 |
+
|
39 |
+
## ⏰ Update
|
40 |
+
- **2024.12.12**: Code and models for CCSR-v2 are released. 👀 Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v2.0).
|
41 |
+
- **2024.9.25**: ⭐[CCSR-v2](https://arxiv.org/pdf/2401.00877) is released, offering reduced step requirements and supporting flexible diffusion step selection (2 or even 1 step) during the inference stage without the need for re-training.
|
42 |
+
- **2023.12.23**: Code and models for [CCSR-v1](https://arxiv.org/pdf/2401.00877v1) are released. Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v1.0).
|
43 |
+
|
44 |
+
|
45 |
+
## 🌟 Overview Framework
|
46 |
+
![ccsr](figs/framework.png)
|
47 |
+
|
48 |
+
## 😍 Visual Results
|
49 |
+
### Demo on Real-world SR
|
50 |
+
|
51 |
+
[<img src="figs/compare_1.png" height="213px"/>](https://imgsli.com/MzI2MTg5) [<img src="figs/compare_2.png" height="213px"/>](https://imgsli.com/MzI2MTky/1/3) [<img src="figs/compare_3.png" height="213px"/>](https://imgsli.com/MzI2MTk0/0/2) [<img src="figs/compare_4.png" height="213px"/>](https://imgsli.com/MzI2MTk1/0/2)
|
52 |
+
|
53 |
+
|
54 |
+
![ccsr](figs/compare_standard.png)
|
55 |
+
|
56 |
+
![ccsr](figs/compare_efficient.png)
|
57 |
+
For more comparisons, please refer to our paper for details.
|
58 |
+
|
59 |
+
## 📝 Quantitative comparisons
|
60 |
+
We propose new stability metrics, namely global standard deviation (G-STD) and local standard deviation (L-STD), to respectively measure the image-level and pixel-level variations of the SR results of diffusion-based methods.
|
61 |
+
|
62 |
+
More details about G-STD and L-STD can be found in our paper.
|
63 |
+
|
64 |
+
![ccsr](figs/table.png)
|
65 |
+
## ⚙ Dependencies and Installation
|
66 |
+
```shell
|
67 |
+
## git clone this repository
|
68 |
+
git clone https://github.com/csslc/CCSR.git
|
69 |
+
cd CCSR
|
70 |
+
|
71 |
+
|
72 |
+
# create an environment with python >= 3.9
|
73 |
+
conda create -n ccsr python=3.9
|
74 |
+
conda activate ccsr
|
75 |
+
pip install -r requirements.txt
|
76 |
+
```
|
77 |
+
## 🍭 Quick Inference
|
78 |
+
**For ease of comparison, we have provided the test results of CCSR-v2 on the DIV2K, RealSR, and DrealSR benchmarks with varying diffusion steps, which can be accessed via [Google Drive](https://drive.google.com/drive/folders/1xjURQZgKAlENzMnAJA2PDG9h_UxfZzio?usp=sharing).**
|
79 |
+
|
80 |
+
#### Step 1: Download the pretrained models
|
81 |
+
- Download the pretrained SD-2.1-base models from [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
|
82 |
+
- Download the CCSR-v2 models from and put the models in the `preset/models`:
|
83 |
+
|
84 |
+
| Model Name | Description | GoogleDrive | BaiduNetdisk |
|
85 |
+
|:-----------------------|:---------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------|
|
86 |
+
| Controlnet | Trained in the stage 1. | [download](https://drive.google.com/drive/folders/1aHwgodKwKYZJBKs0QlFzanSjMDhrNyRA?usp=sharing) | [download](https://pan.baidu.com/s/1SKS70iE4GhhHGxqY1KS8mw) (pwd: ccsr) |
|
87 |
+
| VAE | Trained in the stage 2. | [download](https://drive.google.com/drive/folders/1yHfMV81Md6db4StHTP5MC-eSeLFeBKm8?usp=sharing) | [download](https://pan.baidu.com/s/1fxOIeL6Hk6Muq9h8itAIKQ) (pwd: ccsr) |
|
88 |
+
| Pre-trained Controlnet | The pre-trained model of stage1. | [download](https://drive.google.com/drive/folders/1LTtBRuObITOJwbW-sTDnHtp8xIUZFDHh?usp=sharing) | [download](https://pan.baidu.com/s/1mDeuHBqNj_Iol7PCY_Xfww) (pwd: ccsr) |
|
89 |
+
| Dino models | The pre-trained models for disc. | [download](https://drive.google.com/drive/folders/1PcuZGUTJlltdPz2yk2ZIa4GCtb1yk_y6?usp=sharing) | [download](https://pan.baidu.com/s/1nPdNwgua91mDDRApWUm39Q) (pwd: ccsr) |
|
90 |
+
|
91 |
+
#### Step 2: Prepare testing data
|
92 |
+
You can put the testing images in the `preset/test_datasets`.
|
93 |
+
|
94 |
+
#### Step 3: Running testing command
|
95 |
+
For one-step diffusion process:
|
96 |
+
```
|
97 |
+
python test_ccsr_tile.py \
|
98 |
+
--pretrained_model_path preset/models/stable-diffusion-2-1-base \
|
99 |
+
--controlnet_model_path preset/models \
|
100 |
+
--vae_model_path preset/models \
|
101 |
+
--baseline_name ccsr-v2 \
|
102 |
+
--image_path preset/test_datasets \
|
103 |
+
--output_dir experiments/test \
|
104 |
+
--sample_method ddpm \
|
105 |
+
--num_inference_steps 1 \
|
106 |
+
--t_min 0.0 \
|
107 |
+
--start_point lr \
|
108 |
+
--start_steps 999 \
|
109 |
+
--process_size 512 \
|
110 |
+
--guidance_scale 1.0 \
|
111 |
+
--sample_times 1 \
|
112 |
+
--use_vae_encode_condition \
|
113 |
+
--upscale 4
|
114 |
+
```
|
115 |
+
For multi-step diffusion process:
|
116 |
+
```
|
117 |
+
python test_ccsr_tile.py \
|
118 |
+
--pretrained_model_path preset/models/stable-diffusion-2-1-base \
|
119 |
+
--controlnet_model_path preset/models \
|
120 |
+
--vae_model_path preset/models \
|
121 |
+
--baseline_name ccsr-v2 \
|
122 |
+
--image_path preset/test_datasets \
|
123 |
+
--output_dir experiments/test \
|
124 |
+
--sample_method ddpm \
|
125 |
+
--num_inference_steps 6 \
|
126 |
+
--t_max 0.6667 \
|
127 |
+
--t_min 0.5 \
|
128 |
+
--start_point lr \
|
129 |
+
--start_steps 999 \
|
130 |
+
--process_size 512 \
|
131 |
+
--guidance_scale 4.5 \
|
132 |
+
--sample_times 1 \
|
133 |
+
--use_vae_encode_condition \
|
134 |
+
--upscale 4
|
135 |
+
```
|
136 |
+
We integrate [tile_diffusion](https://github.com/albarji/mixture-of-diffusers) and [tile_vae](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/tree/main) to the [test_ccsr_tile.py](test_ccsr_tile.py) to save the GPU memory for inference.
|
137 |
+
You can change the tile size and stride according to the VRAM of your device.
|
138 |
+
```
|
139 |
+
python test_ccsr_tile.py \
|
140 |
+
--pretrained_model_path preset/models/stable-diffusion-2-1-base \
|
141 |
+
--controlnet_model_path preset/models \
|
142 |
+
--vae_model_path preset/models \
|
143 |
+
--baseline_name ccsr-v2 \
|
144 |
+
--image_path preset/test_datasets \
|
145 |
+
--output_dir experiments/test \
|
146 |
+
--sample_method ddpm \
|
147 |
+
--num_inference_steps 6 \
|
148 |
+
--t_max 0.6667 \
|
149 |
+
--t_min 0.5 \
|
150 |
+
--start_point lr \
|
151 |
+
--start_steps 999 \
|
152 |
+
--process_size 512 \
|
153 |
+
--guidance_scale 4.5 \
|
154 |
+
--sample_times 1 \
|
155 |
+
--use_vae_encode_condition \
|
156 |
+
--upscale 4 \
|
157 |
+
--tile_diffusion \
|
158 |
+
--tile_diffusion_size 512 \
|
159 |
+
--tile_diffusion_stride 256 \
|
160 |
+
--tile_vae \
|
161 |
+
--vae_decoder_tile_size 224 \
|
162 |
+
--vae_encoder_tile_size 1024 \
|
163 |
+
```
|
164 |
+
|
165 |
+
You can obtain `N` different SR results by setting `sample_times` as `N` to test the stability of CCSR. The data folder should be like this:
|
166 |
+
|
167 |
+
```
|
168 |
+
experiments/test
|
169 |
+
├── sample00 # the first group of SR results
|
170 |
+
└── sample01 # the second group of SR results
|
171 |
+
...
|
172 |
+
└── sampleN # the N-th group of SR results
|
173 |
+
```
|
174 |
+
|
175 |
+
## 📏 Evaluation
|
176 |
+
1. Calculate the Image Quality Assessment for each restored group.
|
177 |
+
|
178 |
+
Fill in the required information in [cal_iqa.py](cal_iqa/cal_iqa.py) and run, then you can obtain the evaluation results in the folder like this:
|
179 |
+
```
|
180 |
+
log_path
|
181 |
+
├── log_name_npy # save the IQA values of each restored group as the npy files
|
182 |
+
└── log_name.log # log recode
|
183 |
+
```
|
184 |
+
|
185 |
+
2. Calculate the G-STD value for the diffusion-based SR method.
|
186 |
+
|
187 |
+
Fill in the required information in [iqa_G-STD.py](cal_iqa/iqa_G-STD.py) and run, then you can obtain the mean IQA values of N restored groups and G-STD value.
|
188 |
+
|
189 |
+
3. Calculate the L-STD value for the diffusion-based SR method.
|
190 |
+
|
191 |
+
Fill in the required information in [iqa_L-STD.py](cal_iqa/iqa_L-STD.py) and run, then you can obtain the L-STD value.
|
192 |
+
|
193 |
+
|
194 |
+
## 🚋 Train
|
195 |
+
|
196 |
+
#### Step1: Prepare training data
|
197 |
+
Generate txt file for the training set.
|
198 |
+
Fill in the required information in [get_path](scripts/get_path.py) and run, then you can obtain the txt file recording the paths of ground-truth images.
|
199 |
+
You can save the txt file into `preset/gt_path.txt`.
|
200 |
+
|
201 |
+
#### Step2: Train Stage1 Model
|
202 |
+
1. Download pretrained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to provide generative capabilities.
|
203 |
+
|
204 |
+
```shell
|
205 |
+
wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt --no-check-certificate
|
206 |
+
```
|
207 |
+
|
208 |
+
2. Start training.
|
209 |
+
|
210 |
+
```shell
|
211 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage1.py \
|
212 |
+
--pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
|
213 |
+
--controlnet_model_name_or_path='preset/models/pretrained_controlnet' \
|
214 |
+
--enable_xformers_memory_efficient_attention \
|
215 |
+
--output_dir="./experiments/ccsrv2_stage1" \
|
216 |
+
--mixed_precision="fp16" \
|
217 |
+
--resolution=512 \
|
218 |
+
--learning_rate=5e-5 \
|
219 |
+
--train_batch_size=4 \
|
220 |
+
--gradient_accumulation_steps=6 \
|
221 |
+
--dataloader_num_workers=0 \
|
222 |
+
--checkpointing_steps=500 \
|
223 |
+
--t_max=0.6667 \
|
224 |
+
--max_train_steps=20000 \
|
225 |
+
--dataset_root_folders 'preset/gt_path.txt'
|
226 |
+
```
|
227 |
+
|
228 |
+
#### Step3: Train Stage2 Model
|
229 |
+
1. Put the model obtained from the stage1 into `controlnet_model_name_or_path`.
|
230 |
+
2. Start training.
|
231 |
+
```shell
|
232 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage2.py \
|
233 |
+
--pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
|
234 |
+
--controlnet_model_name_or_path='preset/models/model_stage1' \
|
235 |
+
--enable_xformers_memory_efficient_attention \
|
236 |
+
--output_dir="./experiments/ccsrv2_stage2" \
|
237 |
+
--mixed_precision="fp16" \
|
238 |
+
--resolution=512 \
|
239 |
+
--learning_rate=5e-6 \
|
240 |
+
--train_batch_size=2 \
|
241 |
+
--gradient_accumulation_steps=8 \
|
242 |
+
--checkpointing_steps=500 \
|
243 |
+
--is_start_lr=True \
|
244 |
+
--t_max=0.6667 \
|
245 |
+
--num_inference_steps=1 \
|
246 |
+
--is_module \
|
247 |
+
--lambda_l2=1.0 \
|
248 |
+
--lambda_lpips=1.0 \
|
249 |
+
--lambda_disc=0.05 \
|
250 |
+
--lambda_disc_train=0.5 \
|
251 |
+
--begin_disc=100 \
|
252 |
+
--max_train_steps=2000 \
|
253 |
+
--dataset_root_folders 'preset/gt_path.txt'
|
254 |
+
```
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
### Citations
|
261 |
+
|
262 |
+
If our code helps your research or work, please consider citing our paper.
|
263 |
+
The following are BibTeX references:
|
264 |
+
|
265 |
+
```
|
266 |
+
@article{sun2023ccsr,
|
267 |
+
title={Improving the Stability of Diffusion Models for Content Consistent Super-Resolution},
|
268 |
+
author={Sun, Lingchen and Wu, Rongyuan and Zhang, Zhengqiang and Yong, Hongwei and Zhang, Lei},
|
269 |
+
journal={arXiv preprint arXiv:2401.00877},
|
270 |
+
year={2024}
|
271 |
+
}
|
272 |
+
```
|
273 |
+
|
274 |
+
### License
|
275 |
+
This project is released under the [Apache 2.0 license](LICENSE).
|
276 |
+
|
277 |
+
### Acknowledgement
|
278 |
+
This project is based on [ControlNet](https://github.com/lllyasviel/ControlNet), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [SeeSR](https://github.com/cswry/SeeSR). Some codes are brought from [ADDSR](https://github.com/NJU-PCALab/AddSR). Thanks for their awesome works.
|
279 |
+
|
280 |
+
### Contact
|
281 |
+
If you have any questions, please contact: [email protected]
|
282 |
+
|
283 |
+
|
284 |
+
<details>
|
285 |
+
<summary>statistics</summary>
|
286 |
+
|
287 |
+
![visitors](https://visitor-badge.laobi.icu/badge?page_id=csslc/CCSR)
|
288 |
+
|
289 |
+
</details>
|
290 |
+
|
291 |
+
|
dataloaders/paired_dataset_txt.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
from torchvision import transforms
|
9 |
+
from torch.utils import data as data
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .realesrgan import RealESRGAN_degradation
|
13 |
+
|
14 |
+
class PairedCaptionDataset(data.Dataset):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
root_folders=None,
|
18 |
+
tokenizer=None,
|
19 |
+
gt_ratio=0, # let lr is gt
|
20 |
+
):
|
21 |
+
super(PairedCaptionDataset, self).__init__()
|
22 |
+
|
23 |
+
self.gt_ratio = gt_ratio
|
24 |
+
with open(root_folders, 'r') as f:
|
25 |
+
self.gt_list = [line.strip() for line in f.readlines()]
|
26 |
+
|
27 |
+
self.img_preproc = transforms.Compose([
|
28 |
+
transforms.RandomCrop((512, 512)),
|
29 |
+
transforms.Resize((512, 512)),
|
30 |
+
transforms.RandomHorizontalFlip(),
|
31 |
+
])
|
32 |
+
|
33 |
+
self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')
|
34 |
+
self.tokenizer = tokenizer
|
35 |
+
|
36 |
+
|
37 |
+
def tokenize_caption(self, caption=""):
|
38 |
+
inputs = self.tokenizer(
|
39 |
+
caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
40 |
+
)
|
41 |
+
|
42 |
+
return inputs.input_ids
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
|
46 |
+
gt_path = self.gt_list[index]
|
47 |
+
gt_img = Image.open(gt_path).convert('RGB')
|
48 |
+
gt_img = self.img_preproc(gt_img)
|
49 |
+
|
50 |
+
gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)
|
51 |
+
|
52 |
+
if random.random() < self.gt_ratio:
|
53 |
+
lq_img = gt_img
|
54 |
+
else:
|
55 |
+
lq_img = img_t
|
56 |
+
|
57 |
+
# no caption used
|
58 |
+
lq_caption = ''
|
59 |
+
|
60 |
+
example = dict()
|
61 |
+
example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1]
|
62 |
+
example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]
|
63 |
+
example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0)
|
64 |
+
|
65 |
+
lq_img = lq_img.squeeze()
|
66 |
+
|
67 |
+
return example
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.gt_list)
|
dataloaders/params_ccsr.yml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scale: 4
|
2 |
+
color_jitter_prob: 0.0
|
3 |
+
gray_prob: 0.0
|
4 |
+
|
5 |
+
# the first degradation process
|
6 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
7 |
+
resize_range: [0.3, 1.5]
|
8 |
+
gaussian_noise_prob: 0.5
|
9 |
+
noise_range: [1, 15]
|
10 |
+
poisson_scale_range: [0.05, 2.0]
|
11 |
+
gray_noise_prob: 0.4
|
12 |
+
jpeg_range: [60, 95]
|
13 |
+
|
14 |
+
|
15 |
+
# the second degradation process
|
16 |
+
second_blur_prob: 0.5
|
17 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
18 |
+
resize_range2: [0.6, 1.2]
|
19 |
+
gaussian_noise_prob2: 0.5
|
20 |
+
noise_range2: [1, 12]
|
21 |
+
poisson_scale_range2: [0.05, 1.0]
|
22 |
+
gray_noise_prob2: 0.4
|
23 |
+
jpeg_range2: [60, 100]
|
24 |
+
|
25 |
+
kernel_info:
|
26 |
+
blur_kernel_size: 21
|
27 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
28 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
29 |
+
sinc_prob: 0.1
|
30 |
+
blur_sigma: [0.2, 1.5]
|
31 |
+
betag_range: [0.5, 2.0]
|
32 |
+
betap_range: [1, 1.5]
|
33 |
+
|
34 |
+
blur_kernel_size2: 11
|
35 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
36 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
37 |
+
sinc_prob2: 0.1
|
38 |
+
blur_sigma2: [0.2, 1.0]
|
39 |
+
betag_range2: [0.5, 2.0]
|
40 |
+
betap_range2: [1, 1.5]
|
41 |
+
|
42 |
+
final_sinc_prob: 0.8
|
dataloaders/realesrgan.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import glob
|
5 |
+
import math
|
6 |
+
import yaml
|
7 |
+
import random
|
8 |
+
from collections import OrderedDict
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from basicsr.data.transforms import augment
|
13 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
14 |
+
from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
|
15 |
+
from basicsr.utils.img_process_util import filter2D
|
16 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
17 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
18 |
+
normalize, rgb_to_grayscale)
|
19 |
+
|
20 |
+
cur_path = os.path.dirname(os.path.abspath(__file__))
|
21 |
+
|
22 |
+
|
23 |
+
def ordered_yaml():
|
24 |
+
"""Support OrderedDict for yaml.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
yaml Loader and Dumper.
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
from yaml import CDumper as Dumper
|
31 |
+
from yaml import CLoader as Loader
|
32 |
+
except ImportError:
|
33 |
+
from yaml import Dumper, Loader
|
34 |
+
|
35 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
36 |
+
|
37 |
+
def dict_representer(dumper, data):
|
38 |
+
return dumper.represent_dict(data.items())
|
39 |
+
|
40 |
+
def dict_constructor(loader, node):
|
41 |
+
return OrderedDict(loader.construct_pairs(node))
|
42 |
+
|
43 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
44 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
45 |
+
return Loader, Dumper
|
46 |
+
|
47 |
+
def opt_parse(opt_path):
|
48 |
+
with open(opt_path, mode='r') as f:
|
49 |
+
Loader, _ = ordered_yaml()
|
50 |
+
opt = yaml.load(f, Loader=Loader)
|
51 |
+
|
52 |
+
return opt
|
53 |
+
|
54 |
+
class RealESRGAN_degradation(object):
|
55 |
+
def __init__(self, opt_path='', device='cpu'):
|
56 |
+
self.opt = opt_parse(opt_path)
|
57 |
+
self.device = device #torch.device('cpu')
|
58 |
+
optk = self.opt['kernel_info']
|
59 |
+
|
60 |
+
# blur settings for the first degradation
|
61 |
+
self.blur_kernel_size = optk['blur_kernel_size']
|
62 |
+
self.kernel_list = optk['kernel_list']
|
63 |
+
self.kernel_prob = optk['kernel_prob']
|
64 |
+
self.blur_sigma = optk['blur_sigma']
|
65 |
+
self.betag_range = optk['betag_range']
|
66 |
+
self.betap_range = optk['betap_range']
|
67 |
+
self.sinc_prob = optk['sinc_prob']
|
68 |
+
|
69 |
+
# blur settings for the second degradation
|
70 |
+
self.blur_kernel_size2 = optk['blur_kernel_size2']
|
71 |
+
self.kernel_list2 = optk['kernel_list2']
|
72 |
+
self.kernel_prob2 = optk['kernel_prob2']
|
73 |
+
self.blur_sigma2 = optk['blur_sigma2']
|
74 |
+
self.betag_range2 = optk['betag_range2']
|
75 |
+
self.betap_range2 = optk['betap_range2']
|
76 |
+
self.sinc_prob2 = optk['sinc_prob2']
|
77 |
+
|
78 |
+
# a final sinc filter
|
79 |
+
self.final_sinc_prob = optk['final_sinc_prob']
|
80 |
+
|
81 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
82 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
83 |
+
self.pulse_tensor[10, 10] = 1
|
84 |
+
|
85 |
+
self.jpeger = DiffJPEG(differentiable=False).to(self.device)
|
86 |
+
self.usm_shaper = USMSharp().to(self.device)
|
87 |
+
|
88 |
+
def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
|
89 |
+
fn_idx = torch.randperm(4)
|
90 |
+
for fn_id in fn_idx:
|
91 |
+
if fn_id == 0 and brightness is not None:
|
92 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
93 |
+
img = adjust_brightness(img, brightness_factor)
|
94 |
+
|
95 |
+
if fn_id == 1 and contrast is not None:
|
96 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
97 |
+
img = adjust_contrast(img, contrast_factor)
|
98 |
+
|
99 |
+
if fn_id == 2 and saturation is not None:
|
100 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
101 |
+
img = adjust_saturation(img, saturation_factor)
|
102 |
+
|
103 |
+
if fn_id == 3 and hue is not None:
|
104 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
105 |
+
img = adjust_hue(img, hue_factor)
|
106 |
+
return img
|
107 |
+
|
108 |
+
def random_augment(self, img_gt):
|
109 |
+
# random horizontal flip
|
110 |
+
img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
|
111 |
+
"""
|
112 |
+
# random color jitter
|
113 |
+
if np.random.uniform() < self.opt['color_jitter_prob']:
|
114 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
115 |
+
img_gt = img_gt + jitter_val
|
116 |
+
img_gt = np.clip(img_gt, 0, 1)
|
117 |
+
|
118 |
+
# random grayscale
|
119 |
+
if np.random.uniform() < self.opt['gray_prob']:
|
120 |
+
#img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
121 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
|
122 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
123 |
+
"""
|
124 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
125 |
+
img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
|
126 |
+
|
127 |
+
return img_gt
|
128 |
+
|
129 |
+
def random_kernels(self):
|
130 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
131 |
+
kernel_size = random.choice(self.kernel_range)
|
132 |
+
if np.random.uniform() < self.sinc_prob:
|
133 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
134 |
+
if kernel_size < 13:
|
135 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
136 |
+
else:
|
137 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
138 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
139 |
+
else:
|
140 |
+
kernel = random_mixed_kernels(
|
141 |
+
self.kernel_list,
|
142 |
+
self.kernel_prob,
|
143 |
+
kernel_size,
|
144 |
+
self.blur_sigma,
|
145 |
+
self.blur_sigma, [-math.pi, math.pi],
|
146 |
+
self.betag_range,
|
147 |
+
self.betap_range,
|
148 |
+
noise_range=None)
|
149 |
+
# pad kernel
|
150 |
+
pad_size = (21 - kernel_size) // 2
|
151 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
152 |
+
|
153 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
154 |
+
kernel_size = random.choice(self.kernel_range)
|
155 |
+
if np.random.uniform() < self.sinc_prob2:
|
156 |
+
if kernel_size < 13:
|
157 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
158 |
+
else:
|
159 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
160 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
161 |
+
else:
|
162 |
+
kernel2 = random_mixed_kernels(
|
163 |
+
self.kernel_list2,
|
164 |
+
self.kernel_prob2,
|
165 |
+
kernel_size,
|
166 |
+
self.blur_sigma2,
|
167 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
168 |
+
self.betag_range2,
|
169 |
+
self.betap_range2,
|
170 |
+
noise_range=None)
|
171 |
+
|
172 |
+
# pad kernel
|
173 |
+
pad_size = (21 - kernel_size) // 2
|
174 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
175 |
+
|
176 |
+
# ------------------------------------- sinc kernel ------------------------------------- #
|
177 |
+
if np.random.uniform() < self.final_sinc_prob:
|
178 |
+
kernel_size = random.choice(self.kernel_range)
|
179 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
180 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
181 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
182 |
+
else:
|
183 |
+
sinc_kernel = self.pulse_tensor
|
184 |
+
|
185 |
+
kernel = torch.FloatTensor(kernel)
|
186 |
+
kernel2 = torch.FloatTensor(kernel2)
|
187 |
+
|
188 |
+
return kernel, kernel2, sinc_kernel
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def degrade_process(self, img_gt, resize_bak=False):
|
192 |
+
img_gt = self.random_augment(img_gt)
|
193 |
+
kernel1, kernel2, sinc_kernel = self.random_kernels()
|
194 |
+
img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
|
195 |
+
#img_gt = self.usm_shaper(img_gt) # shaper gt
|
196 |
+
ori_h, ori_w = img_gt.size()[2:4]
|
197 |
+
|
198 |
+
#scale_final = random.randint(4, 16)
|
199 |
+
scale_final = 4
|
200 |
+
|
201 |
+
# ----------------------- The first degradation process ----------------------- #
|
202 |
+
# blur
|
203 |
+
out = filter2D(img_gt, kernel1)
|
204 |
+
# random resize
|
205 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
206 |
+
if updown_type == 'up':
|
207 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
208 |
+
elif updown_type == 'down':
|
209 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
210 |
+
else:
|
211 |
+
scale = 1
|
212 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
213 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
214 |
+
# noise
|
215 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
216 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
217 |
+
out = random_add_gaussian_noise_pt(
|
218 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
219 |
+
else:
|
220 |
+
out = random_add_poisson_noise_pt(
|
221 |
+
out,
|
222 |
+
scale_range=self.opt['poisson_scale_range'],
|
223 |
+
gray_prob=gray_noise_prob,
|
224 |
+
clip=True,
|
225 |
+
rounds=False)
|
226 |
+
# JPEG compression
|
227 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
228 |
+
out = torch.clamp(out, 0, 1)
|
229 |
+
out = self.jpeger(out, quality=jpeg_p)
|
230 |
+
|
231 |
+
# ----------------------- The second degradation process ----------------------- #
|
232 |
+
# blur
|
233 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
234 |
+
out = filter2D(out, kernel2)
|
235 |
+
# random resize
|
236 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
237 |
+
if updown_type == 'up':
|
238 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
239 |
+
elif updown_type == 'down':
|
240 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
241 |
+
else:
|
242 |
+
scale = 1
|
243 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
244 |
+
out = F.interpolate(
|
245 |
+
out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
|
246 |
+
# noise
|
247 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
248 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
249 |
+
out = random_add_gaussian_noise_pt(
|
250 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
251 |
+
else:
|
252 |
+
out = random_add_poisson_noise_pt(
|
253 |
+
out,
|
254 |
+
scale_range=self.opt['poisson_scale_range2'],
|
255 |
+
gray_prob=gray_noise_prob,
|
256 |
+
clip=True,
|
257 |
+
rounds=False)
|
258 |
+
|
259 |
+
# JPEG compression + the final sinc filter
|
260 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
261 |
+
# as one operation.
|
262 |
+
# We consider two orders:
|
263 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
264 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
265 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
266 |
+
if np.random.uniform() < 0.5:
|
267 |
+
# resize back + the final sinc filter
|
268 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
269 |
+
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
|
270 |
+
out = filter2D(out, sinc_kernel)
|
271 |
+
# JPEG compression
|
272 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
273 |
+
out = torch.clamp(out, 0, 1)
|
274 |
+
out = self.jpeger(out, quality=jpeg_p)
|
275 |
+
else:
|
276 |
+
# JPEG compression
|
277 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
278 |
+
out = torch.clamp(out, 0, 1)
|
279 |
+
out = self.jpeger(out, quality=jpeg_p)
|
280 |
+
# resize back + the final sinc filter
|
281 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
282 |
+
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
|
283 |
+
out = filter2D(out, sinc_kernel)
|
284 |
+
|
285 |
+
if np.random.uniform() < self.opt['gray_prob']:
|
286 |
+
out = rgb_to_grayscale(out, num_output_channels=1)
|
287 |
+
|
288 |
+
if np.random.uniform() < self.opt['color_jitter_prob']:
|
289 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
290 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
291 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
292 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
293 |
+
out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
|
294 |
+
|
295 |
+
if resize_bak:
|
296 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
297 |
+
out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
|
298 |
+
# clamp and round
|
299 |
+
img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
300 |
+
|
301 |
+
return img_gt, img_lq
|
302 |
+
|
303 |
+
|
models/DiffAugment.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 2-Clause "Simplified" License
|
2 |
+
# Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# Redistribution and use in source and binary forms, with or without
|
6 |
+
# modification, are permitted provided that the following conditions are met:
|
7 |
+
#
|
8 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
9 |
+
# list of conditions and the following disclaimer.
|
10 |
+
#
|
11 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
# this list of conditions and the following disclaimer in the documentation
|
13 |
+
# and/or other materials provided with the distribution.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
16 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
19 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
20 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
21 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
22 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
23 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
24 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
#
|
26 |
+
# Code from https://github.com/mit-han-lab/data-efficient-gans
|
27 |
+
|
28 |
+
"""Training GANs with DiffAugment."""
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn.functional as F
|
33 |
+
|
34 |
+
|
35 |
+
def DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor:
|
36 |
+
if policy:
|
37 |
+
if not channels_first:
|
38 |
+
x = x.permute(0, 3, 1, 2)
|
39 |
+
for p in policy.split(','):
|
40 |
+
for f in AUGMENT_FNS[p]:
|
41 |
+
x = f(x)
|
42 |
+
if not channels_first:
|
43 |
+
x = x.permute(0, 2, 3, 1)
|
44 |
+
x = x.contiguous()
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
def rand_brightness(x: torch.Tensor) -> torch.Tensor:
|
49 |
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
def rand_saturation(x: torch.Tensor) -> torch.Tensor:
|
54 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
55 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def rand_contrast(x: torch.Tensor) -> torch.Tensor:
|
60 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
61 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor:
|
66 |
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
67 |
+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
68 |
+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
69 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
70 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
71 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
72 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
73 |
+
)
|
74 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
75 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
76 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
77 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
def rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor:
|
82 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
83 |
+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
84 |
+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
85 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
86 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
87 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
88 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
89 |
+
)
|
90 |
+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
91 |
+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
92 |
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
93 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
94 |
+
x = x * mask.unsqueeze(1)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
def rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor:
|
99 |
+
resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio
|
100 |
+
resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear')
|
101 |
+
org_size = x.shape[3]
|
102 |
+
if int(resize_ratio*x.shape[3]) < x.shape[3]:
|
103 |
+
left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2.
|
104 |
+
left_pad = int(left_pad)
|
105 |
+
right_pad = x.shape[3] - left_pad - resized_img.shape[3]
|
106 |
+
x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.)
|
107 |
+
else:
|
108 |
+
left = (int(resize_ratio*x.shape[3])-x.shape[3])/2.
|
109 |
+
left = int(left)
|
110 |
+
x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])]
|
111 |
+
assert x.shape[2] == org_size
|
112 |
+
assert x.shape[3] == org_size
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
AUGMENT_FNS = {
|
117 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
118 |
+
'translation': [rand_translation],
|
119 |
+
'resize': [rand_resize],
|
120 |
+
'cutout': [rand_cutout],
|
121 |
+
}
|
models/controlnet.py
ADDED
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import FromOriginalControlnetMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
25 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
from .unet_2d_blocks import (
|
28 |
+
CrossAttnDownBlock2D,
|
29 |
+
DownBlock2D,
|
30 |
+
UNetMidBlock2DCrossAttn,
|
31 |
+
get_down_block,
|
32 |
+
)
|
33 |
+
from .unet_2d_condition import UNet2DConditionModel
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class ControlNetOutput(BaseOutput):
|
41 |
+
"""
|
42 |
+
The output of [`ControlNetModel`].
|
43 |
+
|
44 |
+
Args:
|
45 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
46 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
47 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
48 |
+
used to condition the original UNet's downsampling activations.
|
49 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
50 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
51 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
52 |
+
Output can be used to condition the original UNet's middle block activation.
|
53 |
+
"""
|
54 |
+
|
55 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
56 |
+
mid_block_res_sample: torch.Tensor
|
57 |
+
|
58 |
+
|
59 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
60 |
+
"""
|
61 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
62 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
63 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
64 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
65 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
66 |
+
model) to encode image-space conditions ... into feature maps ..."
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
conditioning_embedding_channels: int,
|
72 |
+
conditioning_channels: int = 3,
|
73 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
78 |
+
|
79 |
+
self.blocks = nn.ModuleList([])
|
80 |
+
|
81 |
+
for i in range(len(block_out_channels) - 1):
|
82 |
+
channel_in = block_out_channels[i]
|
83 |
+
channel_out = block_out_channels[i + 1]
|
84 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
85 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
86 |
+
|
87 |
+
self.conv_out = zero_module(
|
88 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, conditioning):
|
92 |
+
embedding = self.conv_in(conditioning)
|
93 |
+
embedding = F.silu(embedding)
|
94 |
+
|
95 |
+
for block in self.blocks:
|
96 |
+
embedding = block(embedding)
|
97 |
+
embedding = F.silu(embedding)
|
98 |
+
|
99 |
+
embedding = self.conv_out(embedding)
|
100 |
+
|
101 |
+
return embedding
|
102 |
+
|
103 |
+
|
104 |
+
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
105 |
+
"""
|
106 |
+
A ControlNet model.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
in_channels (`int`, defaults to 4):
|
110 |
+
The number of channels in the input sample.
|
111 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
112 |
+
Whether to flip the sin to cos in the time embedding.
|
113 |
+
freq_shift (`int`, defaults to 0):
|
114 |
+
The frequency shift to apply to the time embedding.
|
115 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
116 |
+
The tuple of downsample blocks to use.
|
117 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
118 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
119 |
+
The tuple of output channels for each block.
|
120 |
+
layers_per_block (`int`, defaults to 2):
|
121 |
+
The number of layers per block.
|
122 |
+
downsample_padding (`int`, defaults to 1):
|
123 |
+
The padding to use for the downsampling convolution.
|
124 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
125 |
+
The scale factor to use for the mid block.
|
126 |
+
act_fn (`str`, defaults to "silu"):
|
127 |
+
The activation function to use.
|
128 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
129 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
130 |
+
in post-processing.
|
131 |
+
norm_eps (`float`, defaults to 1e-5):
|
132 |
+
The epsilon to use for the normalization.
|
133 |
+
cross_attention_dim (`int`, defaults to 1280):
|
134 |
+
The dimension of the cross attention features.
|
135 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
136 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
137 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
138 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
139 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
140 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
141 |
+
dimension to `cross_attention_dim`.
|
142 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
143 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
144 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
145 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
146 |
+
The dimension of the attention heads.
|
147 |
+
use_linear_projection (`bool`, defaults to `False`):
|
148 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
149 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
150 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
151 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
152 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
153 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
154 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
155 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
156 |
+
class conditioning with `class_embed_type` equal to `None`.
|
157 |
+
upcast_attention (`bool`, defaults to `False`):
|
158 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
159 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
160 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
161 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
162 |
+
`class_embed_type="projection"`.
|
163 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
164 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
165 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
166 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
167 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
168 |
+
"""
|
169 |
+
|
170 |
+
_supports_gradient_checkpointing = True
|
171 |
+
|
172 |
+
@register_to_config
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
in_channels: int = 4,
|
176 |
+
conditioning_channels: int = 3,
|
177 |
+
flip_sin_to_cos: bool = True,
|
178 |
+
freq_shift: int = 0,
|
179 |
+
down_block_types: Tuple[str] = (
|
180 |
+
"CrossAttnDownBlock2D",
|
181 |
+
"CrossAttnDownBlock2D",
|
182 |
+
"CrossAttnDownBlock2D",
|
183 |
+
"DownBlock2D",
|
184 |
+
),
|
185 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
186 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
187 |
+
layers_per_block: int = 2,
|
188 |
+
downsample_padding: int = 1,
|
189 |
+
mid_block_scale_factor: float = 1,
|
190 |
+
act_fn: str = "silu",
|
191 |
+
norm_num_groups: Optional[int] = 32,
|
192 |
+
norm_eps: float = 1e-5,
|
193 |
+
cross_attention_dim: int = 1280,
|
194 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
195 |
+
encoder_hid_dim: Optional[int] = None,
|
196 |
+
encoder_hid_dim_type: Optional[str] = None,
|
197 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
198 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
199 |
+
use_linear_projection: bool = False,
|
200 |
+
class_embed_type: Optional[str] = None,
|
201 |
+
addition_embed_type: Optional[str] = None,
|
202 |
+
addition_time_embed_dim: Optional[int] = None,
|
203 |
+
num_class_embeds: Optional[int] = None,
|
204 |
+
upcast_attention: bool = False,
|
205 |
+
resnet_time_scale_shift: str = "default",
|
206 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
207 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
208 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
209 |
+
global_pool_conditions: bool = False,
|
210 |
+
addition_embed_type_num_heads=64,
|
211 |
+
use_vae_encode_condition=False,
|
212 |
+
):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
216 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
217 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
218 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
219 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
220 |
+
# which is why we correct for the naming here.
|
221 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
222 |
+
|
223 |
+
# Check inputs
|
224 |
+
if len(block_out_channels) != len(down_block_types):
|
225 |
+
raise ValueError(
|
226 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
227 |
+
)
|
228 |
+
|
229 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
230 |
+
raise ValueError(
|
231 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
232 |
+
)
|
233 |
+
|
234 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
235 |
+
raise ValueError(
|
236 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
237 |
+
)
|
238 |
+
|
239 |
+
if isinstance(transformer_layers_per_block, int):
|
240 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
241 |
+
|
242 |
+
# input
|
243 |
+
conv_in_kernel = 3
|
244 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
245 |
+
self.conv_in = nn.Conv2d(
|
246 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
247 |
+
)
|
248 |
+
|
249 |
+
# use_vae_encode_condition
|
250 |
+
self.use_vae_encode_condition = use_vae_encode_condition
|
251 |
+
if self.use_vae_encode_condition:
|
252 |
+
print(f'============================')
|
253 |
+
print(f'use vae encode condition in CONTROLNET!!!')
|
254 |
+
print(f'============================')
|
255 |
+
self.condition_conv_in = nn.Conv2d(
|
256 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
print(f'============================')
|
260 |
+
print(f'Not !!! use vae encode condition in CONTROLNET')
|
261 |
+
print(f'============================')
|
262 |
+
# control net conditioning embedding
|
263 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
264 |
+
conditioning_embedding_channels=block_out_channels[0],
|
265 |
+
block_out_channels=conditioning_embedding_out_channels,
|
266 |
+
conditioning_channels=conditioning_channels,
|
267 |
+
)
|
268 |
+
|
269 |
+
# time
|
270 |
+
time_embed_dim = block_out_channels[0] * 4
|
271 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
272 |
+
timestep_input_dim = block_out_channels[0]
|
273 |
+
self.time_embedding = TimestepEmbedding(
|
274 |
+
timestep_input_dim,
|
275 |
+
time_embed_dim,
|
276 |
+
act_fn=act_fn,
|
277 |
+
)
|
278 |
+
|
279 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
280 |
+
encoder_hid_dim_type = "text_proj"
|
281 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
282 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
283 |
+
|
284 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
285 |
+
raise ValueError(
|
286 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
287 |
+
)
|
288 |
+
|
289 |
+
if encoder_hid_dim_type == "text_proj":
|
290 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
291 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
292 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
293 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
294 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
295 |
+
self.encoder_hid_proj = TextImageProjection(
|
296 |
+
text_embed_dim=encoder_hid_dim,
|
297 |
+
image_embed_dim=cross_attention_dim,
|
298 |
+
cross_attention_dim=cross_attention_dim,
|
299 |
+
)
|
300 |
+
|
301 |
+
elif encoder_hid_dim_type is not None:
|
302 |
+
raise ValueError(
|
303 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
self.encoder_hid_proj = None
|
307 |
+
|
308 |
+
# class embedding
|
309 |
+
if class_embed_type is None and num_class_embeds is not None:
|
310 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
311 |
+
elif class_embed_type == "timestep":
|
312 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
313 |
+
elif class_embed_type == "identity":
|
314 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
315 |
+
elif class_embed_type == "projection":
|
316 |
+
if projection_class_embeddings_input_dim is None:
|
317 |
+
raise ValueError(
|
318 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
319 |
+
)
|
320 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
321 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
322 |
+
# 2. it projects from an arbitrary input dimension.
|
323 |
+
#
|
324 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
325 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
326 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
327 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
328 |
+
else:
|
329 |
+
self.class_embedding = None
|
330 |
+
|
331 |
+
if addition_embed_type == "text":
|
332 |
+
if encoder_hid_dim is not None:
|
333 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
334 |
+
else:
|
335 |
+
text_time_embedding_from_dim = cross_attention_dim
|
336 |
+
|
337 |
+
self.add_embedding = TextTimeEmbedding(
|
338 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
339 |
+
)
|
340 |
+
elif addition_embed_type == "text_image":
|
341 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
342 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
343 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
344 |
+
self.add_embedding = TextImageTimeEmbedding(
|
345 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
346 |
+
)
|
347 |
+
elif addition_embed_type == "text_time":
|
348 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
349 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
350 |
+
|
351 |
+
elif addition_embed_type is not None:
|
352 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
353 |
+
|
354 |
+
self.down_blocks = nn.ModuleList([])
|
355 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
356 |
+
|
357 |
+
if isinstance(only_cross_attention, bool):
|
358 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
359 |
+
|
360 |
+
if isinstance(attention_head_dim, int):
|
361 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
362 |
+
|
363 |
+
if isinstance(num_attention_heads, int):
|
364 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
365 |
+
|
366 |
+
# down
|
367 |
+
output_channel = block_out_channels[0]
|
368 |
+
|
369 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
370 |
+
controlnet_block = zero_module(controlnet_block)
|
371 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
372 |
+
|
373 |
+
|
374 |
+
for i, down_block_type in enumerate(down_block_types):
|
375 |
+
input_channel = output_channel
|
376 |
+
output_channel = block_out_channels[i]
|
377 |
+
is_final_block = i == len(block_out_channels) - 1
|
378 |
+
|
379 |
+
down_block = get_down_block(
|
380 |
+
down_block_type,
|
381 |
+
num_layers=layers_per_block,
|
382 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
383 |
+
in_channels=input_channel,
|
384 |
+
out_channels=output_channel,
|
385 |
+
temb_channels=time_embed_dim,
|
386 |
+
add_downsample=not is_final_block,
|
387 |
+
resnet_eps=norm_eps,
|
388 |
+
resnet_act_fn=act_fn,
|
389 |
+
resnet_groups=norm_num_groups,
|
390 |
+
cross_attention_dim=cross_attention_dim,
|
391 |
+
num_attention_heads=num_attention_heads[i],
|
392 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
393 |
+
downsample_padding=downsample_padding,
|
394 |
+
use_linear_projection=use_linear_projection,
|
395 |
+
only_cross_attention=only_cross_attention[i],
|
396 |
+
upcast_attention=upcast_attention,
|
397 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
398 |
+
)
|
399 |
+
self.down_blocks.append(down_block)
|
400 |
+
|
401 |
+
for _ in range(layers_per_block):
|
402 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
403 |
+
controlnet_block = zero_module(controlnet_block)
|
404 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
405 |
+
|
406 |
+
if not is_final_block:
|
407 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
408 |
+
controlnet_block = zero_module(controlnet_block)
|
409 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
410 |
+
|
411 |
+
# mid
|
412 |
+
mid_block_channel = block_out_channels[-1]
|
413 |
+
|
414 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
415 |
+
controlnet_block = zero_module(controlnet_block)
|
416 |
+
self.controlnet_mid_block = controlnet_block
|
417 |
+
|
418 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
419 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
420 |
+
in_channels=mid_block_channel,
|
421 |
+
temb_channels=time_embed_dim,
|
422 |
+
resnet_eps=norm_eps,
|
423 |
+
resnet_act_fn=act_fn,
|
424 |
+
output_scale_factor=mid_block_scale_factor,
|
425 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
426 |
+
cross_attention_dim=cross_attention_dim,
|
427 |
+
num_attention_heads=num_attention_heads[-1],
|
428 |
+
resnet_groups=norm_num_groups,
|
429 |
+
use_linear_projection=use_linear_projection,
|
430 |
+
upcast_attention=upcast_attention,
|
431 |
+
)
|
432 |
+
|
433 |
+
@classmethod
|
434 |
+
def from_unet(
|
435 |
+
cls,
|
436 |
+
unet: UNet2DConditionModel,
|
437 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
438 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
439 |
+
load_weights_from_unet: bool = True,
|
440 |
+
use_vae_encode_condition: bool = False,
|
441 |
+
):
|
442 |
+
r"""
|
443 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
444 |
+
|
445 |
+
Parameters:
|
446 |
+
unet (`UNet2DConditionModel`):
|
447 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
448 |
+
where applicable.
|
449 |
+
"""
|
450 |
+
transformer_layers_per_block = (
|
451 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
452 |
+
)
|
453 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
454 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
455 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
456 |
+
addition_time_embed_dim = (
|
457 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
458 |
+
)
|
459 |
+
|
460 |
+
controlnet = cls(
|
461 |
+
encoder_hid_dim=encoder_hid_dim,
|
462 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
463 |
+
addition_embed_type=addition_embed_type,
|
464 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
465 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
466 |
+
in_channels=unet.config.in_channels,
|
467 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
468 |
+
freq_shift=unet.config.freq_shift,
|
469 |
+
down_block_types=unet.config.down_block_types,
|
470 |
+
only_cross_attention=unet.config.only_cross_attention,
|
471 |
+
block_out_channels=unet.config.block_out_channels,
|
472 |
+
layers_per_block=unet.config.layers_per_block,
|
473 |
+
downsample_padding=unet.config.downsample_padding,
|
474 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
475 |
+
act_fn=unet.config.act_fn,
|
476 |
+
norm_num_groups=unet.config.norm_num_groups,
|
477 |
+
norm_eps=unet.config.norm_eps,
|
478 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
479 |
+
attention_head_dim=unet.config.attention_head_dim,
|
480 |
+
num_attention_heads=unet.config.num_attention_heads,
|
481 |
+
use_linear_projection=unet.config.use_linear_projection,
|
482 |
+
class_embed_type=unet.config.class_embed_type,
|
483 |
+
num_class_embeds=unet.config.num_class_embeds,
|
484 |
+
upcast_attention=unet.config.upcast_attention,
|
485 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
486 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
487 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
488 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
489 |
+
use_vae_encode_condition=use_vae_encode_condition,
|
490 |
+
)
|
491 |
+
|
492 |
+
if load_weights_from_unet:
|
493 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
494 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
495 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
496 |
+
|
497 |
+
if controlnet.class_embedding:
|
498 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
499 |
+
|
500 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
501 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
502 |
+
|
503 |
+
return controlnet
|
504 |
+
|
505 |
+
@property
|
506 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
507 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
508 |
+
r"""
|
509 |
+
Returns:
|
510 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
511 |
+
indexed by its weight name.
|
512 |
+
"""
|
513 |
+
# set recursively
|
514 |
+
processors = {}
|
515 |
+
|
516 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
517 |
+
if hasattr(module, "get_processor"):
|
518 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
519 |
+
|
520 |
+
for sub_name, child in module.named_children():
|
521 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
522 |
+
|
523 |
+
return processors
|
524 |
+
|
525 |
+
for name, module in self.named_children():
|
526 |
+
fn_recursive_add_processors(name, module, processors)
|
527 |
+
|
528 |
+
return processors
|
529 |
+
|
530 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
531 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
532 |
+
r"""
|
533 |
+
Sets the attention processor to use to compute attention.
|
534 |
+
|
535 |
+
Parameters:
|
536 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
537 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
538 |
+
for **all** `Attention` layers.
|
539 |
+
|
540 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
541 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
542 |
+
|
543 |
+
"""
|
544 |
+
count = len(self.attn_processors.keys())
|
545 |
+
|
546 |
+
if isinstance(processor, dict) and len(processor) != count:
|
547 |
+
raise ValueError(
|
548 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
549 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
550 |
+
)
|
551 |
+
|
552 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
553 |
+
if hasattr(module, "set_processor"):
|
554 |
+
if not isinstance(processor, dict):
|
555 |
+
module.set_processor(processor)
|
556 |
+
else:
|
557 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
558 |
+
|
559 |
+
for sub_name, child in module.named_children():
|
560 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
561 |
+
|
562 |
+
for name, module in self.named_children():
|
563 |
+
fn_recursive_attn_processor(name, module, processor)
|
564 |
+
|
565 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
566 |
+
def set_default_attn_processor(self):
|
567 |
+
"""
|
568 |
+
Disables custom attention processors and sets the default attention implementation.
|
569 |
+
"""
|
570 |
+
self.set_attn_processor(AttnProcessor())
|
571 |
+
|
572 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
573 |
+
def set_attention_slice(self, slice_size):
|
574 |
+
r"""
|
575 |
+
Enable sliced attention computation.
|
576 |
+
|
577 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
578 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
582 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
583 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
584 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
585 |
+
must be a multiple of `slice_size`.
|
586 |
+
"""
|
587 |
+
sliceable_head_dims = []
|
588 |
+
|
589 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
590 |
+
if hasattr(module, "set_attention_slice"):
|
591 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
592 |
+
|
593 |
+
for child in module.children():
|
594 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
595 |
+
|
596 |
+
# retrieve number of attention layers
|
597 |
+
for module in self.children():
|
598 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
599 |
+
|
600 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
601 |
+
|
602 |
+
if slice_size == "auto":
|
603 |
+
# half the attention head size is usually a good trade-off between
|
604 |
+
# speed and memory
|
605 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
606 |
+
elif slice_size == "max":
|
607 |
+
# make smallest slice possible
|
608 |
+
slice_size = num_sliceable_layers * [1]
|
609 |
+
|
610 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
611 |
+
|
612 |
+
if len(slice_size) != len(sliceable_head_dims):
|
613 |
+
raise ValueError(
|
614 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
615 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
616 |
+
)
|
617 |
+
|
618 |
+
for i in range(len(slice_size)):
|
619 |
+
size = slice_size[i]
|
620 |
+
dim = sliceable_head_dims[i]
|
621 |
+
if size is not None and size > dim:
|
622 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
623 |
+
|
624 |
+
# Recursively walk through all the children.
|
625 |
+
# Any children which exposes the set_attention_slice method
|
626 |
+
# gets the message
|
627 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
628 |
+
if hasattr(module, "set_attention_slice"):
|
629 |
+
module.set_attention_slice(slice_size.pop())
|
630 |
+
|
631 |
+
for child in module.children():
|
632 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
633 |
+
|
634 |
+
reversed_slice_size = list(reversed(slice_size))
|
635 |
+
for module in self.children():
|
636 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
637 |
+
|
638 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
639 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
640 |
+
module.gradient_checkpointing = value
|
641 |
+
|
642 |
+
def forward(
|
643 |
+
self,
|
644 |
+
sample: torch.FloatTensor,
|
645 |
+
timestep: Union[torch.Tensor, float, int],
|
646 |
+
encoder_hidden_states: torch.Tensor,
|
647 |
+
controlnet_cond: torch.FloatTensor,
|
648 |
+
conditioning_scale: float = 1.0,
|
649 |
+
class_labels: Optional[torch.Tensor] = None,
|
650 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
651 |
+
attention_mask: Optional[torch.Tensor] = None,
|
652 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
653 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
654 |
+
guess_mode: bool = False,
|
655 |
+
return_dict: bool = True,
|
656 |
+
image_encoder_hidden_states: torch.Tensor = None,
|
657 |
+
vae_encode_condition_hidden_states: torch.Tensor = None,
|
658 |
+
use_vae_encode_condition = False,
|
659 |
+
) -> Union[ControlNetOutput, Tuple]:
|
660 |
+
"""
|
661 |
+
The [`ControlNetModel`] forward method.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
sample (`torch.FloatTensor`):
|
665 |
+
The noisy input tensor.
|
666 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
667 |
+
The number of timesteps to denoise an input.
|
668 |
+
encoder_hidden_states (`torch.Tensor`):
|
669 |
+
The encoder hidden states.
|
670 |
+
controlnet_cond (`torch.FloatTensor`):
|
671 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
672 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
673 |
+
The scale factor for ControlNet outputs.
|
674 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
675 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
676 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
677 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
678 |
+
added_cond_kwargs (`dict`):
|
679 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
680 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
681 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
682 |
+
guess_mode (`bool`, defaults to `False`):
|
683 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
684 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
685 |
+
return_dict (`bool`, defaults to `True`):
|
686 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
687 |
+
|
688 |
+
Returns:
|
689 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
690 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
691 |
+
returned where the first element is the sample tensor.
|
692 |
+
"""
|
693 |
+
# check channel order
|
694 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
695 |
+
|
696 |
+
if channel_order == "rgb":
|
697 |
+
# in rgb order by default
|
698 |
+
...
|
699 |
+
elif channel_order == "bgr":
|
700 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
701 |
+
else:
|
702 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
703 |
+
|
704 |
+
# prepare attention_mask
|
705 |
+
if attention_mask is not None:
|
706 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
707 |
+
attention_mask = attention_mask.unsqueeze(1)
|
708 |
+
|
709 |
+
# 1. time
|
710 |
+
timesteps = timestep
|
711 |
+
if not torch.is_tensor(timesteps):
|
712 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
713 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
714 |
+
is_mps = sample.device.type == "mps"
|
715 |
+
if isinstance(timestep, float):
|
716 |
+
dtype = torch.float32 if is_mps else torch.float64
|
717 |
+
else:
|
718 |
+
dtype = torch.int32 if is_mps else torch.int64
|
719 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
720 |
+
elif len(timesteps.shape) == 0:
|
721 |
+
timesteps = timesteps[None].to(sample.device)
|
722 |
+
|
723 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
724 |
+
timesteps = timesteps.expand(sample.shape[0])
|
725 |
+
|
726 |
+
t_emb = self.time_proj(timesteps)
|
727 |
+
|
728 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
729 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
730 |
+
# there might be better ways to encapsulate this.
|
731 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
732 |
+
|
733 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
734 |
+
aug_emb = None
|
735 |
+
|
736 |
+
if self.class_embedding is not None:
|
737 |
+
if class_labels is None:
|
738 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
739 |
+
|
740 |
+
if self.config.class_embed_type == "timestep":
|
741 |
+
class_labels = self.time_proj(class_labels)
|
742 |
+
|
743 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
744 |
+
emb = emb + class_emb
|
745 |
+
|
746 |
+
if self.config.addition_embed_type is not None:
|
747 |
+
if self.config.addition_embed_type == "text":
|
748 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
749 |
+
|
750 |
+
elif self.config.addition_embed_type == "text_time":
|
751 |
+
if "text_embeds" not in added_cond_kwargs:
|
752 |
+
raise ValueError(
|
753 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
754 |
+
)
|
755 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
756 |
+
if "time_ids" not in added_cond_kwargs:
|
757 |
+
raise ValueError(
|
758 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
759 |
+
)
|
760 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
761 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
762 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
763 |
+
|
764 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
765 |
+
add_embeds = add_embeds.to(emb.dtype)
|
766 |
+
aug_emb = self.add_embedding(add_embeds)
|
767 |
+
|
768 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
769 |
+
|
770 |
+
# 2. pre-process
|
771 |
+
sample = self.conv_in(sample)
|
772 |
+
|
773 |
+
if not self.use_vae_encode_condition:
|
774 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
775 |
+
else:
|
776 |
+
controlnet_cond = self.condition_conv_in(vae_encode_condition_hidden_states)
|
777 |
+
|
778 |
+
sample = sample + controlnet_cond
|
779 |
+
|
780 |
+
# 3. down
|
781 |
+
down_block_res_samples = (sample,)
|
782 |
+
for downsample_block in self.down_blocks:
|
783 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
784 |
+
sample, res_samples = downsample_block(
|
785 |
+
hidden_states=sample,
|
786 |
+
temb=emb,
|
787 |
+
encoder_hidden_states=encoder_hidden_states,
|
788 |
+
attention_mask=attention_mask,
|
789 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
790 |
+
image_encoder_hidden_states=image_encoder_hidden_states,
|
791 |
+
)
|
792 |
+
else:
|
793 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
794 |
+
|
795 |
+
down_block_res_samples += res_samples
|
796 |
+
|
797 |
+
# 4. mid
|
798 |
+
if self.mid_block is not None:
|
799 |
+
sample = self.mid_block(
|
800 |
+
sample,
|
801 |
+
emb,
|
802 |
+
encoder_hidden_states=encoder_hidden_states,
|
803 |
+
attention_mask=attention_mask,
|
804 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
805 |
+
image_encoder_hidden_states=image_encoder_hidden_states,
|
806 |
+
)
|
807 |
+
|
808 |
+
# 5. Control net blocks
|
809 |
+
|
810 |
+
controlnet_down_block_res_samples = ()
|
811 |
+
|
812 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
813 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
814 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
815 |
+
|
816 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
817 |
+
|
818 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
819 |
+
|
820 |
+
# 6. scaling
|
821 |
+
if guess_mode and not self.config.global_pool_conditions:
|
822 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
823 |
+
|
824 |
+
scales = scales * conditioning_scale
|
825 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
826 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
827 |
+
else:
|
828 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
829 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
830 |
+
|
831 |
+
if self.config.global_pool_conditions:
|
832 |
+
down_block_res_samples = [
|
833 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
834 |
+
]
|
835 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
836 |
+
|
837 |
+
if not return_dict:
|
838 |
+
return (down_block_res_samples, mid_block_res_sample)
|
839 |
+
|
840 |
+
return ControlNetOutput(
|
841 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
842 |
+
)
|
843 |
+
|
844 |
+
|
845 |
+
def zero_module(module):
|
846 |
+
for p in module.parameters():
|
847 |
+
nn.init.zeros_(p)
|
848 |
+
return module
|
849 |
+
|
850 |
+
|
models/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from models.losses.contperceptual import LPIPSWithDiscriminator
|
models/losses/contperceptual.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
5 |
+
from diffusers.models.modeling_utils import ModelMixin
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.loaders import FromOriginalControlnetMixin
|
8 |
+
|
9 |
+
class LPIPSWithDiscriminator(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
10 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
11 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
12 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
13 |
+
disc_loss="hinge"):
|
14 |
+
|
15 |
+
super().__init__()
|
16 |
+
assert disc_loss in ["hinge", "vanilla"]
|
17 |
+
self.kl_weight = kl_weight
|
18 |
+
self.pixel_weight = pixelloss_weight
|
19 |
+
self.perceptual_loss = LPIPS().eval()
|
20 |
+
self.perceptual_weight = perceptual_weight
|
21 |
+
# output log variance
|
22 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
23 |
+
|
24 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
25 |
+
n_layers=disc_num_layers,
|
26 |
+
use_actnorm=use_actnorm
|
27 |
+
).apply(weights_init)
|
28 |
+
self.discriminator_iter_start = disc_start
|
29 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
30 |
+
self.disc_factor = disc_factor
|
31 |
+
self.discriminator_weight = disc_weight
|
32 |
+
self.disc_conditional = disc_conditional
|
33 |
+
|
34 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
35 |
+
if last_layer is not None:
|
36 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
37 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
38 |
+
else:
|
39 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
40 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
41 |
+
|
42 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
43 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
44 |
+
d_weight = d_weight * self.discriminator_weight
|
45 |
+
return d_weight
|
46 |
+
|
47 |
+
def forward(self, inputs, reconstructions, optimizer_idx,
|
48 |
+
global_step, posteriors=None, last_layer=None, cond=None, split="train",
|
49 |
+
weights=None, return_dic=False):
|
50 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
51 |
+
if self.perceptual_weight > 0:
|
52 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
53 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
54 |
+
|
55 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
56 |
+
weighted_nll_loss = nll_loss
|
57 |
+
if weights is not None:
|
58 |
+
weighted_nll_loss = weights*nll_loss
|
59 |
+
weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
60 |
+
nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]
|
61 |
+
if self.kl_weight>0:
|
62 |
+
kl_loss = posteriors.kl()
|
63 |
+
kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
|
64 |
+
|
65 |
+
# now the GAN part
|
66 |
+
if optimizer_idx == 0:
|
67 |
+
# generator update
|
68 |
+
if cond is None:
|
69 |
+
assert not self.disc_conditional
|
70 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
71 |
+
else:
|
72 |
+
assert self.disc_conditional
|
73 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
74 |
+
g_loss = -torch.mean(logits_fake)
|
75 |
+
|
76 |
+
if self.disc_factor > 0.0:
|
77 |
+
try:
|
78 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
79 |
+
except RuntimeError:
|
80 |
+
# assert not self.training
|
81 |
+
d_weight = torch.tensor(1.0) * self.discriminator_weight
|
82 |
+
else:
|
83 |
+
# d_weight = torch.tensor(0.0)
|
84 |
+
d_weight = torch.tensor(0.0)
|
85 |
+
|
86 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
87 |
+
if self.kl_weight>0:
|
88 |
+
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
89 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
90 |
+
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
91 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
92 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
93 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
94 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
95 |
+
}
|
96 |
+
if return_dic:
|
97 |
+
loss_dic = {}
|
98 |
+
loss_dic['total_loss'] = loss.clone().detach().mean()
|
99 |
+
loss_dic['logvar'] = self.logvar.detach()
|
100 |
+
loss_dic['kl_loss'] = kl_loss.detach().mean()
|
101 |
+
loss_dic['nll_loss'] = nll_loss.detach().mean()
|
102 |
+
loss_dic['rec_loss'] = rec_loss.detach().mean()
|
103 |
+
loss_dic['d_weight'] = d_weight.detach()
|
104 |
+
loss_dic['disc_factor'] = torch.tensor(disc_factor)
|
105 |
+
loss_dic['g_loss'] = g_loss.detach().mean()
|
106 |
+
else:
|
107 |
+
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
|
108 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
109 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
110 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
111 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
112 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
113 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
114 |
+
}
|
115 |
+
if return_dic:
|
116 |
+
loss_dic = {}
|
117 |
+
loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean()
|
118 |
+
loss_dic["{}/logvar".format(split)] = self.logvar.detach()
|
119 |
+
loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()
|
120 |
+
loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()
|
121 |
+
loss_dic['d_weight'.format(split)] = d_weight.detach()
|
122 |
+
loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)
|
123 |
+
loss_dic['g_loss'.format(split)] = g_loss.detach().mean()
|
124 |
+
|
125 |
+
if return_dic:
|
126 |
+
return loss, log, loss_dic
|
127 |
+
return loss, log
|
128 |
+
|
129 |
+
if optimizer_idx == 1:
|
130 |
+
# second pass for discriminator update
|
131 |
+
if cond is None:
|
132 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
133 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
134 |
+
else:
|
135 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
136 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
137 |
+
|
138 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
139 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
140 |
+
|
141 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
142 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
143 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
144 |
+
}
|
145 |
+
|
146 |
+
if return_dic:
|
147 |
+
loss_dic = {}
|
148 |
+
loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean()
|
149 |
+
loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean()
|
150 |
+
loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean()
|
151 |
+
return d_loss, log, loss_dic
|
152 |
+
|
153 |
+
return d_loss, log
|
154 |
+
|
models/losses/vqperceptual.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import repeat
|
5 |
+
|
6 |
+
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
from taming.modules.losses.lpips import LPIPS
|
8 |
+
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
9 |
+
|
10 |
+
|
11 |
+
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
12 |
+
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
13 |
+
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
14 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
15 |
+
loss_real = (weights * loss_real).sum() / weights.sum()
|
16 |
+
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
17 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
18 |
+
return d_loss
|
19 |
+
|
20 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
21 |
+
if global_step < threshold:
|
22 |
+
weight = value
|
23 |
+
return weight
|
24 |
+
|
25 |
+
|
26 |
+
def measure_perplexity(predicted_indices, n_embed):
|
27 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
28 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
29 |
+
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
30 |
+
avg_probs = encodings.mean(0)
|
31 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
32 |
+
cluster_use = torch.sum(avg_probs > 0)
|
33 |
+
return perplexity, cluster_use
|
34 |
+
|
35 |
+
def l1(x, y):
|
36 |
+
return torch.abs(x-y)
|
37 |
+
|
38 |
+
|
39 |
+
def l2(x, y):
|
40 |
+
return torch.pow((x-y), 2)
|
41 |
+
|
42 |
+
|
43 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
44 |
+
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
45 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
46 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
47 |
+
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
48 |
+
pixel_loss="l1"):
|
49 |
+
super().__init__()
|
50 |
+
assert disc_loss in ["hinge", "vanilla"]
|
51 |
+
assert perceptual_loss in ["lpips", "clips", "dists"]
|
52 |
+
assert pixel_loss in ["l1", "l2"]
|
53 |
+
self.codebook_weight = codebook_weight
|
54 |
+
self.pixel_weight = pixelloss_weight
|
55 |
+
if perceptual_loss == "lpips":
|
56 |
+
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
57 |
+
self.perceptual_loss = LPIPS().eval()
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
60 |
+
self.perceptual_weight = perceptual_weight
|
61 |
+
|
62 |
+
if pixel_loss == "l1":
|
63 |
+
self.pixel_loss = l1
|
64 |
+
else:
|
65 |
+
self.pixel_loss = l2
|
66 |
+
|
67 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
68 |
+
n_layers=disc_num_layers,
|
69 |
+
use_actnorm=use_actnorm,
|
70 |
+
ndf=disc_ndf
|
71 |
+
).apply(weights_init)
|
72 |
+
self.discriminator_iter_start = disc_start
|
73 |
+
if disc_loss == "hinge":
|
74 |
+
self.disc_loss = hinge_d_loss
|
75 |
+
elif disc_loss == "vanilla":
|
76 |
+
self.disc_loss = vanilla_d_loss
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
79 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
80 |
+
self.disc_factor = disc_factor
|
81 |
+
self.discriminator_weight = disc_weight
|
82 |
+
self.disc_conditional = disc_conditional
|
83 |
+
self.n_classes = n_classes
|
84 |
+
|
85 |
+
# def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
86 |
+
# if last_layer is not None:
|
87 |
+
# nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
88 |
+
# g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
89 |
+
# else:
|
90 |
+
# nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
91 |
+
# g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
92 |
+
|
93 |
+
# d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
94 |
+
# d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
95 |
+
# d_weight = d_weight * self.discriminator_weight
|
96 |
+
# return d_weight
|
97 |
+
|
98 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
99 |
+
# if last_layer is not None:
|
100 |
+
# nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
101 |
+
# g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
102 |
+
# else:
|
103 |
+
# nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
104 |
+
# g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
105 |
+
|
106 |
+
# d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
107 |
+
# d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
108 |
+
d_weight = 1.0 * self.discriminator_weight
|
109 |
+
return d_weight
|
110 |
+
|
111 |
+
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
112 |
+
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
113 |
+
if not exists(codebook_loss):
|
114 |
+
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
115 |
+
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
116 |
+
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
117 |
+
if self.perceptual_weight > 0:
|
118 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
119 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
120 |
+
else:
|
121 |
+
p_loss = torch.tensor([0.0])
|
122 |
+
|
123 |
+
nll_loss = rec_loss
|
124 |
+
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
125 |
+
nll_loss = torch.mean(nll_loss)
|
126 |
+
|
127 |
+
# now the GAN part
|
128 |
+
if optimizer_idx == 0:
|
129 |
+
# generator update
|
130 |
+
if cond is None:
|
131 |
+
assert not self.disc_conditional
|
132 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
133 |
+
else:
|
134 |
+
assert self.disc_conditional
|
135 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
136 |
+
g_loss = -torch.mean(logits_fake)
|
137 |
+
|
138 |
+
try:
|
139 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
140 |
+
except RuntimeError:
|
141 |
+
assert not self.training
|
142 |
+
d_weight = torch.tensor(0.0)
|
143 |
+
|
144 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
145 |
+
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
146 |
+
|
147 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
148 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
149 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
150 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
151 |
+
"{}/p_loss".format(split): p_loss.detach().mean(),
|
152 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
153 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
154 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
155 |
+
}
|
156 |
+
if predicted_indices is not None:
|
157 |
+
assert self.n_classes is not None
|
158 |
+
with torch.no_grad():
|
159 |
+
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
160 |
+
log[f"{split}/perplexity"] = perplexity
|
161 |
+
log[f"{split}/cluster_usage"] = cluster_usage
|
162 |
+
return loss, log
|
163 |
+
|
164 |
+
if optimizer_idx == 1:
|
165 |
+
# second pass for discriminator update
|
166 |
+
if cond is None:
|
167 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
168 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
169 |
+
else:
|
170 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
171 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
172 |
+
|
173 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
174 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
175 |
+
|
176 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
177 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
178 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
179 |
+
}
|
180 |
+
return d_loss, log
|
models/shared.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Shared architecture blocks."""
|
10 |
+
|
11 |
+
from typing import Callable
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from ADD.th_utils.ops import bias_act
|
18 |
+
|
19 |
+
|
20 |
+
class ResidualBlock(nn.Module):
|
21 |
+
def __init__(self, fn: Callable):
|
22 |
+
super().__init__()
|
23 |
+
self.fn = fn
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
return (self.fn(x) + x) / np.sqrt(2)
|
27 |
+
|
28 |
+
|
29 |
+
class FullyConnectedLayer(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
in_features: int, # Number of input features.
|
33 |
+
out_features: int, # Number of output features.
|
34 |
+
bias: bool = True, # Apply additive bias before the activation function?
|
35 |
+
activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
36 |
+
lr_multiplier: float = 1.0, # Learning rate multiplier.
|
37 |
+
weight_init: float = 1.0, # Initial standard deviation of the weight tensor.
|
38 |
+
bias_init: float = 0.0, # Initial value for the additive bias.
|
39 |
+
):
|
40 |
+
|
41 |
+
super().__init__()
|
42 |
+
self.in_features = in_features
|
43 |
+
self.out_features = out_features
|
44 |
+
self.activation = activation
|
45 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
|
46 |
+
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
|
47 |
+
self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
|
48 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
49 |
+
self.bias_gain = lr_multiplier
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
53 |
+
b = self.bias
|
54 |
+
if b is not None:
|
55 |
+
b = b.to(x.dtype)
|
56 |
+
if self.bias_gain != 1:
|
57 |
+
b = b * self.bias_gain
|
58 |
+
|
59 |
+
if self.activation == 'linear' and b is not None:
|
60 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
61 |
+
else:
|
62 |
+
x = x.matmul(w.t())
|
63 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def extra_repr(self) -> str:
|
67 |
+
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
|
68 |
+
|
69 |
+
|
70 |
+
class MLP(nn.Module):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
features_list: list[int], # Number of features in each layer of the MLP.
|
74 |
+
activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
75 |
+
lr_multiplier: float = 1.0, # Learning rate multiplier.
|
76 |
+
linear_out: bool = False # Use the 'linear' activation function for the output layer?
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
num_layers = len(features_list) - 1
|
80 |
+
self.num_layers = num_layers
|
81 |
+
self.out_dim = features_list[-1]
|
82 |
+
|
83 |
+
for idx in range(num_layers):
|
84 |
+
in_features = features_list[idx]
|
85 |
+
out_features = features_list[idx + 1]
|
86 |
+
if linear_out and idx == num_layers-1:
|
87 |
+
activation = 'linear'
|
88 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
89 |
+
setattr(self, f'fc{idx}', layer)
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
''' if x is sequence of tokens, shift tokens to batch and apply MLP to all'''
|
93 |
+
shift2batch = (x.ndim == 3)
|
94 |
+
|
95 |
+
if shift2batch:
|
96 |
+
B, K, C = x.shape
|
97 |
+
x = x.flatten(0,1)
|
98 |
+
|
99 |
+
for idx in range(self.num_layers):
|
100 |
+
layer = getattr(self, f'fc{idx}')
|
101 |
+
x = layer(x)
|
102 |
+
|
103 |
+
if shift2batch:
|
104 |
+
x = x.reshape(B, K, -1)
|
105 |
+
|
106 |
+
return x
|
models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/unet_2d_condition.py
ADDED
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.activations import get_activation
|
25 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
26 |
+
from diffusers.models.embeddings import (
|
27 |
+
GaussianFourierProjection,
|
28 |
+
ImageHintTimeEmbedding,
|
29 |
+
ImageProjection,
|
30 |
+
ImageTimeEmbedding,
|
31 |
+
PositionNet,
|
32 |
+
TextImageProjection,
|
33 |
+
TextImageTimeEmbedding,
|
34 |
+
TextTimeEmbedding,
|
35 |
+
TimestepEmbedding,
|
36 |
+
Timesteps,
|
37 |
+
)
|
38 |
+
from diffusers.models.modeling_utils import ModelMixin
|
39 |
+
from .unet_2d_blocks import (
|
40 |
+
UNetMidBlock2DCrossAttn,
|
41 |
+
UNetMidBlock2DSimpleCrossAttn,
|
42 |
+
get_down_block,
|
43 |
+
get_up_block,
|
44 |
+
)
|
45 |
+
|
46 |
+
import os, json
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class UNet2DConditionOutput(BaseOutput):
|
54 |
+
"""
|
55 |
+
The output of [`UNet2DConditionModel`].
|
56 |
+
|
57 |
+
Args:
|
58 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
59 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
60 |
+
"""
|
61 |
+
|
62 |
+
sample: torch.FloatTensor = None
|
63 |
+
|
64 |
+
|
65 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
66 |
+
r"""
|
67 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
68 |
+
shaped output.
|
69 |
+
|
70 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
71 |
+
for all models (such as downloading or saving).
|
72 |
+
|
73 |
+
Parameters:
|
74 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
75 |
+
Height and width of input/output sample.
|
76 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
77 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
78 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
79 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to flip the sin to cos in the time embedding.
|
81 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
82 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
83 |
+
The tuple of downsample blocks to use.
|
84 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
85 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
86 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
87 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
88 |
+
The tuple of upsample blocks to use.
|
89 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
90 |
+
Whether to include self-attention in the basic transformer blocks, see
|
91 |
+
[`~models.attention.BasicTransformerBlock`].
|
92 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
93 |
+
The tuple of output channels for each block.
|
94 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
95 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
96 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
97 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
98 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
99 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
100 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
101 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
102 |
+
The dimension of the cross attention features.
|
103 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
104 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
105 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
106 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
107 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
108 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
109 |
+
dimension to `cross_attention_dim`.
|
110 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
111 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
112 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
113 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
114 |
+
num_attention_heads (`int`, *optional*):
|
115 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
116 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
117 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
118 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
119 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
120 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
121 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
122 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
123 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
124 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
125 |
+
Dimension for the timestep embeddings.
|
126 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
127 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
128 |
+
class conditioning with `class_embed_type` equal to `None`.
|
129 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
130 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
131 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
132 |
+
An optional override for the dimension of the projected time embedding.
|
133 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
134 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
135 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
136 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
137 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
138 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
139 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
140 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
141 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
142 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
143 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
144 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
145 |
+
embeddings with the class embeddings.
|
146 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
147 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
148 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
149 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
150 |
+
otherwise.
|
151 |
+
"""
|
152 |
+
|
153 |
+
_supports_gradient_checkpointing = True
|
154 |
+
|
155 |
+
@register_to_config
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
sample_size: Optional[int] = None,
|
159 |
+
in_channels: int = 4,
|
160 |
+
out_channels: int = 4,
|
161 |
+
center_input_sample: bool = False,
|
162 |
+
flip_sin_to_cos: bool = True,
|
163 |
+
freq_shift: int = 0,
|
164 |
+
down_block_types: Tuple[str] = (
|
165 |
+
"CrossAttnDownBlock2D",
|
166 |
+
"CrossAttnDownBlock2D",
|
167 |
+
"CrossAttnDownBlock2D",
|
168 |
+
"DownBlock2D",
|
169 |
+
),
|
170 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
171 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
172 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
173 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
174 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
175 |
+
downsample_padding: int = 1,
|
176 |
+
mid_block_scale_factor: float = 1,
|
177 |
+
act_fn: str = "silu",
|
178 |
+
norm_num_groups: Optional[int] = 32,
|
179 |
+
norm_eps: float = 1e-5,
|
180 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
181 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
182 |
+
encoder_hid_dim: Optional[int] = None,
|
183 |
+
encoder_hid_dim_type: Optional[str] = None,
|
184 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
185 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
186 |
+
dual_cross_attention: bool = False,
|
187 |
+
use_linear_projection: bool = False,
|
188 |
+
class_embed_type: Optional[str] = None,
|
189 |
+
addition_embed_type: Optional[str] = None,
|
190 |
+
addition_time_embed_dim: Optional[int] = None,
|
191 |
+
num_class_embeds: Optional[int] = None,
|
192 |
+
upcast_attention: bool = False,
|
193 |
+
resnet_time_scale_shift: str = "default",
|
194 |
+
resnet_skip_time_act: bool = False,
|
195 |
+
resnet_out_scale_factor: int = 1.0,
|
196 |
+
time_embedding_type: str = "positional",
|
197 |
+
time_embedding_dim: Optional[int] = None,
|
198 |
+
time_embedding_act_fn: Optional[str] = None,
|
199 |
+
timestep_post_act: Optional[str] = None,
|
200 |
+
time_cond_proj_dim: Optional[int] = None,
|
201 |
+
conv_in_kernel: int = 3,
|
202 |
+
conv_out_kernel: int = 3,
|
203 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
204 |
+
attention_type: str = "default",
|
205 |
+
class_embeddings_concat: bool = False,
|
206 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
207 |
+
cross_attention_norm: Optional[str] = None,
|
208 |
+
addition_embed_type_num_heads=64,
|
209 |
+
):
|
210 |
+
super().__init__()
|
211 |
+
|
212 |
+
self.sample_size = sample_size
|
213 |
+
|
214 |
+
if num_attention_heads is not None:
|
215 |
+
raise ValueError(
|
216 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
217 |
+
)
|
218 |
+
|
219 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
220 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
221 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
222 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
223 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
224 |
+
# which is why we correct for the naming here.
|
225 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
226 |
+
|
227 |
+
# Check inputs
|
228 |
+
if len(down_block_types) != len(up_block_types):
|
229 |
+
raise ValueError(
|
230 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
231 |
+
)
|
232 |
+
|
233 |
+
if len(block_out_channels) != len(down_block_types):
|
234 |
+
raise ValueError(
|
235 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
236 |
+
)
|
237 |
+
|
238 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
239 |
+
raise ValueError(
|
240 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
241 |
+
)
|
242 |
+
|
243 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
244 |
+
raise ValueError(
|
245 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
246 |
+
)
|
247 |
+
|
248 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
249 |
+
raise ValueError(
|
250 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
251 |
+
)
|
252 |
+
|
253 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
254 |
+
raise ValueError(
|
255 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
256 |
+
)
|
257 |
+
|
258 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
259 |
+
raise ValueError(
|
260 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
261 |
+
)
|
262 |
+
|
263 |
+
# input
|
264 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
265 |
+
self.conv_in = nn.Conv2d(
|
266 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
267 |
+
)
|
268 |
+
|
269 |
+
# time
|
270 |
+
if time_embedding_type == "fourier":
|
271 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
272 |
+
if time_embed_dim % 2 != 0:
|
273 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
274 |
+
self.time_proj = GaussianFourierProjection(
|
275 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
276 |
+
)
|
277 |
+
timestep_input_dim = time_embed_dim
|
278 |
+
elif time_embedding_type == "positional":
|
279 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
280 |
+
|
281 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
282 |
+
timestep_input_dim = block_out_channels[0]
|
283 |
+
else:
|
284 |
+
raise ValueError(
|
285 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
286 |
+
)
|
287 |
+
|
288 |
+
self.time_embedding = TimestepEmbedding(
|
289 |
+
timestep_input_dim,
|
290 |
+
time_embed_dim,
|
291 |
+
act_fn=act_fn,
|
292 |
+
post_act_fn=timestep_post_act,
|
293 |
+
cond_proj_dim=time_cond_proj_dim,
|
294 |
+
)
|
295 |
+
|
296 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
297 |
+
encoder_hid_dim_type = "text_proj"
|
298 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
299 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
300 |
+
|
301 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
302 |
+
raise ValueError(
|
303 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
304 |
+
)
|
305 |
+
|
306 |
+
if encoder_hid_dim_type == "text_proj":
|
307 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
308 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
309 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
310 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
311 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
312 |
+
self.encoder_hid_proj = TextImageProjection(
|
313 |
+
text_embed_dim=encoder_hid_dim,
|
314 |
+
image_embed_dim=cross_attention_dim,
|
315 |
+
cross_attention_dim=cross_attention_dim,
|
316 |
+
)
|
317 |
+
elif encoder_hid_dim_type == "image_proj":
|
318 |
+
# Kandinsky 2.2
|
319 |
+
self.encoder_hid_proj = ImageProjection(
|
320 |
+
image_embed_dim=encoder_hid_dim,
|
321 |
+
cross_attention_dim=cross_attention_dim,
|
322 |
+
)
|
323 |
+
elif encoder_hid_dim_type is not None:
|
324 |
+
raise ValueError(
|
325 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
self.encoder_hid_proj = None
|
329 |
+
|
330 |
+
# class embedding
|
331 |
+
if class_embed_type is None and num_class_embeds is not None:
|
332 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
333 |
+
elif class_embed_type == "timestep":
|
334 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
335 |
+
elif class_embed_type == "identity":
|
336 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
337 |
+
elif class_embed_type == "projection":
|
338 |
+
if projection_class_embeddings_input_dim is None:
|
339 |
+
raise ValueError(
|
340 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
341 |
+
)
|
342 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
343 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
344 |
+
# 2. it projects from an arbitrary input dimension.
|
345 |
+
#
|
346 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
347 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
348 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
349 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
350 |
+
elif class_embed_type == "simple_projection":
|
351 |
+
if projection_class_embeddings_input_dim is None:
|
352 |
+
raise ValueError(
|
353 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
354 |
+
)
|
355 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
356 |
+
else:
|
357 |
+
self.class_embedding = None
|
358 |
+
|
359 |
+
if addition_embed_type == "text":
|
360 |
+
if encoder_hid_dim is not None:
|
361 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
362 |
+
else:
|
363 |
+
text_time_embedding_from_dim = cross_attention_dim
|
364 |
+
|
365 |
+
self.add_embedding = TextTimeEmbedding(
|
366 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
367 |
+
)
|
368 |
+
elif addition_embed_type == "text_image":
|
369 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
370 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
371 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
372 |
+
self.add_embedding = TextImageTimeEmbedding(
|
373 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
374 |
+
)
|
375 |
+
elif addition_embed_type == "text_time":
|
376 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
377 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
378 |
+
elif addition_embed_type == "image":
|
379 |
+
# Kandinsky 2.2
|
380 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
381 |
+
elif addition_embed_type == "image_hint":
|
382 |
+
# Kandinsky 2.2 ControlNet
|
383 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
384 |
+
elif addition_embed_type is not None:
|
385 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
386 |
+
|
387 |
+
if time_embedding_act_fn is None:
|
388 |
+
self.time_embed_act = None
|
389 |
+
else:
|
390 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
391 |
+
|
392 |
+
self.down_blocks = nn.ModuleList([])
|
393 |
+
self.up_blocks = nn.ModuleList([])
|
394 |
+
|
395 |
+
if isinstance(only_cross_attention, bool):
|
396 |
+
if mid_block_only_cross_attention is None:
|
397 |
+
mid_block_only_cross_attention = only_cross_attention
|
398 |
+
|
399 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
400 |
+
|
401 |
+
if mid_block_only_cross_attention is None:
|
402 |
+
mid_block_only_cross_attention = False
|
403 |
+
|
404 |
+
if isinstance(num_attention_heads, int):
|
405 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
406 |
+
|
407 |
+
if isinstance(attention_head_dim, int):
|
408 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
409 |
+
|
410 |
+
if isinstance(cross_attention_dim, int):
|
411 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
412 |
+
|
413 |
+
if isinstance(layers_per_block, int):
|
414 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
415 |
+
|
416 |
+
if isinstance(transformer_layers_per_block, int):
|
417 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
418 |
+
|
419 |
+
if class_embeddings_concat:
|
420 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
421 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
422 |
+
# regular time embeddings
|
423 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
424 |
+
else:
|
425 |
+
blocks_time_embed_dim = time_embed_dim
|
426 |
+
|
427 |
+
# down
|
428 |
+
output_channel = block_out_channels[0]
|
429 |
+
for i, down_block_type in enumerate(down_block_types):
|
430 |
+
input_channel = output_channel
|
431 |
+
output_channel = block_out_channels[i]
|
432 |
+
is_final_block = i == len(block_out_channels) - 1
|
433 |
+
|
434 |
+
down_block = get_down_block(
|
435 |
+
down_block_type,
|
436 |
+
num_layers=layers_per_block[i],
|
437 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
438 |
+
in_channels=input_channel,
|
439 |
+
out_channels=output_channel,
|
440 |
+
temb_channels=blocks_time_embed_dim,
|
441 |
+
add_downsample=not is_final_block,
|
442 |
+
resnet_eps=norm_eps,
|
443 |
+
resnet_act_fn=act_fn,
|
444 |
+
resnet_groups=norm_num_groups,
|
445 |
+
cross_attention_dim=cross_attention_dim[i],
|
446 |
+
num_attention_heads=num_attention_heads[i],
|
447 |
+
downsample_padding=downsample_padding,
|
448 |
+
dual_cross_attention=dual_cross_attention,
|
449 |
+
use_linear_projection=use_linear_projection,
|
450 |
+
only_cross_attention=only_cross_attention[i],
|
451 |
+
upcast_attention=upcast_attention,
|
452 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
453 |
+
attention_type=attention_type,
|
454 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
455 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
456 |
+
cross_attention_norm=cross_attention_norm,
|
457 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
458 |
+
)
|
459 |
+
self.down_blocks.append(down_block)
|
460 |
+
|
461 |
+
# mid
|
462 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
463 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
464 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
465 |
+
in_channels=block_out_channels[-1],
|
466 |
+
temb_channels=blocks_time_embed_dim,
|
467 |
+
resnet_eps=norm_eps,
|
468 |
+
resnet_act_fn=act_fn,
|
469 |
+
output_scale_factor=mid_block_scale_factor,
|
470 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
471 |
+
cross_attention_dim=cross_attention_dim[-1],
|
472 |
+
num_attention_heads=num_attention_heads[-1],
|
473 |
+
resnet_groups=norm_num_groups,
|
474 |
+
dual_cross_attention=dual_cross_attention,
|
475 |
+
use_linear_projection=use_linear_projection,
|
476 |
+
upcast_attention=upcast_attention,
|
477 |
+
attention_type=attention_type,
|
478 |
+
)
|
479 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
480 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
481 |
+
in_channels=block_out_channels[-1],
|
482 |
+
temb_channels=blocks_time_embed_dim,
|
483 |
+
resnet_eps=norm_eps,
|
484 |
+
resnet_act_fn=act_fn,
|
485 |
+
output_scale_factor=mid_block_scale_factor,
|
486 |
+
cross_attention_dim=cross_attention_dim[-1],
|
487 |
+
attention_head_dim=attention_head_dim[-1],
|
488 |
+
resnet_groups=norm_num_groups,
|
489 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
490 |
+
skip_time_act=resnet_skip_time_act,
|
491 |
+
only_cross_attention=mid_block_only_cross_attention,
|
492 |
+
cross_attention_norm=cross_attention_norm,
|
493 |
+
)
|
494 |
+
elif mid_block_type is None:
|
495 |
+
self.mid_block = None
|
496 |
+
else:
|
497 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
498 |
+
|
499 |
+
# count how many layers upsample the images
|
500 |
+
self.num_upsamplers = 0
|
501 |
+
|
502 |
+
# up
|
503 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
504 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
505 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
506 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
507 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
508 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
509 |
+
|
510 |
+
output_channel = reversed_block_out_channels[0]
|
511 |
+
for i, up_block_type in enumerate(up_block_types):
|
512 |
+
is_final_block = i == len(block_out_channels) - 1
|
513 |
+
|
514 |
+
prev_output_channel = output_channel
|
515 |
+
output_channel = reversed_block_out_channels[i]
|
516 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
517 |
+
|
518 |
+
# add upsample block for all BUT final layer
|
519 |
+
if not is_final_block:
|
520 |
+
add_upsample = True
|
521 |
+
self.num_upsamplers += 1
|
522 |
+
else:
|
523 |
+
add_upsample = False
|
524 |
+
|
525 |
+
up_block = get_up_block(
|
526 |
+
up_block_type,
|
527 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
528 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
529 |
+
in_channels=input_channel,
|
530 |
+
out_channels=output_channel,
|
531 |
+
prev_output_channel=prev_output_channel,
|
532 |
+
temb_channels=blocks_time_embed_dim,
|
533 |
+
add_upsample=add_upsample,
|
534 |
+
resnet_eps=norm_eps,
|
535 |
+
resnet_act_fn=act_fn,
|
536 |
+
resnet_groups=norm_num_groups,
|
537 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
538 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
539 |
+
dual_cross_attention=dual_cross_attention,
|
540 |
+
use_linear_projection=use_linear_projection,
|
541 |
+
only_cross_attention=only_cross_attention[i],
|
542 |
+
upcast_attention=upcast_attention,
|
543 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
544 |
+
attention_type=attention_type,
|
545 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
546 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
547 |
+
cross_attention_norm=cross_attention_norm,
|
548 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
549 |
+
)
|
550 |
+
self.up_blocks.append(up_block)
|
551 |
+
prev_output_channel = output_channel
|
552 |
+
|
553 |
+
# out
|
554 |
+
if norm_num_groups is not None:
|
555 |
+
self.conv_norm_out = nn.GroupNorm(
|
556 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
557 |
+
)
|
558 |
+
|
559 |
+
self.conv_act = get_activation(act_fn)
|
560 |
+
|
561 |
+
else:
|
562 |
+
self.conv_norm_out = None
|
563 |
+
self.conv_act = None
|
564 |
+
|
565 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
566 |
+
self.conv_out = nn.Conv2d(
|
567 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
568 |
+
)
|
569 |
+
|
570 |
+
if attention_type == "gated":
|
571 |
+
positive_len = 768
|
572 |
+
if isinstance(cross_attention_dim, int):
|
573 |
+
positive_len = cross_attention_dim
|
574 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
575 |
+
positive_len = cross_attention_dim[0]
|
576 |
+
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
|
577 |
+
|
578 |
+
@property
|
579 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
580 |
+
r"""
|
581 |
+
Returns:
|
582 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
583 |
+
indexed by its weight name.
|
584 |
+
"""
|
585 |
+
# set recursively
|
586 |
+
processors = {}
|
587 |
+
|
588 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
589 |
+
if hasattr(module, "get_processor"):
|
590 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
591 |
+
|
592 |
+
for sub_name, child in module.named_children():
|
593 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
594 |
+
|
595 |
+
return processors
|
596 |
+
|
597 |
+
for name, module in self.named_children():
|
598 |
+
fn_recursive_add_processors(name, module, processors)
|
599 |
+
|
600 |
+
return processors
|
601 |
+
|
602 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
603 |
+
r"""
|
604 |
+
Sets the attention processor to use to compute attention.
|
605 |
+
|
606 |
+
Parameters:
|
607 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
608 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
609 |
+
for **all** `Attention` layers.
|
610 |
+
|
611 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
612 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
613 |
+
|
614 |
+
"""
|
615 |
+
count = len(self.attn_processors.keys())
|
616 |
+
|
617 |
+
if isinstance(processor, dict) and len(processor) != count:
|
618 |
+
raise ValueError(
|
619 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
620 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
621 |
+
)
|
622 |
+
|
623 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
624 |
+
if hasattr(module, "set_processor"):
|
625 |
+
if not isinstance(processor, dict):
|
626 |
+
module.set_processor(processor)
|
627 |
+
else:
|
628 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
629 |
+
|
630 |
+
for sub_name, child in module.named_children():
|
631 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
632 |
+
|
633 |
+
for name, module in self.named_children():
|
634 |
+
fn_recursive_attn_processor(name, module, processor)
|
635 |
+
|
636 |
+
def set_default_attn_processor(self):
|
637 |
+
"""
|
638 |
+
Disables custom attention processors and sets the default attention implementation.
|
639 |
+
"""
|
640 |
+
self.set_attn_processor(AttnProcessor())
|
641 |
+
|
642 |
+
def set_attention_slice(self, slice_size):
|
643 |
+
r"""
|
644 |
+
Enable sliced attention computation.
|
645 |
+
|
646 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
647 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
651 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
652 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
653 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
654 |
+
must be a multiple of `slice_size`.
|
655 |
+
"""
|
656 |
+
sliceable_head_dims = []
|
657 |
+
|
658 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
659 |
+
if hasattr(module, "set_attention_slice"):
|
660 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
661 |
+
|
662 |
+
for child in module.children():
|
663 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
664 |
+
|
665 |
+
# retrieve number of attention layers
|
666 |
+
for module in self.children():
|
667 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
668 |
+
|
669 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
670 |
+
|
671 |
+
if slice_size == "auto":
|
672 |
+
# half the attention head size is usually a good trade-off between
|
673 |
+
# speed and memory
|
674 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
675 |
+
elif slice_size == "max":
|
676 |
+
# make smallest slice possible
|
677 |
+
slice_size = num_sliceable_layers * [1]
|
678 |
+
|
679 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
680 |
+
|
681 |
+
if len(slice_size) != len(sliceable_head_dims):
|
682 |
+
raise ValueError(
|
683 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
684 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
685 |
+
)
|
686 |
+
|
687 |
+
for i in range(len(slice_size)):
|
688 |
+
size = slice_size[i]
|
689 |
+
dim = sliceable_head_dims[i]
|
690 |
+
if size is not None and size > dim:
|
691 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
692 |
+
|
693 |
+
# Recursively walk through all the children.
|
694 |
+
# Any children which exposes the set_attention_slice method
|
695 |
+
# gets the message
|
696 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
697 |
+
if hasattr(module, "set_attention_slice"):
|
698 |
+
module.set_attention_slice(slice_size.pop())
|
699 |
+
|
700 |
+
for child in module.children():
|
701 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
702 |
+
|
703 |
+
reversed_slice_size = list(reversed(slice_size))
|
704 |
+
for module in self.children():
|
705 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
706 |
+
|
707 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
708 |
+
if hasattr(module, "gradient_checkpointing"):
|
709 |
+
module.gradient_checkpointing = value
|
710 |
+
|
711 |
+
def forward(
|
712 |
+
self,
|
713 |
+
sample: torch.FloatTensor,
|
714 |
+
timestep: Union[torch.Tensor, float, int],
|
715 |
+
encoder_hidden_states: torch.Tensor,
|
716 |
+
class_labels: Optional[torch.Tensor] = None,
|
717 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
718 |
+
attention_mask: Optional[torch.Tensor] = None,
|
719 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
720 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
721 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
722 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
723 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
724 |
+
return_dict: bool = True,
|
725 |
+
image_encoder_hidden_states: torch.Tensor = None,
|
726 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
727 |
+
r"""
|
728 |
+
The [`UNet2DConditionModel`] forward method.
|
729 |
+
|
730 |
+
Args:
|
731 |
+
sample (`torch.FloatTensor`):
|
732 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
733 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
734 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
735 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
736 |
+
encoder_attention_mask (`torch.Tensor`):
|
737 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
738 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
739 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
740 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
741 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
742 |
+
tuple.
|
743 |
+
cross_attention_kwargs (`dict`, *optional*):
|
744 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
745 |
+
added_cond_kwargs: (`dict`, *optional*):
|
746 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
747 |
+
are passed along to the UNet blocks.
|
748 |
+
|
749 |
+
Returns:
|
750 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
751 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
752 |
+
a `tuple` is returned where the first element is the sample tensor.
|
753 |
+
"""
|
754 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
755 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
756 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
757 |
+
# on the fly if necessary.
|
758 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
759 |
+
|
760 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
761 |
+
forward_upsample_size = False
|
762 |
+
upsample_size = None
|
763 |
+
|
764 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
765 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
766 |
+
forward_upsample_size = True
|
767 |
+
|
768 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
769 |
+
# expects mask of shape:
|
770 |
+
# [batch, key_tokens]
|
771 |
+
# adds singleton query_tokens dimension:
|
772 |
+
# [batch, 1, key_tokens]
|
773 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
774 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
775 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
776 |
+
if attention_mask is not None:
|
777 |
+
# assume that mask is expressed as:
|
778 |
+
# (1 = keep, 0 = discard)
|
779 |
+
# convert mask into a bias that can be added to attention scores:
|
780 |
+
# (keep = +0, discard = -10000.0)
|
781 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
782 |
+
attention_mask = attention_mask.unsqueeze(1)
|
783 |
+
|
784 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
785 |
+
if encoder_attention_mask is not None:
|
786 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
787 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
788 |
+
|
789 |
+
# 0. center input if necessary
|
790 |
+
if self.config.center_input_sample:
|
791 |
+
sample = 2 * sample - 1.0
|
792 |
+
|
793 |
+
# 1. time
|
794 |
+
timesteps = timestep
|
795 |
+
if not torch.is_tensor(timesteps):
|
796 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
797 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
798 |
+
is_mps = sample.device.type == "mps"
|
799 |
+
if isinstance(timestep, float):
|
800 |
+
dtype = torch.float32 if is_mps else torch.float64
|
801 |
+
else:
|
802 |
+
dtype = torch.int32 if is_mps else torch.int64
|
803 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
804 |
+
elif len(timesteps.shape) == 0:
|
805 |
+
timesteps = timesteps[None].to(sample.device)
|
806 |
+
|
807 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
808 |
+
timesteps = timesteps.expand(sample.shape[0])
|
809 |
+
|
810 |
+
t_emb = self.time_proj(timesteps)
|
811 |
+
|
812 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
813 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
814 |
+
# there might be better ways to encapsulate this.
|
815 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
816 |
+
|
817 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
818 |
+
aug_emb = None
|
819 |
+
|
820 |
+
if self.class_embedding is not None:
|
821 |
+
if class_labels is None:
|
822 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
823 |
+
|
824 |
+
if self.config.class_embed_type == "timestep":
|
825 |
+
class_labels = self.time_proj(class_labels)
|
826 |
+
|
827 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
828 |
+
# there might be better ways to encapsulate this.
|
829 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
830 |
+
|
831 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
832 |
+
|
833 |
+
if self.config.class_embeddings_concat:
|
834 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
835 |
+
else:
|
836 |
+
emb = emb + class_emb
|
837 |
+
|
838 |
+
if self.config.addition_embed_type == "text":
|
839 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
840 |
+
elif self.config.addition_embed_type == "text_image":
|
841 |
+
# Kandinsky 2.1 - style
|
842 |
+
if "image_embeds" not in added_cond_kwargs:
|
843 |
+
raise ValueError(
|
844 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
845 |
+
)
|
846 |
+
|
847 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
848 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
849 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
850 |
+
elif self.config.addition_embed_type == "text_time":
|
851 |
+
# SDXL - style
|
852 |
+
if "text_embeds" not in added_cond_kwargs:
|
853 |
+
raise ValueError(
|
854 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
855 |
+
)
|
856 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
857 |
+
if "time_ids" not in added_cond_kwargs:
|
858 |
+
raise ValueError(
|
859 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
860 |
+
)
|
861 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
862 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
863 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
864 |
+
|
865 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
866 |
+
add_embeds = add_embeds.to(emb.dtype)
|
867 |
+
aug_emb = self.add_embedding(add_embeds)
|
868 |
+
elif self.config.addition_embed_type == "image":
|
869 |
+
# Kandinsky 2.2 - style
|
870 |
+
if "image_embeds" not in added_cond_kwargs:
|
871 |
+
raise ValueError(
|
872 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
873 |
+
)
|
874 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
875 |
+
aug_emb = self.add_embedding(image_embs)
|
876 |
+
elif self.config.addition_embed_type == "image_hint":
|
877 |
+
# Kandinsky 2.2 - style
|
878 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
879 |
+
raise ValueError(
|
880 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
881 |
+
)
|
882 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
883 |
+
hint = added_cond_kwargs.get("hint")
|
884 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
885 |
+
sample = torch.cat([sample, hint], dim=1)
|
886 |
+
|
887 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
888 |
+
|
889 |
+
if self.time_embed_act is not None:
|
890 |
+
emb = self.time_embed_act(emb)
|
891 |
+
|
892 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
893 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
894 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
895 |
+
# Kadinsky 2.1 - style
|
896 |
+
if "image_embeds" not in added_cond_kwargs:
|
897 |
+
raise ValueError(
|
898 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
899 |
+
)
|
900 |
+
|
901 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
902 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
903 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
904 |
+
# Kandinsky 2.2 - style
|
905 |
+
if "image_embeds" not in added_cond_kwargs:
|
906 |
+
raise ValueError(
|
907 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
908 |
+
)
|
909 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
910 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
911 |
+
# 2. pre-process
|
912 |
+
sample = self.conv_in(sample)
|
913 |
+
|
914 |
+
# 2.5 GLIGEN position net
|
915 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
916 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
917 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
918 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
919 |
+
|
920 |
+
# 3. down
|
921 |
+
|
922 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
923 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
924 |
+
|
925 |
+
down_block_res_samples = (sample,)
|
926 |
+
for downsample_block in self.down_blocks:
|
927 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
928 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
929 |
+
additional_residuals = {}
|
930 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
931 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
932 |
+
|
933 |
+
sample, res_samples = downsample_block(
|
934 |
+
hidden_states=sample,
|
935 |
+
temb=emb,
|
936 |
+
encoder_hidden_states=encoder_hidden_states,
|
937 |
+
attention_mask=attention_mask,
|
938 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
939 |
+
encoder_attention_mask=encoder_attention_mask,
|
940 |
+
image_encoder_hidden_states=image_encoder_hidden_states,
|
941 |
+
**additional_residuals,
|
942 |
+
)
|
943 |
+
else:
|
944 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
945 |
+
|
946 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
947 |
+
sample += down_block_additional_residuals.pop(0)
|
948 |
+
|
949 |
+
down_block_res_samples += res_samples
|
950 |
+
|
951 |
+
if is_controlnet:
|
952 |
+
new_down_block_res_samples = ()
|
953 |
+
|
954 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
955 |
+
down_block_res_samples, down_block_additional_residuals
|
956 |
+
):
|
957 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
958 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
959 |
+
|
960 |
+
down_block_res_samples = new_down_block_res_samples
|
961 |
+
|
962 |
+
# 4. mid
|
963 |
+
if self.mid_block is not None:
|
964 |
+
sample = self.mid_block(
|
965 |
+
sample,
|
966 |
+
emb,
|
967 |
+
encoder_hidden_states=encoder_hidden_states,
|
968 |
+
attention_mask=attention_mask,
|
969 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
970 |
+
encoder_attention_mask=encoder_attention_mask,
|
971 |
+
image_encoder_hidden_states=image_encoder_hidden_states,
|
972 |
+
)
|
973 |
+
# To support T2I-Adapter-XL
|
974 |
+
if (
|
975 |
+
is_adapter
|
976 |
+
and len(down_block_additional_residuals) > 0
|
977 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
978 |
+
):
|
979 |
+
sample += down_block_additional_residuals.pop(0)
|
980 |
+
|
981 |
+
if is_controlnet:
|
982 |
+
sample = sample + mid_block_additional_residual
|
983 |
+
|
984 |
+
# 5. up
|
985 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
986 |
+
is_final_block = i == len(self.up_blocks) - 1
|
987 |
+
|
988 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
989 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
990 |
+
|
991 |
+
# if we have not reached the final block and need to forward the
|
992 |
+
# upsample size, we do it here
|
993 |
+
if not is_final_block and forward_upsample_size:
|
994 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
995 |
+
|
996 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
997 |
+
sample = upsample_block(
|
998 |
+
hidden_states=sample,
|
999 |
+
temb=emb,
|
1000 |
+
res_hidden_states_tuple=res_samples,
|
1001 |
+
encoder_hidden_states=encoder_hidden_states,
|
1002 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1003 |
+
upsample_size=upsample_size,
|
1004 |
+
attention_mask=attention_mask,
|
1005 |
+
encoder_attention_mask=encoder_attention_mask,
|
1006 |
+
image_encoder_hidden_states=image_encoder_hidden_states,
|
1007 |
+
)
|
1008 |
+
else:
|
1009 |
+
sample = upsample_block(
|
1010 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
# 6. post-process
|
1014 |
+
if self.conv_norm_out:
|
1015 |
+
sample = self.conv_norm_out(sample)
|
1016 |
+
sample = self.conv_act(sample)
|
1017 |
+
sample = self.conv_out(sample)
|
1018 |
+
|
1019 |
+
if not return_dict:
|
1020 |
+
return (sample,)
|
1021 |
+
|
1022 |
+
return UNet2DConditionOutput(sample=sample)
|
1023 |
+
|
1024 |
+
@classmethod
|
1025 |
+
def from_pretrained_orig(cls, pretrained_model_path, subfolder=None, **kwargs):
|
1026 |
+
if subfolder is not None:
|
1027 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1028 |
+
|
1029 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
1030 |
+
if not os.path.isfile(config_file):
|
1031 |
+
raise RuntimeError(f"{config_file} does not exist")
|
1032 |
+
with open(config_file, "r") as f:
|
1033 |
+
config = json.load(f)
|
1034 |
+
|
1035 |
+
from diffusers.utils import WEIGHTS_NAME
|
1036 |
+
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
|
1037 |
+
|
1038 |
+
|
1039 |
+
model = cls.from_config(config)
|
1040 |
+
|
1041 |
+
## for .bin file
|
1042 |
+
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1043 |
+
# if not os.path.isfile(model_file):
|
1044 |
+
# raise RuntimeError(f"{model_file} does not exist")
|
1045 |
+
# state_dict = torch.load(model_file, map_location="cpu")
|
1046 |
+
# model.load_state_dict(state_dict, strict=False)
|
1047 |
+
|
1048 |
+
## for .safetensors file
|
1049 |
+
import safetensors
|
1050 |
+
model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
|
1051 |
+
if not os.path.isfile(model_file):
|
1052 |
+
raise RuntimeError(f"{model_file} does not exist")
|
1053 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
1054 |
+
model.load_state_dict(state_dict, strict=False)
|
1055 |
+
|
1056 |
+
return model
|
1057 |
+
|
1058 |
+
@classmethod
|
1059 |
+
def from_pretrained_safetensor(cls, pretrained_model_path, subfolder=None, **kwargs):
|
1060 |
+
if subfolder is not None:
|
1061 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1062 |
+
|
1063 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
1064 |
+
if not os.path.isfile(config_file):
|
1065 |
+
raise RuntimeError(f"{config_file} does not exist")
|
1066 |
+
with open(config_file, "r") as f:
|
1067 |
+
config = json.load(f)
|
1068 |
+
|
1069 |
+
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
|
1070 |
+
model = cls.from_config(config)
|
1071 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1072 |
+
if not os.path.isfile(model_file):
|
1073 |
+
raise RuntimeError(f"{model_file} does not exist")
|
1074 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1075 |
+
for k, v in model.state_dict().items():
|
1076 |
+
if 'attn2_plus' in k:
|
1077 |
+
print(k)
|
1078 |
+
state_dict.update({k: v})
|
1079 |
+
model.load_state_dict(state_dict, strict=False)
|
1080 |
+
|
1081 |
+
return model
|
models/vit_utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab)
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
#
|
23 |
+
# Based on code from https://github.com/isl-org/DPT
|
24 |
+
|
25 |
+
"""Flexible configuration and feature extraction of timm VisionTransformers."""
|
26 |
+
|
27 |
+
import types
|
28 |
+
import math
|
29 |
+
from typing import Callable
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
|
35 |
+
|
36 |
+
class AddReadout(nn.Module):
|
37 |
+
def __init__(self, start_index: bool = 1):
|
38 |
+
super(AddReadout, self).__init__()
|
39 |
+
self.start_index = start_index
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42 |
+
if self.start_index == 2:
|
43 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
44 |
+
else:
|
45 |
+
readout = x[:, 0]
|
46 |
+
return x[:, self.start_index:] + readout.unsqueeze(1)
|
47 |
+
|
48 |
+
|
49 |
+
class Transpose(nn.Module):
|
50 |
+
def __init__(self, dim0: int, dim1: int):
|
51 |
+
super(Transpose, self).__init__()
|
52 |
+
self.dim0 = dim0
|
53 |
+
self.dim1 = dim1
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
+
x = x.transpose(self.dim0, self.dim1)
|
57 |
+
return x.contiguous()
|
58 |
+
|
59 |
+
|
60 |
+
def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict:
|
61 |
+
_, _, H, W = x.size()
|
62 |
+
_ = pretrained.model.forward_flex(x)
|
63 |
+
return {k: pretrained.rearrange(v) for k, v in activations.items()}
|
64 |
+
|
65 |
+
|
66 |
+
def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:
|
67 |
+
posemb_tok, posemb_grid = (
|
68 |
+
posemb[:, : self.start_index],
|
69 |
+
posemb[0, self.start_index :],
|
70 |
+
)
|
71 |
+
|
72 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
73 |
+
|
74 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
75 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
|
76 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
77 |
+
|
78 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
79 |
+
|
80 |
+
return posemb
|
81 |
+
|
82 |
+
|
83 |
+
def forward_flex(self, x: torch.Tensor) -> torch.Tensor:
|
84 |
+
# patch proj and dynamically resize
|
85 |
+
B, C, H, W = x.size()
|
86 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
87 |
+
pos_embed = self._resize_pos_embed(
|
88 |
+
self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]
|
89 |
+
)
|
90 |
+
|
91 |
+
# add cls token
|
92 |
+
cls_tokens = self.cls_token.expand(
|
93 |
+
x.size(0), -1, -1
|
94 |
+
)
|
95 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
96 |
+
|
97 |
+
# forward pass
|
98 |
+
x = x + pos_embed
|
99 |
+
x = self.pos_drop(x)
|
100 |
+
|
101 |
+
for blk in self.blocks:
|
102 |
+
x = blk(x)
|
103 |
+
|
104 |
+
x = self.norm(x)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
activations = {}
|
109 |
+
|
110 |
+
|
111 |
+
def get_activation(name: str) -> Callable:
|
112 |
+
def hook(model, input, output):
|
113 |
+
activations[name] = output
|
114 |
+
return hook
|
115 |
+
|
116 |
+
|
117 |
+
def make_sd_backbone(
|
118 |
+
model: nn.Module,
|
119 |
+
hooks: list[int] = [2, 5, 8, 11],
|
120 |
+
hook_patch: bool = True,
|
121 |
+
start_index: list[int] = 1,
|
122 |
+
):
|
123 |
+
assert len(hooks) == 4
|
124 |
+
|
125 |
+
pretrained = nn.Module()
|
126 |
+
pretrained.model = model
|
127 |
+
|
128 |
+
# add hooks
|
129 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
|
130 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
|
131 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
|
132 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
|
133 |
+
if hook_patch:
|
134 |
+
pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
|
135 |
+
|
136 |
+
# configure readout
|
137 |
+
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
|
138 |
+
pretrained.model.start_index = start_index
|
139 |
+
pretrained.model.patch_size = patch_size
|
140 |
+
|
141 |
+
# We inject this function into the VisionTransformer instances so that
|
142 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
143 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
144 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
145 |
+
_resize_pos_embed, pretrained.model
|
146 |
+
)
|
147 |
+
|
148 |
+
return pretrained
|
149 |
+
|
150 |
+
def make_vit_backbone(
|
151 |
+
model: nn.Module,
|
152 |
+
patch_size: list[int] = [16, 16],
|
153 |
+
hooks: list[int] = [2, 5, 8, 11],
|
154 |
+
hook_patch: bool = True,
|
155 |
+
start_index: list[int] = 1,
|
156 |
+
):
|
157 |
+
assert len(hooks) == 4
|
158 |
+
|
159 |
+
pretrained = nn.Module()
|
160 |
+
pretrained.model = model
|
161 |
+
|
162 |
+
# add hooks
|
163 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
|
164 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
|
165 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
|
166 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
|
167 |
+
if hook_patch:
|
168 |
+
pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
|
169 |
+
|
170 |
+
# configure readout
|
171 |
+
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
|
172 |
+
pretrained.model.start_index = start_index
|
173 |
+
pretrained.model.patch_size = patch_size
|
174 |
+
|
175 |
+
# We inject this function into the VisionTransformer instances so that
|
176 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
177 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
178 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
179 |
+
_resize_pos_embed, pretrained.model
|
180 |
+
)
|
181 |
+
|
182 |
+
return pretrained
|