NightRaven109 commited on
Commit
6ecc7d4
·
verified ·
1 Parent(s): efc2ec3

Upload 73 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ADD/dnnlib/__init__.py +9 -0
  2. ADD/dnnlib/util.py +492 -0
  3. ADD/layers/__init__.py +11 -0
  4. ADD/layers/attention.py +89 -0
  5. ADD/layers/block.py +260 -0
  6. ADD/layers/dino_head.py +58 -0
  7. ADD/layers/drop_path.py +34 -0
  8. ADD/layers/layer_scale.py +27 -0
  9. ADD/layers/mlp.py +40 -0
  10. ADD/layers/patch_embed.py +88 -0
  11. ADD/layers/swiglu_ffn.py +72 -0
  12. ADD/models/discriminator.py +178 -0
  13. ADD/models/vit.py +373 -0
  14. ADD/th_utils/__init__.py +9 -0
  15. ADD/th_utils/custom_ops.py +157 -0
  16. ADD/th_utils/misc.py +284 -0
  17. ADD/th_utils/ops/__init__.py +9 -0
  18. ADD/th_utils/ops/bias_act.cpp +99 -0
  19. ADD/th_utils/ops/bias_act.cu +173 -0
  20. ADD/th_utils/ops/bias_act.h +38 -0
  21. ADD/th_utils/ops/bias_act.py +209 -0
  22. ADD/th_utils/ops/conv2d_gradfix.py +203 -0
  23. ADD/th_utils/ops/conv2d_resample.py +143 -0
  24. ADD/th_utils/ops/filtered_lrelu.cpp +300 -0
  25. ADD/th_utils/ops/filtered_lrelu.cu +1284 -0
  26. ADD/th_utils/ops/filtered_lrelu.h +90 -0
  27. ADD/th_utils/ops/filtered_lrelu.py +274 -0
  28. ADD/th_utils/ops/filtered_lrelu_ns.cu +27 -0
  29. ADD/th_utils/ops/filtered_lrelu_rd.cu +27 -0
  30. ADD/th_utils/ops/filtered_lrelu_wr.cu +27 -0
  31. ADD/th_utils/ops/fma.py +60 -0
  32. ADD/th_utils/ops/grid_sample_gradfix.py +83 -0
  33. ADD/th_utils/ops/upfirdn2d.cpp +107 -0
  34. ADD/th_utils/ops/upfirdn2d.cu +384 -0
  35. ADD/th_utils/ops/upfirdn2d.h +59 -0
  36. ADD/th_utils/ops/upfirdn2d.py +389 -0
  37. ADD/utils/util_net.py +182 -0
  38. README.md +291 -15
  39. dataloaders/paired_dataset_txt.py +70 -0
  40. dataloaders/params_ccsr.yml +42 -0
  41. dataloaders/realesrgan.py +303 -0
  42. models/DiffAugment.py +121 -0
  43. models/controlnet.py +850 -0
  44. models/losses/__init__.py +1 -0
  45. models/losses/contperceptual.py +154 -0
  46. models/losses/vqperceptual.py +180 -0
  47. models/shared.py +106 -0
  48. models/unet_2d_blocks.py +0 -0
  49. models/unet_2d_condition.py +1081 -0
  50. models/vit_utils.py +182 -0
ADD/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
ADD/dnnlib/util.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import os
16
+ import sys
17
+ import types
18
+ import io
19
+ import pickle
20
+ import re
21
+ import requests
22
+ import html
23
+ import hashlib
24
+ import glob
25
+ import tempfile
26
+ import urllib
27
+ import urllib.request
28
+ import uuid
29
+ from typing import Any, List, Tuple, Union, Optional
30
+ from distutils.util import strtobool
31
+ import shutil
32
+
33
+ import numpy as np
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+ class EasyDict(dict):
40
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41
+
42
+ def __getattr__(self, name: str) -> Any:
43
+ try:
44
+ return self[name]
45
+ except KeyError:
46
+ raise AttributeError(name)
47
+
48
+ def __setattr__(self, name: str, value: Any) -> None:
49
+ self[name] = value
50
+
51
+ def __delattr__(self, name: str) -> None:
52
+ del self[name]
53
+
54
+
55
+ class Logger(object):
56
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57
+
58
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59
+ self.file = None
60
+
61
+ if file_name is not None:
62
+ self.file = open(file_name, file_mode)
63
+
64
+ self.should_flush = should_flush
65
+ self.stdout = sys.stdout
66
+ self.stderr = sys.stderr
67
+
68
+ sys.stdout = self
69
+ sys.stderr = self
70
+
71
+ def __enter__(self) -> "Logger":
72
+ return self
73
+
74
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75
+ self.close()
76
+
77
+ def write(self, text: Union[str, bytes]) -> None:
78
+ """Write text to stdout (and a file) and optionally flush."""
79
+ if isinstance(text, bytes):
80
+ text = text.decode()
81
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82
+ return
83
+
84
+ if self.file is not None:
85
+ self.file.write(text)
86
+
87
+ self.stdout.write(text)
88
+
89
+ if self.should_flush:
90
+ self.flush()
91
+
92
+ def flush(self) -> None:
93
+ """Flush written text to both stdout and a file, if open."""
94
+ if self.file is not None:
95
+ self.file.flush()
96
+
97
+ self.stdout.flush()
98
+
99
+ def close(self) -> None:
100
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
101
+ self.flush()
102
+
103
+ # if using multiple loggers, prevent closing in wrong order
104
+ if sys.stdout is self:
105
+ sys.stdout = self.stdout
106
+ if sys.stderr is self:
107
+ sys.stderr = self.stderr
108
+
109
+ if self.file is not None:
110
+ self.file.close()
111
+ self.file = None
112
+
113
+
114
+ # Cache directories
115
+ # ------------------------------------------------------------------------------------------
116
+
117
+ _dnnlib_cache_dir = None
118
+
119
+ def set_cache_dir(path: str) -> None:
120
+ global _dnnlib_cache_dir
121
+ _dnnlib_cache_dir = path
122
+
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+
136
+ # Small util functions
137
+ # ------------------------------------------------------------------------------------------
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ safe_name = safe_name[:min(len(safe_name), 128)]
481
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
482
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
483
+ os.makedirs(cache_dir, exist_ok=True)
484
+ with open(temp_file, "wb") as f:
485
+ f.write(url_data)
486
+ os.replace(temp_file, cache_file) # atomic
487
+ if return_filename:
488
+ return cache_file
489
+
490
+ # Return data as file object.
491
+ assert not return_filename
492
+ return io.BytesIO(url_data)
ADD/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
ADD/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
ADD/layers/block.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+
40
+ warnings.warn("xFormers is not available (Block)")
41
+
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int,
48
+ mlp_ratio: float = 4.0,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ ffn_bias: bool = True,
52
+ drop: float = 0.0,
53
+ attn_drop: float = 0.0,
54
+ init_values=None,
55
+ drop_path: float = 0.0,
56
+ act_layer: Callable[..., nn.Module] = nn.GELU,
57
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
+ attn_class: Callable[..., nn.Module] = Attention,
59
+ ffn_layer: Callable[..., nn.Module] = Mlp,
60
+ ) -> None:
61
+ super().__init__()
62
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
+ self.norm1 = norm_layer(dim)
64
+ self.attn = attn_class(
65
+ dim,
66
+ num_heads=num_heads,
67
+ qkv_bias=qkv_bias,
68
+ proj_bias=proj_bias,
69
+ attn_drop=attn_drop,
70
+ proj_drop=drop,
71
+ )
72
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = ffn_layer(
78
+ in_features=dim,
79
+ hidden_features=mlp_hidden_dim,
80
+ act_layer=act_layer,
81
+ drop=drop,
82
+ bias=ffn_bias,
83
+ )
84
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86
+
87
+ self.sample_drop_ratio = drop_path
88
+
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ def attn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls1(self.attn(self.norm1(x)))
92
+
93
+ def ffn_residual_func(x: Tensor) -> Tensor:
94
+ return self.ls2(self.mlp(self.norm2(x)))
95
+
96
+ if self.training and self.sample_drop_ratio > 0.1:
97
+ # the overhead is compensated only for a drop path rate larger than 0.1
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=attn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=ffn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ elif self.training and self.sample_drop_ratio > 0.0:
109
+ x = x + self.drop_path1(attn_residual_func(x))
110
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111
+ else:
112
+ x = x + attn_residual_func(x)
113
+ x = x + ffn_residual_func(x)
114
+ return x
115
+
116
+
117
+ def drop_add_residual_stochastic_depth(
118
+ x: Tensor,
119
+ residual_func: Callable[[Tensor], Tensor],
120
+ sample_drop_ratio: float = 0.0,
121
+ ) -> Tensor:
122
+ # 1) extract subset using permutation
123
+ b, n, d = x.shape
124
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126
+ x_subset = x[brange]
127
+
128
+ # 2) apply residual_func to get residual
129
+ residual = residual_func(x_subset)
130
+
131
+ x_flat = x.flatten(1)
132
+ residual = residual.flatten(1)
133
+
134
+ residual_scale_factor = b / sample_subset_size
135
+
136
+ # 3) add the residual
137
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138
+ return x_plus_residual.view_as(x)
139
+
140
+
141
+ def get_branges_scales(x, sample_drop_ratio=0.0):
142
+ b, n, d = x.shape
143
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145
+ residual_scale_factor = b / sample_subset_size
146
+ return brange, residual_scale_factor
147
+
148
+
149
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150
+ if scaling_vector is None:
151
+ x_flat = x.flatten(1)
152
+ residual = residual.flatten(1)
153
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154
+ else:
155
+ x_plus_residual = scaled_index_add(
156
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157
+ )
158
+ return x_plus_residual
159
+
160
+
161
+ attn_bias_cache: Dict[Tuple, Any] = {}
162
+
163
+
164
+ def get_attn_bias_and_cat(x_list, branges=None):
165
+ """
166
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
167
+ """
168
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170
+ if all_shapes not in attn_bias_cache.keys():
171
+ seqlens = []
172
+ for b, x in zip(batch_sizes, x_list):
173
+ for _ in range(b):
174
+ seqlens.append(x.shape[1])
175
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
176
+ attn_bias._batch_sizes = batch_sizes
177
+ attn_bias_cache[all_shapes] = attn_bias
178
+
179
+ if branges is not None:
180
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181
+ else:
182
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
184
+
185
+ return attn_bias_cache[all_shapes], cat_tensors
186
+
187
+
188
+ def drop_add_residual_stochastic_depth_list(
189
+ x_list: List[Tensor],
190
+ residual_func: Callable[[Tensor, Any], Tensor],
191
+ sample_drop_ratio: float = 0.0,
192
+ scaling_vector=None,
193
+ ) -> Tensor:
194
+ # 1) generate random set of indices for dropping samples in the batch
195
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196
+ branges = [s[0] for s in branges_scales]
197
+ residual_scale_factors = [s[1] for s in branges_scales]
198
+
199
+ # 2) get attention bias and index+concat the tensors
200
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201
+
202
+ # 3) apply residual_func to get residual, and split the result
203
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204
+
205
+ outputs = []
206
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208
+ return outputs
209
+
210
+
211
+ class NestedTensorBlock(Block):
212
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213
+ """
214
+ x_list contains a list of tensors to nest together and run
215
+ """
216
+ assert isinstance(self.attn, MemEffAttention)
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+
220
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
222
+
223
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.mlp(self.norm2(x))
225
+
226
+ x_list = drop_add_residual_stochastic_depth_list(
227
+ x_list,
228
+ residual_func=attn_residual_func,
229
+ sample_drop_ratio=self.sample_drop_ratio,
230
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231
+ )
232
+ x_list = drop_add_residual_stochastic_depth_list(
233
+ x_list,
234
+ residual_func=ffn_residual_func,
235
+ sample_drop_ratio=self.sample_drop_ratio,
236
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237
+ )
238
+ return x_list
239
+ else:
240
+
241
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243
+
244
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls2(self.mlp(self.norm2(x)))
246
+
247
+ attn_bias, x = get_attn_bias_and_cat(x_list)
248
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
249
+ x = x + ffn_residual_func(x)
250
+ return attn_bias.split(x)
251
+
252
+ def forward(self, x_or_x_list):
253
+ if isinstance(x_or_x_list, Tensor):
254
+ return super().forward(x_or_x_list)
255
+ elif isinstance(x_or_x_list, list):
256
+ if not XFORMERS_AVAILABLE:
257
+ raise AssertionError("xFormers is required for using nested tensors")
258
+ return self.forward_nested(x_or_x_list)
259
+ else:
260
+ raise AssertionError
ADD/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
ADD/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
ADD/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
ADD/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
ADD/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ #self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ #def flops(self) -> float:
84
+ #Ho, Wo = self.patches_resolution
85
+ #flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ #if self.norm is not None:
87
+ # flops += Ho * Wo * self.embed_dim
88
+ #return flops
ADD/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
ADD/models/discriminator.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """
10
+ Projected discriminator architecture from
11
+ "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
12
+ """
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.utils.spectral_norm import SpectralNorm
19
+ from torchvision.transforms import RandomCrop, Normalize
20
+ import timm
21
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
22
+
23
+ from ADD.th_utils import misc
24
+ from models.shared import ResidualBlock, FullyConnectedLayer
25
+ from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone
26
+ from models.DiffAugment import DiffAugment
27
+ from ADD.utils.util_net import reload_model_
28
+
29
+ from functools import partial
30
+
31
+ class SpectralConv1d(nn.Conv1d):
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)
35
+
36
+
37
+ class BatchNormLocal(nn.Module):
38
+ def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):
39
+ super().__init__()
40
+ self.virtual_bs = virtual_bs
41
+ self.eps = eps
42
+ self.affine = affine
43
+
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(num_features))
46
+ self.bias = nn.Parameter(torch.zeros(num_features))
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ shape = x.size()
50
+
51
+ # Reshape batch into groups.
52
+ G = np.ceil(x.size(0)/self.virtual_bs).astype(int)
53
+ x = x.view(G, -1, x.size(-2), x.size(-1))
54
+
55
+ # Calculate stats.
56
+ mean = x.mean([1, 3], keepdim=True)
57
+ var = x.var([1, 3], keepdim=True, unbiased=False)
58
+ x = (x - mean) / (torch.sqrt(var + self.eps))
59
+
60
+ if self.affine:
61
+ x = x * self.weight[None, :, None] + self.bias[None, :, None]
62
+
63
+ return x.view(shape)
64
+
65
+
66
+ def make_block(channels: int, kernel_size: int) -> nn.Module:
67
+ return nn.Sequential(
68
+ SpectralConv1d(
69
+ channels,
70
+ channels,
71
+ kernel_size = kernel_size,
72
+ padding = kernel_size//2,
73
+ padding_mode = 'circular',
74
+ ),
75
+ #BatchNormLocal(channels),
76
+ nn.GroupNorm(4, channels),
77
+ nn.LeakyReLU(0.2, True),
78
+ )
79
+
80
+ class DiscHead(nn.Module):
81
+ def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
82
+ super().__init__()
83
+ self.channels = channels
84
+ self.c_dim = c_dim
85
+ self.cmap_dim = cmap_dim
86
+
87
+ self.main = nn.Sequential(
88
+ make_block(channels, kernel_size=1),
89
+ ResidualBlock(make_block(channels, kernel_size=9))
90
+ )
91
+
92
+ if self.c_dim > 0:
93
+ self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)
94
+ self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
95
+ else:
96
+ self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
97
+
98
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
99
+ h = self.main(x)
100
+ out = self.cls(h)
101
+
102
+ if self.c_dim > 0:
103
+ cmap = self.cmapper(c).unsqueeze(-1)
104
+ out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
105
+
106
+ return out
107
+
108
+ class DINO(torch.nn.Module):
109
+ def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):
110
+ super().__init__()
111
+ self.n_hooks = len(hooks) + int(hook_patch)
112
+
113
+ self.model = make_vit_backbone(
114
+ timm.create_model('vit_small_patch16_224.dino', pretrained=False),
115
+ patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
116
+ )
117
+ reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))
118
+ self.model = self.model.eval().requires_grad_(False)
119
+
120
+
121
+ self.img_resolution = self.model.model.patch_embed.img_size[0]
122
+ self.embed_dim = self.model.model.embed_dim
123
+ self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ ''' input: x in [0, 1]; output: dict of activations '''
127
+ x = F.interpolate(x, self.img_resolution, mode='area')
128
+ x = self.norm(x)
129
+ features = forward_vit(self.model, x)
130
+ return features
131
+
132
+
133
+ class ProjectedDiscriminator(nn.Module):
134
+ def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):
135
+ super().__init__()
136
+ self.c_dim = c_dim
137
+ self.diffaug = diffaug
138
+ self.p_crop = p_crop
139
+
140
+ self.dino = DINO()
141
+
142
+ heads = []
143
+ for i in range(self.dino.n_hooks):
144
+ heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],
145
+ self.heads = nn.ModuleDict(heads)
146
+
147
+ def train(self, mode: bool = True):
148
+ self.dino = self.dino.train(False)
149
+ self.heads = self.heads.train(mode)
150
+ return self
151
+
152
+ def eval(self):
153
+ return self.train(False)
154
+
155
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
156
+ # Apply augmentation (x in [-1, 1]).
157
+ if self.diffaug:
158
+ x = DiffAugment(x, policy='translation,cutout')
159
+
160
+ # Transform to [0, 1].
161
+ x = x.add(1).div(2)
162
+
163
+ # Take crops with probablity p_crop if the image is larger.
164
+ if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:
165
+ x = RandomCrop(self.dino.img_resolution)(x)
166
+
167
+ # Forward pass through DINO ViT.
168
+ features = self.dino(x)
169
+
170
+ # Apply discriminator heads.
171
+ logits = []
172
+ for k, head in self.heads.items():
173
+ features[k].requires_grad_(True)
174
+ logits.append(head(features[k], c).view(x.size(0), -1))
175
+ #logits = torch.cat(logits, dim=1)
176
+
177
+ return logits, features
178
+
ADD/models/vit.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ADD.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
194
+
195
+ sqrt_N = math.sqrt(N)
196
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
197
+ patch_pos_embed = nn.functional.interpolate(
198
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
199
+ scale_factor=(sx, sy),
200
+ mode="bicubic",
201
+ antialias=self.interpolate_antialias,
202
+ )
203
+
204
+ assert int(w0) == patch_pos_embed.shape[-2]
205
+ assert int(h0) == patch_pos_embed.shape[-1]
206
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
207
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
208
+
209
+ def prepare_tokens_with_masks(self, x, masks=None):
210
+ B, nc, w, h = x.shape
211
+ x = self.patch_embed(x)
212
+ if masks is not None:
213
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
214
+
215
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
216
+ x = x + self.interpolate_pos_encoding(x, w, h)
217
+
218
+ if self.register_tokens is not None:
219
+ x = torch.cat(
220
+ (
221
+ x[:, :1],
222
+ self.register_tokens.expand(x.shape[0], -1, -1),
223
+ x[:, 1:],
224
+ ),
225
+ dim=1,
226
+ )
227
+
228
+ return x
229
+
230
+ def forward_features_list(self, x_list, masks_list):
231
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
232
+ for blk in self.blocks:
233
+ x = blk(x)
234
+
235
+ all_x = x
236
+ output = []
237
+ for x, masks in zip(all_x, masks_list):
238
+ x_norm = self.norm(x)
239
+ output.append(
240
+ {
241
+ "x_norm_clstoken": x_norm[:, 0],
242
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
243
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
244
+ "x_prenorm": x,
245
+ "masks": masks,
246
+ }
247
+ )
248
+ return output
249
+
250
+ def forward_features(self, x, masks=None):
251
+ fea_list = []
252
+ counter = 0
253
+ if isinstance(x, list):
254
+ return self.forward_features_list(x, masks)
255
+
256
+ x = self.prepare_tokens_with_masks(x, masks)
257
+ fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+ counter += 1
262
+ if counter % 3 == 0:
263
+ fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))
264
+
265
+ x_norm = self.norm(x)
266
+ return fea_list, x_norm[:, 0]
267
+
268
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
269
+ x = self.prepare_tokens_with_masks(x)
270
+ # If n is an int, take the n last blocks. If it's a list, take them
271
+ output, total_block_len = [], len(self.blocks)
272
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
273
+ for i, blk in enumerate(self.blocks):
274
+ x = blk(x)
275
+ if i in blocks_to_take:
276
+ output.append(x)
277
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
278
+ return output
279
+
280
+ def _get_intermediate_layers_chunked(self, x, n=1):
281
+ x = self.prepare_tokens_with_masks(x)
282
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
283
+ # If n is an int, take the n last blocks. If it's a list, take them
284
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
285
+ for block_chunk in self.blocks:
286
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
287
+ x = blk(x)
288
+ if i in blocks_to_take:
289
+ output.append(x)
290
+ i += 1
291
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
292
+ return output
293
+
294
+ def get_intermediate_layers(
295
+ self,
296
+ x: torch.Tensor,
297
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
298
+ reshape: bool = False,
299
+ return_class_token: bool = False,
300
+ norm=True,
301
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
302
+ if self.chunked_blocks:
303
+ outputs = self._get_intermediate_layers_chunked(x, n)
304
+ else:
305
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
306
+ if norm:
307
+ outputs = [self.norm(out) for out in outputs]
308
+ class_tokens = [out[:, 0] for out in outputs]
309
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
310
+ if reshape:
311
+ B, _, w, h = x.shape
312
+ outputs = [
313
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
314
+ for out in outputs
315
+ ]
316
+ if return_class_token:
317
+ return tuple(zip(outputs, class_tokens))
318
+ return tuple(outputs)
319
+
320
+ def forward(self, *args, is_training=False, **kwargs):
321
+ ret = self.forward_features(*args, **kwargs)
322
+ if is_training:
323
+ return ret
324
+ else:
325
+ return ret#self.head(ret["x_norm_clstoken"])
326
+
327
+
328
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
329
+ """ViT weight initialization, original timm impl (for reproducibility)"""
330
+ if isinstance(module, nn.Linear):
331
+ trunc_normal_(module.weight, std=0.02)
332
+ if module.bias is not None:
333
+ nn.init.zeros_(module.bias)
334
+
335
+
336
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
337
+ model = DinoVisionTransformer(
338
+ patch_size=patch_size,
339
+ embed_dim=384,
340
+ depth=12,
341
+ num_heads=6,
342
+ mlp_ratio=4,
343
+ block_fn=partial(Block, attn_class=MemEffAttention),
344
+ num_register_tokens=num_register_tokens,
345
+ **kwargs,
346
+ )
347
+ return model
348
+
349
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
350
+ model = DinoVisionTransformer(
351
+ patch_size=patch_size,
352
+ embed_dim=1024,
353
+ depth=24,
354
+ num_heads=16,
355
+ mlp_ratio=4,
356
+ block_fn=partial(Block, attn_class=MemEffAttention),
357
+ num_register_tokens=num_register_tokens,
358
+ **kwargs,
359
+ )
360
+ return model
361
+
362
+
363
+ # net = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0)
364
+ # prefile = torch.load('../weights/dinov2_vits14_pretrain.pth')
365
+ # net.load_state_dict(prefile, True)
366
+ # out = net(torch.rand(1, 3, 518, 518))
367
+ # print(out.shape)
368
+
369
+ # net = vit_large(patch_size=14, img_size=526, block_chunks=0, init_values=1.0, num_register_tokens=4)
370
+ # prefile = torch.load('../weights/dinov2_vitl14_reg4_pretrain.pth')
371
+ # net.load_state_dict(prefile, True)
372
+ # out = net(torch.rand(1, 3, 70, 70))
373
+ # print(out.shape)
ADD/th_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
ADD/th_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files*/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
ADD/th_utils/misc.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import ADD.dnnlib as dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
68
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69
+
70
+ @contextlib.contextmanager
71
+ def suppress_tracer_warnings():
72
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73
+ warnings.filters.insert(0, flt)
74
+ yield
75
+ warnings.filters.remove(flt)
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+ def spectral_to_cpu(model: torch.nn.Module):
147
+ def wrapped_in_spectral(m): return hasattr(m, 'weight_v')
148
+ children = get_children(model)
149
+ for child in children:
150
+ if wrapped_in_spectral(child):
151
+ child.weight = child.weight.cpu()
152
+ return model
153
+
154
+ def get_children(model: torch.nn.Module):
155
+ children = list(model.children())
156
+ flatt_children = []
157
+ if children == []:
158
+ return model
159
+ else:
160
+ for child in children:
161
+ try:
162
+ flatt_children.extend(get_children(child))
163
+ except TypeError:
164
+ flatt_children.append(get_children(child))
165
+ return flatt_children
166
+
167
+ def params_and_buffers(module):
168
+ assert isinstance(module, torch.nn.Module)
169
+ return list(module.parameters()) + list(module.buffers())
170
+
171
+ def named_params_and_buffers(module):
172
+ assert isinstance(module, torch.nn.Module)
173
+ return list(module.named_parameters()) + list(module.named_buffers())
174
+
175
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
176
+ assert isinstance(src_module, torch.nn.Module)
177
+ assert isinstance(dst_module, torch.nn.Module)
178
+ src_tensors = dict(named_params_and_buffers(src_module))
179
+ for name, tensor in named_params_and_buffers(dst_module):
180
+ assert (name in src_tensors) or (not require_all)
181
+ if name in src_tensors:
182
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
183
+
184
+ #----------------------------------------------------------------------------
185
+ # Context manager for easily enabling/disabling DistributedDataParallel
186
+ # synchronization.
187
+
188
+ @contextlib.contextmanager
189
+ def ddp_sync(module, sync):
190
+ assert isinstance(module, torch.nn.Module)
191
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
192
+ yield
193
+ else:
194
+ with module.no_sync():
195
+ yield
196
+
197
+ #----------------------------------------------------------------------------
198
+ # Check DistributedDataParallel consistency across processes.
199
+
200
+ def check_ddp_consistency(module, ignore_regex=None):
201
+ assert isinstance(module, torch.nn.Module)
202
+ for name, tensor in named_params_and_buffers(module):
203
+ fullname = type(module).__name__ + '.' + name
204
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
205
+ continue
206
+ tensor = tensor.detach()
207
+ if tensor.is_floating_point():
208
+ tensor = nan_to_num(tensor)
209
+ other = tensor.clone()
210
+ torch.distributed.broadcast(tensor=other, src=0)
211
+ assert (tensor == other).all(), fullname
212
+
213
+ #----------------------------------------------------------------------------
214
+ # Print summary table of module hierarchy.
215
+
216
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
217
+ assert isinstance(module, torch.nn.Module)
218
+ assert not isinstance(module, torch.jit.ScriptModule)
219
+ assert isinstance(inputs, (tuple, list))
220
+
221
+ # Register hooks.
222
+ entries = []
223
+ nesting = [0]
224
+ def pre_hook(_mod, _inputs):
225
+ nesting[0] += 1
226
+ def post_hook(mod, _inputs, outputs):
227
+ nesting[0] -= 1
228
+ if nesting[0] <= max_nesting:
229
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
230
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
231
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
232
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
233
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
234
+
235
+ # Run module.
236
+ outputs = module(*inputs)
237
+ for hook in hooks:
238
+ hook.remove()
239
+
240
+ # Identify unique outputs, parameters, and buffers.
241
+ tensors_seen = set()
242
+ for e in entries:
243
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
244
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
245
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
246
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
247
+
248
+ # Filter out redundant entries.
249
+ if skip_redundant:
250
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
251
+
252
+ # Construct table.
253
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
254
+ rows += [['---'] * len(rows[0])]
255
+ param_total = 0
256
+ buffer_total = 0
257
+ submodule_names = {mod: name for name, mod in module.named_modules()}
258
+ for e in entries:
259
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
260
+ param_size = sum(t.numel() for t in e.unique_params)
261
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
262
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
263
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
264
+ rows += [[
265
+ name + (':0' if len(e.outputs) >= 2 else ''),
266
+ str(param_size) if param_size else '-',
267
+ str(buffer_size) if buffer_size else '-',
268
+ (output_shapes + ['-'])[0],
269
+ (output_dtypes + ['-'])[0],
270
+ ]]
271
+ for idx in range(1, len(e.outputs)):
272
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
273
+ param_total += param_size
274
+ buffer_total += buffer_size
275
+ rows += [['---'] * len(rows[0])]
276
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
277
+
278
+ # Print table.
279
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
280
+ print()
281
+ for row in rows:
282
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
283
+ print()
284
+ return outputs
ADD/th_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
ADD/th_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
ADD/th_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
ADD/th_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
ADD/th_utils/ops/bias_act.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import ADD.dnnlib as dnnlib
15
+
16
+ from .. import custom_ops
17
+ from .. import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ activation_funcs = {
22
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _plugin = None
36
+ _null_tensor = torch.empty([0])
37
+
38
+ def _init():
39
+ global _plugin
40
+ if _plugin is None:
41
+ _plugin = custom_ops.get_plugin(
42
+ module_name='bias_act_plugin',
43
+ sources=['bias_act.cpp', 'bias_act.cu'],
44
+ headers=['bias_act.h'],
45
+ source_dir=os.path.dirname(__file__),
46
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
47
+ )
48
+ return True
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53
+ r"""Fused bias and activation function.
54
+
55
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
57
+ the fused op is considerably more efficient than performing the same calculation
58
+ using standard PyTorch ops. It supports first and second order gradients,
59
+ but not third order gradients.
60
+
61
+ Args:
62
+ x: Input activation tensor. Can be of any shape.
63
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64
+ as `x`. The shape must be known, and it must match the dimension of `x`
65
+ corresponding to `dim`.
66
+ dim: The dimension in `x` corresponding to the elements of `b`.
67
+ The value of `dim` is ignored if `b` is not specified.
68
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
69
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70
+ See `activation_funcs` for a full list. `None` is not allowed.
71
+ alpha: Shape parameter for the activation function, or `None` to use the default.
72
+ gain: Scaling factor for the output tensor, or `None` to use default.
73
+ See `activation_funcs` for the default scaling of each activation function.
74
+ If unsure, consider specifying 1.
75
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76
+ the clamping (default).
77
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78
+
79
+ Returns:
80
+ Tensor of the same shape and datatype as `x`.
81
+ """
82
+ assert isinstance(x, torch.Tensor)
83
+ assert impl in ['ref', 'cuda']
84
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
85
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @misc.profiled_function
91
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93
+ """
94
+ assert isinstance(x, torch.Tensor)
95
+ assert clamp is None or clamp >= 0
96
+ spec = activation_funcs[act]
97
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
98
+ gain = float(gain if gain is not None else spec.def_gain)
99
+ clamp = float(clamp if clamp is not None else -1)
100
+
101
+ # Add bias.
102
+ if b is not None:
103
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
104
+ assert 0 <= dim < x.ndim
105
+ assert b.shape[0] == x.shape[dim]
106
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107
+
108
+ # Evaluate activation function.
109
+ alpha = float(alpha)
110
+ x = spec.func(x, alpha=alpha)
111
+
112
+ # Scale by gain.
113
+ gain = float(gain)
114
+ if gain != 1:
115
+ x = x * gain
116
+
117
+ # Clamp.
118
+ if clamp >= 0:
119
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120
+ return x
121
+
122
+ #----------------------------------------------------------------------------
123
+
124
+ _bias_act_cuda_cache = dict()
125
+
126
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127
+ """Fast CUDA implementation of `bias_act()` using custom ops.
128
+ """
129
+ # Parse arguments.
130
+ assert clamp is None or clamp >= 0
131
+ spec = activation_funcs[act]
132
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
133
+ gain = float(gain if gain is not None else spec.def_gain)
134
+ clamp = float(clamp if clamp is not None else -1)
135
+
136
+ # Lookup from cache.
137
+ key = (dim, act, alpha, gain, clamp)
138
+ if key in _bias_act_cuda_cache:
139
+ return _bias_act_cuda_cache[key]
140
+
141
+ # Forward op.
142
+ class BiasActCuda(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
145
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146
+ x = x.contiguous(memory_format=ctx.memory_format)
147
+ b = b.contiguous() if b is not None else _null_tensor
148
+ y = x
149
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151
+ ctx.save_for_backward(
152
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154
+ y if 'y' in spec.ref else _null_tensor)
155
+ return y
156
+
157
+ @staticmethod
158
+ def backward(ctx, dy): # pylint: disable=arguments-differ
159
+ dy = dy.contiguous(memory_format=ctx.memory_format)
160
+ x, b, y = ctx.saved_tensors
161
+ dx = None
162
+ db = None
163
+
164
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165
+ dx = dy
166
+ if act != 'linear' or gain != 1 or clamp >= 0:
167
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
168
+
169
+ if ctx.needs_input_grad[1]:
170
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
171
+
172
+ return dx, db
173
+
174
+ # Backward op.
175
+ class BiasActCudaGrad(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180
+ ctx.save_for_backward(
181
+ dy if spec.has_2nd_grad else _null_tensor,
182
+ x, b, y)
183
+ return dx
184
+
185
+ @staticmethod
186
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
187
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188
+ dy, x, b, y = ctx.saved_tensors
189
+ d_dy = None
190
+ d_x = None
191
+ d_b = None
192
+ d_y = None
193
+
194
+ if ctx.needs_input_grad[0]:
195
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196
+
197
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199
+
200
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202
+
203
+ return d_dy, d_x, d_b, d_y
204
+
205
+ # Add to cache.
206
+ _bias_act_cuda_cache[key] = BiasActCuda
207
+ return BiasActCuda
208
+
209
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import contextlib
13
+ import torch
14
+ from pkg_resources import parse_version
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
+ _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25
+
26
+ @contextlib.contextmanager
27
+ def no_weight_gradients(disable=True):
28
+ global weight_gradients_disabled
29
+ old = weight_gradients_disabled
30
+ if disable:
31
+ weight_gradients_disabled = True
32
+ yield
33
+ weight_gradients_disabled = old
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38
+ if _should_use_custom_op(input):
39
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41
+
42
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43
+ if _should_use_custom_op(input):
44
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _should_use_custom_op(input):
50
+ assert isinstance(input, torch.Tensor)
51
+ if (not enabled) or (not torch.backends.cudnn.enabled):
52
+ return False
53
+ if _use_pytorch_1_11_api:
54
+ # The work-around code doesn't work on PyTorch 1.11.0 onwards
55
+ return False
56
+ if input.device.type != 'cuda':
57
+ return False
58
+ return True
59
+
60
+ def _tuple_of_ints(xs, ndim):
61
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62
+ assert len(xs) == ndim
63
+ assert all(isinstance(x, int) for x in xs)
64
+ return xs
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ _conv2d_gradfix_cache = dict()
69
+ _null_tensor = torch.empty([0])
70
+
71
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
72
+ # Parse arguments.
73
+ ndim = 2
74
+ weight_shape = tuple(weight_shape)
75
+ stride = _tuple_of_ints(stride, ndim)
76
+ padding = _tuple_of_ints(padding, ndim)
77
+ output_padding = _tuple_of_ints(output_padding, ndim)
78
+ dilation = _tuple_of_ints(dilation, ndim)
79
+
80
+ # Lookup from cache.
81
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
82
+ if key in _conv2d_gradfix_cache:
83
+ return _conv2d_gradfix_cache[key]
84
+
85
+ # Validate arguments.
86
+ assert groups >= 1
87
+ assert len(weight_shape) == ndim + 2
88
+ assert all(stride[i] >= 1 for i in range(ndim))
89
+ assert all(padding[i] >= 0 for i in range(ndim))
90
+ assert all(dilation[i] >= 0 for i in range(ndim))
91
+ if not transpose:
92
+ assert all(output_padding[i] == 0 for i in range(ndim))
93
+ else: # transpose
94
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
95
+
96
+ # Helpers.
97
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
98
+ def calc_output_padding(input_shape, output_shape):
99
+ if transpose:
100
+ return [0, 0]
101
+ return [
102
+ input_shape[i + 2]
103
+ - (output_shape[i + 2] - 1) * stride[i]
104
+ - (1 - 2 * padding[i])
105
+ - dilation[i] * (weight_shape[i + 2] - 1)
106
+ for i in range(ndim)
107
+ ]
108
+
109
+ # Forward & backward.
110
+ class Conv2d(torch.autograd.Function):
111
+ @staticmethod
112
+ def forward(ctx, input, weight, bias):
113
+ assert weight.shape == weight_shape
114
+ ctx.save_for_backward(
115
+ input if weight.requires_grad else _null_tensor,
116
+ weight if input.requires_grad else _null_tensor,
117
+ )
118
+ ctx.input_shape = input.shape
119
+
120
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
121
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
122
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
123
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
124
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
125
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
126
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
127
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
128
+
129
+ # General case => cuDNN.
130
+ if transpose:
131
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
132
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
133
+
134
+ @staticmethod
135
+ def backward(ctx, grad_output):
136
+ input, weight = ctx.saved_tensors
137
+ input_shape = ctx.input_shape
138
+ grad_input = None
139
+ grad_weight = None
140
+ grad_bias = None
141
+
142
+ if ctx.needs_input_grad[0]:
143
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
144
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
145
+ grad_input = op.apply(grad_output, weight, None)
146
+ assert grad_input.shape == input_shape
147
+
148
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
149
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
150
+ assert grad_weight.shape == weight_shape
151
+
152
+ if ctx.needs_input_grad[2]:
153
+ grad_bias = grad_output.sum([0, 2, 3])
154
+
155
+ return grad_input, grad_weight, grad_bias
156
+
157
+ # Gradient with respect to the weights.
158
+ class Conv2dGradWeight(torch.autograd.Function):
159
+ @staticmethod
160
+ def forward(ctx, grad_output, input):
161
+ ctx.save_for_backward(
162
+ grad_output if input.requires_grad else _null_tensor,
163
+ input if grad_output.requires_grad else _null_tensor,
164
+ )
165
+ ctx.grad_output_shape = grad_output.shape
166
+ ctx.input_shape = input.shape
167
+
168
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
169
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
170
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
171
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
172
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
173
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
174
+
175
+ # General case => cuDNN.
176
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
177
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
178
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
179
+
180
+ @staticmethod
181
+ def backward(ctx, grad2_grad_weight):
182
+ grad_output, input = ctx.saved_tensors
183
+ grad_output_shape = ctx.grad_output_shape
184
+ input_shape = ctx.input_shape
185
+ grad2_grad_output = None
186
+ grad2_input = None
187
+
188
+ if ctx.needs_input_grad[0]:
189
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
190
+ assert grad2_grad_output.shape == grad_output_shape
191
+
192
+ if ctx.needs_input_grad[1]:
193
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
194
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
195
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
196
+ assert grad2_input.shape == input_shape
197
+
198
+ return grad2_grad_output, grad2_input
199
+
200
+ _conv2d_gradfix_cache[key] = Conv2d
201
+ return Conv2d
202
+
203
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ if not flip_weight and (kw > 1 or kh > 1):
37
+ w = w.flip([2, 3])
38
+
39
+ # Execute using conv2d_gradfix.
40
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
+ return op(x, w, stride=stride, padding=padding, groups=groups)
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ @misc.profiled_function
46
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47
+ r"""2D convolution with optional up/downsampling.
48
+
49
+ Padding is performed only once at the beginning, not between the operations.
50
+
51
+ Args:
52
+ x: Input tensor of shape
53
+ `[batch_size, in_channels, in_height, in_width]`.
54
+ w: Weight tensor of shape
55
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57
+ calling upfirdn2d.setup_filter(). None = identity (default).
58
+ up: Integer upsampling factor (default: 1).
59
+ down: Integer downsampling factor (default: 1).
60
+ padding: Padding with respect to the upsampled image. Can be a single number
61
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62
+ (default: 0).
63
+ groups: Split input channels into N groups (default: 1).
64
+ flip_weight: False = convolution, True = correlation (default: True).
65
+ flip_filter: False = convolution, True = correlation (default: False).
66
+
67
+ Returns:
68
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69
+ """
70
+ # Validate arguments.
71
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74
+ assert isinstance(up, int) and (up >= 1)
75
+ assert isinstance(down, int) and (down >= 1)
76
+ assert isinstance(groups, int) and (groups >= 1)
77
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78
+ fw, fh = _get_filter_size(f)
79
+ px0, px1, py0, py1 = _parse_padding(padding)
80
+
81
+ # Adjust padding to account for up/downsampling.
82
+ if up > 1:
83
+ px0 += (fw + up - 1) // 2
84
+ px1 += (fw - up) // 2
85
+ py0 += (fh + up - 1) // 2
86
+ py1 += (fh - up) // 2
87
+ if down > 1:
88
+ px0 += (fw - down + 1) // 2
89
+ px1 += (fw - down) // 2
90
+ py0 += (fh - down + 1) // 2
91
+ py1 += (fh - down) // 2
92
+
93
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
95
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97
+ return x
98
+
99
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
101
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103
+ return x
104
+
105
+ # Fast path: downsampling only => use strided convolution.
106
+ if down > 1 and up == 1:
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112
+ if up > 1:
113
+ if groups == 1:
114
+ w = w.transpose(0, 1)
115
+ else:
116
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117
+ w = w.transpose(1, 2)
118
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119
+ px0 -= kw - 1
120
+ px1 -= kw - up
121
+ py0 -= kh - 1
122
+ py1 -= kh - up
123
+ pxt = max(min(-px0, -px1), 0)
124
+ pyt = max(min(-py0, -py1), 0)
125
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127
+ if down > 1:
128
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129
+ return x
130
+
131
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132
+ if up == 1 and down == 1:
133
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135
+
136
+ # Fallback: Generic reference implementation.
137
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
ADD/th_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
ADD/th_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
ADD/th_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='filtered_lrelu_plugin',
28
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
29
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
32
+ )
33
+ return True
34
+
35
+ def _get_filter_size(f):
36
+ if f is None:
37
+ return 1, 1
38
+ assert isinstance(f, torch.Tensor)
39
+ assert 1 <= f.ndim <= 2
40
+ return f.shape[-1], f.shape[0] # width, height
41
+
42
+ def _parse_padding(padding):
43
+ if isinstance(padding, int):
44
+ padding = [padding, padding]
45
+ assert isinstance(padding, (list, tuple))
46
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
47
+ padding = [int(x) for x in padding]
48
+ if len(padding) == 2:
49
+ px, py = padding
50
+ padding = [px, px, py, py]
51
+ px0, px1, py0, py1 = padding
52
+ return px0, px1, py0, py1
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
57
+ r"""Filtered leaky ReLU for a batch of 2D images.
58
+
59
+ Performs the following sequence of operations for each channel:
60
+
61
+ 1. Add channel-specific bias if provided (`b`).
62
+
63
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
64
+
65
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
66
+ Negative padding corresponds to cropping the image.
67
+
68
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
69
+ so that the footprint of all output pixels lies within the input image.
70
+
71
+ 5. Multiply each value by the provided gain factor (`gain`).
72
+
73
+ 6. Apply leaky ReLU activation function to each value.
74
+
75
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
76
+
77
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
78
+ it so that the footprint of all output pixels lies within the input image.
79
+
80
+ 9. Downsample the image by keeping every Nth pixel (`down`).
81
+
82
+ The fused op is considerably more efficient than performing the same calculation
83
+ using standard PyTorch ops. It supports gradients of arbitrary order.
84
+
85
+ Args:
86
+ x: Float32/float16/float64 input tensor of the shape
87
+ `[batch_size, num_channels, in_height, in_width]`.
88
+ fu: Float32 upsampling FIR filter of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable), or
91
+ `None` (identity).
92
+ fd: Float32 downsampling FIR filter of the shape
93
+ `[filter_height, filter_width]` (non-separable),
94
+ `[filter_taps]` (separable), or
95
+ `None` (identity).
96
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
97
+ as `x`. The length of vector must must match the channel dimension of `x`.
98
+ up: Integer upsampling factor (default: 1).
99
+ down: Integer downsampling factor. (default: 1).
100
+ padding: Padding with respect to the upsampled image. Can be a single number
101
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
102
+ (default: 0).
103
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
104
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
105
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
106
+ flip_filter: False = convolution, True = correlation (default: False).
107
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
108
+
109
+ Returns:
110
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
111
+ """
112
+ assert isinstance(x, torch.Tensor)
113
+ assert impl in ['ref', 'cuda']
114
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
115
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
116
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ @misc.profiled_function
121
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
122
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
123
+ existing `upfirdn2n()` and `bias_act()` ops.
124
+ """
125
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
126
+ fu_w, fu_h = _get_filter_size(fu)
127
+ fd_w, fd_h = _get_filter_size(fd)
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
130
+ misc.assert_shape(b, [x.shape[1]])
131
+ assert isinstance(up, int) and up >= 1
132
+ assert isinstance(down, int) and down >= 1
133
+ px0, px1, py0, py1 = _parse_padding(padding)
134
+ assert gain == float(gain) and gain > 0
135
+ assert slope == float(slope) and slope >= 0
136
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
137
+
138
+ # Calculate output size.
139
+ batch_size, channels, in_h, in_w = x.shape
140
+ in_dtype = x.dtype
141
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
142
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
143
+
144
+ # Compute using existing ops.
145
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
146
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
147
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
148
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
149
+
150
+ # Check output shape & dtype.
151
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
152
+ assert x.dtype == in_dtype
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
156
+
157
+ _filtered_lrelu_cuda_cache = dict()
158
+
159
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
160
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
161
+ """
162
+ assert isinstance(up, int) and up >= 1
163
+ assert isinstance(down, int) and down >= 1
164
+ px0, px1, py0, py1 = _parse_padding(padding)
165
+ assert gain == float(gain) and gain > 0
166
+ gain = float(gain)
167
+ assert slope == float(slope) and slope >= 0
168
+ slope = float(slope)
169
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
170
+ clamp = float(clamp if clamp is not None else 'inf')
171
+
172
+ # Lookup from cache.
173
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
174
+ if key in _filtered_lrelu_cuda_cache:
175
+ return _filtered_lrelu_cuda_cache[key]
176
+
177
+ # Forward op.
178
+ class FilteredLReluCuda(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
181
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
182
+
183
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
184
+ if fu is None:
185
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
186
+ if fd is None:
187
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert 1 <= fu.ndim <= 2
189
+ assert 1 <= fd.ndim <= 2
190
+
191
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
192
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
193
+ fu = fu.square()[None]
194
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
195
+ fd = fd.square()[None]
196
+
197
+ # Missing sign input tensor.
198
+ if si is None:
199
+ si = torch.empty([0])
200
+
201
+ # Missing bias tensor.
202
+ if b is None:
203
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
204
+
205
+ # Construct internal sign tensor only if gradients are needed.
206
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
207
+
208
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
209
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
210
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
211
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
212
+
213
+ # Call C++/Cuda plugin if datatype is supported.
214
+ if x.dtype in [torch.float16, torch.float32]:
215
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
216
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
217
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
218
+ else:
219
+ return_code = -1
220
+
221
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
222
+ # only the bit-packed sign tensor is retained for gradient computation.
223
+ if return_code < 0:
224
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
225
+
226
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
227
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
228
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
229
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
230
+
231
+ # Prepare for gradient computation.
232
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
233
+ ctx.x_shape = x.shape
234
+ ctx.y_shape = y.shape
235
+ ctx.s_ofs = sx, sy
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ fu, fd, si = ctx.saved_tensors
241
+ _, _, xh, xw = ctx.x_shape
242
+ _, _, yh, yw = ctx.y_shape
243
+ sx, sy = ctx.s_ofs
244
+ dx = None # 0
245
+ dfu = None; assert not ctx.needs_input_grad[1]
246
+ dfd = None; assert not ctx.needs_input_grad[2]
247
+ db = None # 3
248
+ dsi = None; assert not ctx.needs_input_grad[4]
249
+ dsx = None; assert not ctx.needs_input_grad[5]
250
+ dsy = None; assert not ctx.needs_input_grad[6]
251
+
252
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
253
+ pp = [
254
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
255
+ xw * up - yw * down + px0 - (up - 1),
256
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
257
+ xh * up - yh * down + py0 - (up - 1),
258
+ ]
259
+ gg = gain * (up ** 2) / (down ** 2)
260
+ ff = (not flip_filter)
261
+ sx = sx - (fu.shape[-1] - 1) + px0
262
+ sy = sy - (fu.shape[0] - 1) + py0
263
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
264
+
265
+ if ctx.needs_input_grad[3]:
266
+ db = dx.sum([0, 2, 3])
267
+
268
+ return dx, dfu, dfd, db, dsi, dsx, dsy
269
+
270
+ # Add to cache.
271
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
272
+ return FilteredLReluCuda
273
+
274
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
ADD/th_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
ADD/th_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
ADD/th_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+ from pkg_resources import parse_version
16
+
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ enabled = False # Enable the custom op by setting this to true.
24
+ _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def grid_sample(input, grid):
29
+ if _should_use_custom_op():
30
+ return _GridSample2dForward.apply(input, grid)
31
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def _should_use_custom_op():
36
+ return enabled
37
+
38
+ #----------------------------------------------------------------------------
39
+
40
+ class _GridSample2dForward(torch.autograd.Function):
41
+ @staticmethod
42
+ def forward(ctx, input, grid):
43
+ assert input.ndim == 4
44
+ assert grid.ndim == 4
45
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46
+ ctx.save_for_backward(input, grid)
47
+ return output
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ input, grid = ctx.saved_tensors
52
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53
+ return grad_input, grad_grid
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ class _GridSample2dBackward(torch.autograd.Function):
58
+ @staticmethod
59
+ def forward(ctx, grad_output, input, grid):
60
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61
+ if _use_pytorch_1_11_api:
62
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
63
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
64
+ else:
65
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
+ ctx.save_for_backward(grid)
67
+ return grad_input, grad_grid
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
+ _ = grad2_grad_grid # unused
72
+ grid, = ctx.saved_tensors
73
+ grad2_grad_output = None
74
+ grad2_input = None
75
+ grad2_grid = None
76
+
77
+ if ctx.needs_input_grad[0]:
78
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
+
80
+ assert not ctx.needs_input_grad[2]
81
+ return grad2_grad_output, grad2_input, grad2_grid
82
+
83
+ #----------------------------------------------------------------------------
ADD/th_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
25
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
+
33
+ // Create output tensor.
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
+
42
+ // Initialize CUDA kernel parameters.
43
+ upfirdn2d_kernel_params p;
44
+ p.x = x.data_ptr();
45
+ p.f = f.data_ptr<float>();
46
+ p.y = y.data_ptr();
47
+ p.up = make_int2(upx, upy);
48
+ p.down = make_int2(downx, downy);
49
+ p.pad0 = make_int2(padx0, pady0);
50
+ p.flip = (flip) ? 1 : 0;
51
+ p.gain = gain;
52
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
+
61
+ // Choose CUDA kernel.
62
+ upfirdn2d_kernel_spec spec;
63
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
+ {
65
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
+ });
67
+
68
+ // Set looping options.
69
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
+ p.loopMinor = spec.loopMinor;
71
+ p.loopX = spec.loopX;
72
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
+
75
+ // Compute grid size.
76
+ dim3 blockSize, gridSize;
77
+ if (spec.tileOutW < 0) // large
78
+ {
79
+ blockSize = dim3(4, 32, 1);
80
+ gridSize = dim3(
81
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
+ p.launchMajor);
84
+ }
85
+ else // small
86
+ {
87
+ blockSize = dim3(256, 1, 1);
88
+ gridSize = dim3(
89
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
+ p.launchMajor);
92
+ }
93
+
94
+ // Launch CUDA kernel.
95
+ void* args[] = {&p};
96
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
+ return y;
98
+ }
99
+
100
+ //------------------------------------------------------------------------
101
+
102
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
+ {
104
+ m.def("upfirdn2d", &upfirdn2d);
105
+ }
106
+
107
+ //------------------------------------------------------------------------
ADD/th_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
+
211
+ // No up/downsampling.
212
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
+ {
214
+ // contiguous
215
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
+ // channels_last
229
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
+ }
243
+
244
+ // 2x upsampling.
245
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
+ {
247
+ // contiguous
248
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ // channels_last
255
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
+ {
264
+ // contiguous
265
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
+ // channels_last
269
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
+ }
273
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
+ {
275
+ // contiguous
276
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
+ // channels_last
280
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
+ }
284
+
285
+ // 2x downsampling.
286
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
+ {
288
+ // contiguous
289
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
+ // channels_last
296
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
+ }
303
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
+ {
305
+ // contiguous
306
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
+ // channels_last
310
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
+ }
314
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
+ {
316
+ // contiguous
317
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
+ // channels_last
321
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
+ }
325
+
326
+ // 4x upsampling.
327
+ if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
+ {
329
+ // contiguous
330
+ if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
+ if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
+ // channels_last
333
+ if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
+ if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
+ }
336
+ if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
+ {
338
+ // contiguous
339
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
+ // channels_last
342
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
+ }
345
+ if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
+ {
347
+ // contiguous
348
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
+ // channels_last
351
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
+ }
354
+
355
+ // 4x downsampling (inefficient).
356
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
+ {
358
+ // contiguous
359
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
+ // channels_last
362
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
+ }
365
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
+ {
367
+ // contiguous
368
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
+ // channels_last
371
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
+ }
374
+ return spec;
375
+ }
376
+
377
+ //------------------------------------------------------------------------
378
+ // Template specializations.
379
+
380
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
+
384
+ //------------------------------------------------------------------------
ADD/th_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
ADD/th_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .. import custom_ops
16
+ from .. import misc
17
+ from . import conv2d_gradfix
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='upfirdn2d_plugin',
28
+ sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
29
+ headers=['upfirdn2d.h'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
32
+ )
33
+ return True
34
+
35
+ def _parse_scaling(scaling):
36
+ if isinstance(scaling, int):
37
+ scaling = [scaling, scaling]
38
+ assert isinstance(scaling, (list, tuple))
39
+ assert all(isinstance(x, int) for x in scaling)
40
+ sx, sy = scaling
41
+ assert sx >= 1 and sy >= 1
42
+ return sx, sy
43
+
44
+ def _parse_padding(padding):
45
+ if isinstance(padding, int):
46
+ padding = [padding, padding]
47
+ assert isinstance(padding, (list, tuple))
48
+ assert all(isinstance(x, int) for x in padding)
49
+ if len(padding) == 2:
50
+ padx, pady = padding
51
+ padding = [padx, padx, pady, pady]
52
+ padx0, padx1, pady0, pady1 = padding
53
+ return padx0, padx1, pady0, pady1
54
+
55
+ def _get_filter_size(f):
56
+ if f is None:
57
+ return 1, 1
58
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
59
+ fw = f.shape[-1]
60
+ fh = f.shape[0]
61
+ with misc.suppress_tracer_warnings():
62
+ fw = int(fw)
63
+ fh = int(fh)
64
+ misc.assert_shape(f, [fh, fw][:f.ndim])
65
+ assert fw >= 1 and fh >= 1
66
+ return fw, fh
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
71
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
72
+
73
+ Args:
74
+ f: Torch tensor, numpy array, or python list of the shape
75
+ `[filter_height, filter_width]` (non-separable),
76
+ `[filter_taps]` (separable),
77
+ `[]` (impulse), or
78
+ `None` (identity).
79
+ device: Result device (default: cpu).
80
+ normalize: Normalize the filter so that it retains the magnitude
81
+ for constant input signal (DC)? (default: True).
82
+ flip_filter: Flip the filter? (default: False).
83
+ gain: Overall scaling factor for signal magnitude (default: 1).
84
+ separable: Return a separable filter? (default: select automatically).
85
+
86
+ Returns:
87
+ Float32 tensor of the shape
88
+ `[filter_height, filter_width]` (non-separable) or
89
+ `[filter_taps]` (separable).
90
+ """
91
+ # Validate.
92
+ if f is None:
93
+ f = 1
94
+ f = torch.as_tensor(f, dtype=torch.float32)
95
+ assert f.ndim in [0, 1, 2]
96
+ assert f.numel() > 0
97
+ if f.ndim == 0:
98
+ f = f[np.newaxis]
99
+
100
+ # Separable?
101
+ if separable is None:
102
+ separable = (f.ndim == 1 and f.numel() >= 8)
103
+ if f.ndim == 1 and not separable:
104
+ f = f.ger(f)
105
+ assert f.ndim == (1 if separable else 2)
106
+
107
+ # Apply normalize, flip, gain, and device.
108
+ if normalize:
109
+ f /= f.sum()
110
+ if flip_filter:
111
+ f = f.flip(list(range(f.ndim)))
112
+ f = f * (gain ** (f.ndim / 2))
113
+ f = f.to(device=device)
114
+ return f
115
+
116
+ #----------------------------------------------------------------------------
117
+
118
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
119
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
120
+
121
+ Performs the following sequence of operations for each channel:
122
+
123
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
124
+
125
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
126
+ Negative padding corresponds to cropping the image.
127
+
128
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
129
+ so that the footprint of all output pixels lies within the input image.
130
+
131
+ 4. Downsample the image by keeping every Nth pixel (`down`).
132
+
133
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
134
+ The fused op is considerably more efficient than performing the same calculation
135
+ using standard PyTorch ops. It supports gradients of arbitrary order.
136
+
137
+ Args:
138
+ x: Float32/float64/float16 input tensor of the shape
139
+ `[batch_size, num_channels, in_height, in_width]`.
140
+ f: Float32 FIR filter of the shape
141
+ `[filter_height, filter_width]` (non-separable),
142
+ `[filter_taps]` (separable), or
143
+ `None` (identity).
144
+ up: Integer upsampling factor. Can be a single int or a list/tuple
145
+ `[x, y]` (default: 1).
146
+ down: Integer downsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ padding: Padding with respect to the upsampled image. Can be a single number
149
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
150
+ (default: 0).
151
+ flip_filter: False = convolution, True = correlation (default: False).
152
+ gain: Overall scaling factor for signal magnitude (default: 1).
153
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
154
+
155
+ Returns:
156
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
157
+ """
158
+ assert isinstance(x, torch.Tensor)
159
+ assert impl in ['ref', 'cuda']
160
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
161
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
162
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
163
+
164
+ #----------------------------------------------------------------------------
165
+
166
+ @misc.profiled_function
167
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
168
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
169
+ """
170
+ # Validate arguments.
171
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
172
+ if f is None:
173
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
174
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
175
+ assert f.dtype == torch.float32 and not f.requires_grad
176
+ batch_size, num_channels, in_height, in_width = x.shape
177
+ upx, upy = _parse_scaling(up)
178
+ downx, downy = _parse_scaling(down)
179
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
180
+
181
+ # Check that upsampled buffer is not smaller than the filter.
182
+ upW = in_width * upx + padx0 + padx1
183
+ upH = in_height * upy + pady0 + pady1
184
+ assert upW >= f.shape[-1] and upH >= f.shape[0]
185
+
186
+ # Upsample by inserting zeros.
187
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
188
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
189
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
190
+
191
+ # Pad or crop.
192
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
193
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
194
+
195
+ # Setup filter.
196
+ f = f * (gain ** (f.ndim / 2))
197
+ f = f.to(x.dtype)
198
+ if not flip_filter:
199
+ f = f.flip(list(range(f.ndim)))
200
+
201
+ # Convolve with the filter.
202
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
203
+ if f.ndim == 4:
204
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
205
+ else:
206
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
207
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
208
+
209
+ # Downsample by throwing away pixels.
210
+ x = x[:, :, ::downy, ::downx]
211
+ return x
212
+
213
+ #----------------------------------------------------------------------------
214
+
215
+ _upfirdn2d_cuda_cache = dict()
216
+
217
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
218
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
219
+ """
220
+ # Parse arguments.
221
+ upx, upy = _parse_scaling(up)
222
+ downx, downy = _parse_scaling(down)
223
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
224
+
225
+ # Lookup from cache.
226
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
227
+ if key in _upfirdn2d_cuda_cache:
228
+ return _upfirdn2d_cuda_cache[key]
229
+
230
+ # Forward op.
231
+ class Upfirdn2dCuda(torch.autograd.Function):
232
+ @staticmethod
233
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
234
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
235
+ if f is None:
236
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
237
+ if f.ndim == 1 and f.shape[0] == 1:
238
+ f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
239
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
240
+ y = x
241
+ if f.ndim == 2:
242
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
243
+ else:
244
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
245
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
246
+ ctx.save_for_backward(f)
247
+ ctx.x_shape = x.shape
248
+ return y
249
+
250
+ @staticmethod
251
+ def backward(ctx, dy): # pylint: disable=arguments-differ
252
+ f, = ctx.saved_tensors
253
+ _, _, ih, iw = ctx.x_shape
254
+ _, _, oh, ow = dy.shape
255
+ fw, fh = _get_filter_size(f)
256
+ p = [
257
+ fw - padx0 - 1,
258
+ iw * upx - ow * downx + padx0 - upx + 1,
259
+ fh - pady0 - 1,
260
+ ih * upy - oh * downy + pady0 - upy + 1,
261
+ ]
262
+ dx = None
263
+ df = None
264
+
265
+ if ctx.needs_input_grad[0]:
266
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
267
+
268
+ assert not ctx.needs_input_grad[1]
269
+ return dx, df
270
+
271
+ # Add to cache.
272
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
273
+ return Upfirdn2dCuda
274
+
275
+ #----------------------------------------------------------------------------
276
+
277
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
278
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
279
+
280
+ By default, the result is padded so that its shape matches the input.
281
+ User-specified padding is applied on top of that, with negative values
282
+ indicating cropping. Pixels outside the image are assumed to be zero.
283
+
284
+ Args:
285
+ x: Float32/float64/float16 input tensor of the shape
286
+ `[batch_size, num_channels, in_height, in_width]`.
287
+ f: Float32 FIR filter of the shape
288
+ `[filter_height, filter_width]` (non-separable),
289
+ `[filter_taps]` (separable), or
290
+ `None` (identity).
291
+ padding: Padding with respect to the output. Can be a single number or a
292
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
293
+ (default: 0).
294
+ flip_filter: False = convolution, True = correlation (default: False).
295
+ gain: Overall scaling factor for signal magnitude (default: 1).
296
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
297
+
298
+ Returns:
299
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
300
+ """
301
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
302
+ fw, fh = _get_filter_size(f)
303
+ p = [
304
+ padx0 + fw // 2,
305
+ padx1 + (fw - 1) // 2,
306
+ pady0 + fh // 2,
307
+ pady1 + (fh - 1) // 2,
308
+ ]
309
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
314
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
315
+
316
+ By default, the result is padded so that its shape is a multiple of the input.
317
+ User-specified padding is applied on top of that, with negative values
318
+ indicating cropping. Pixels outside the image are assumed to be zero.
319
+
320
+ Args:
321
+ x: Float32/float64/float16 input tensor of the shape
322
+ `[batch_size, num_channels, in_height, in_width]`.
323
+ f: Float32 FIR filter of the shape
324
+ `[filter_height, filter_width]` (non-separable),
325
+ `[filter_taps]` (separable), or
326
+ `None` (identity).
327
+ up: Integer upsampling factor. Can be a single int or a list/tuple
328
+ `[x, y]` (default: 1).
329
+ padding: Padding with respect to the output. Can be a single number or a
330
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
331
+ (default: 0).
332
+ flip_filter: False = convolution, True = correlation (default: False).
333
+ gain: Overall scaling factor for signal magnitude (default: 1).
334
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
335
+
336
+ Returns:
337
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
338
+ """
339
+ upx, upy = _parse_scaling(up)
340
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
341
+ fw, fh = _get_filter_size(f)
342
+ p = [
343
+ padx0 + (fw + upx - 1) // 2,
344
+ padx1 + (fw - upx) // 2,
345
+ pady0 + (fh + upy - 1) // 2,
346
+ pady1 + (fh - upy) // 2,
347
+ ]
348
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
349
+
350
+ #----------------------------------------------------------------------------
351
+
352
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
353
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
354
+
355
+ By default, the result is padded so that its shape is a fraction of the input.
356
+ User-specified padding is applied on top of that, with negative values
357
+ indicating cropping. Pixels outside the image are assumed to be zero.
358
+
359
+ Args:
360
+ x: Float32/float64/float16 input tensor of the shape
361
+ `[batch_size, num_channels, in_height, in_width]`.
362
+ f: Float32 FIR filter of the shape
363
+ `[filter_height, filter_width]` (non-separable),
364
+ `[filter_taps]` (separable), or
365
+ `None` (identity).
366
+ down: Integer downsampling factor. Can be a single int or a list/tuple
367
+ `[x, y]` (default: 1).
368
+ padding: Padding with respect to the input. Can be a single number or a
369
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
370
+ (default: 0).
371
+ flip_filter: False = convolution, True = correlation (default: False).
372
+ gain: Overall scaling factor for signal magnitude (default: 1).
373
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
374
+
375
+ Returns:
376
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
377
+ """
378
+ downx, downy = _parse_scaling(down)
379
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
380
+ fw, fh = _get_filter_size(f)
381
+ p = [
382
+ padx0 + (fw - downx + 1) // 2,
383
+ padx1 + (fw - downx) // 2,
384
+ pady0 + (fh - downy + 1) // 2,
385
+ pady1 + (fh - downy) // 2,
386
+ ]
387
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
388
+
389
+ #----------------------------------------------------------------------------
ADD/utils/util_net.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Power by Zongsheng Yue 2021-11-24 20:29:36
4
+
5
+ import math
6
+ import torch
7
+ from pathlib import Path
8
+ from collections import OrderedDict
9
+ import torch.nn.functional as F
10
+ from copy import deepcopy
11
+
12
+ def calculate_parameters(net):
13
+ out = 0
14
+ for param in net.parameters():
15
+ out += param.numel()
16
+ return out
17
+
18
+ def pad_input(x, mod):
19
+ h, w = x.shape[-2:]
20
+ bottom = int(math.ceil(h/mod)*mod -h)
21
+ right = int(math.ceil(w/mod)*mod - w)
22
+ x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')
23
+ return x_pad
24
+
25
+ def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):
26
+ n_GPUs = 1
27
+ b, c, h, w = x.size()
28
+ h_half, w_half = h // 2, w // 2
29
+ h_size, w_size = h_half + shave, w_half + shave
30
+ lr_list = [
31
+ x[:, :, 0:h_size, 0:w_size],
32
+ x[:, :, 0:h_size, (w - w_size):w],
33
+ x[:, :, (h - h_size):h, 0:w_size],
34
+ x[:, :, (h - h_size):h, (w - w_size):w]]
35
+
36
+ if w_size * h_size < min_size:
37
+ sr_list = []
38
+ for i in range(0, 4, n_GPUs):
39
+ lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
40
+ if net_kwargs is None:
41
+ sr_batch = net(lr_batch)
42
+ else:
43
+ sr_batch = net(lr_batch, **net_kwargs)
44
+ sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
45
+ else:
46
+ sr_list = [
47
+ forward_chop(patch, shave=shave, min_size=min_size) \
48
+ for patch in lr_list
49
+ ]
50
+
51
+ h, w = scale * h, scale * w
52
+ h_half, w_half = scale * h_half, scale * w_half
53
+ h_size, w_size = scale * h_size, scale * w_size
54
+ shave *= scale
55
+
56
+ output = x.new(b, c, h, w)
57
+ output[:, :, 0:h_half, 0:w_half] \
58
+ = sr_list[0][:, :, 0:h_half, 0:w_half]
59
+ output[:, :, 0:h_half, w_half:w] \
60
+ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
61
+ output[:, :, h_half:h, 0:w_half] \
62
+ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
63
+ output[:, :, h_half:h, w_half:w] \
64
+ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
65
+
66
+ return output
67
+
68
+ def measure_time(net, inputs, num_forward=100):
69
+ '''
70
+ Measuring the average runing time (seconds) for pytorch.
71
+ out = net(*inputs)
72
+ '''
73
+ start = torch.cuda.Event(enable_timing=True)
74
+ end = torch.cuda.Event(enable_timing=True)
75
+
76
+ start.record()
77
+ with torch.set_grad_enabled(False):
78
+ for _ in range(num_forward):
79
+ out = net(*inputs)
80
+ end.record()
81
+
82
+ torch.cuda.synchronize()
83
+
84
+ return start.elapsed_time(end) / 1000
85
+
86
+ def reload_model(model, ckpt):
87
+ if list(model.state_dict().keys())[0].startswith('module.'):
88
+ if list(ckpt.keys())[0].startswith('module.'):
89
+ ckpt = ckpt
90
+ else:
91
+ ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()})
92
+ else:
93
+ if list(ckpt.keys())[0].startswith('module.'):
94
+ ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
95
+ else:
96
+ ckpt = ckpt
97
+ model.load_state_dict(ckpt, True)
98
+
99
+ def compute_hinge_loss(real_output, fake_output, x_start_, r1_lambda):
100
+ if r1_lambda == 0:
101
+ real_loss_total = torch.relu(torch.ones_like(real_output) - real_output).mean()
102
+ fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
103
+
104
+ else:
105
+ real_loss_ = torch.relu(torch.ones_like(real_output) - real_output).mean()
106
+
107
+ # 计算真实样本的梯度
108
+ grad_real = torch.autograd.grad(outputs=real_output.sum(), inputs=x_start_, create_graph=True)[0]
109
+
110
+ # 计算梯度惩罚
111
+ grad_penalty = (grad_real.contiguous().view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() * r1_lambda
112
+
113
+ real_loss_total = real_loss_ + grad_penalty
114
+ fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
115
+
116
+ real_loss = real_loss_total
117
+ fake_loss = fake_loss_total
118
+
119
+ loss_d = real_loss + fake_loss
120
+
121
+ return loss_d
122
+
123
+
124
+
125
+ def reload_model_(model, ckpt):
126
+ if list(model.state_dict().keys())[0].startswith('model.'):
127
+ if list(ckpt.keys())[0].startswith('model.'):
128
+ ckpt = ckpt
129
+ else:
130
+ ckpt = OrderedDict({f'model.{key}':value for key, value in ckpt.items()})
131
+ else:
132
+ if list(ckpt.keys())[0].startswith('model.'):
133
+ ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
134
+ else:
135
+ ckpt = ckpt
136
+ model.load_state_dict(ckpt, True)
137
+
138
+
139
+
140
+ def reload_model_IDE(model, ckpt):
141
+ extracted_dict = OrderedDict()
142
+ for key, value in ckpt.items():
143
+ if key.startswith('E_st'):
144
+ new_key = key.replace('E_st.', '')
145
+ extracted_dict[new_key] = value
146
+
147
+ model.load_state_dict(extracted_dict, True)
148
+
149
+
150
+
151
+ class EMA():
152
+ def __init__(self, model, decay):
153
+ self.model = model
154
+ self.decay = decay
155
+ self.shadow = {}
156
+ self.backup = {}
157
+
158
+ def register(self):
159
+ for name, param in self.model.named_parameters():
160
+ if param.requires_grad:
161
+ self.shadow[name] = param.data.clone()
162
+
163
+ def update(self):
164
+ for name, param in self.model.named_parameters():
165
+ if param.requires_grad:
166
+ assert name in self.shadow
167
+ new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
168
+ self.shadow[name] = new_average.clone()
169
+
170
+ def apply_shadow(self):
171
+ for name, param in self.model.named_parameters():
172
+ if param.requires_grad:
173
+ assert name in self.shadow
174
+ self.backup[name] = param.data
175
+ param.data = self.shadow[name]
176
+
177
+ def restore(self):
178
+ for name, param in self.model.named_parameters():
179
+ if param.requires_grad:
180
+ assert name in self.backup
181
+ param.data = self.backup[name]
182
+ self.backup = {}
README.md CHANGED
@@ -1,15 +1,291 @@
1
- ---
2
- python_version: 3.9
3
- title: TextureUpscaleBeta
4
- emoji: ⚡
5
- colorFrom: red
6
- colorTo: blue
7
- sdk: gradio
8
- sdk_version: 5.9.1
9
- app_file: app.py
10
- pinned: false
11
- preload_from_hub:
12
- - NightRaven109/CCSRModels
13
- ---
14
-
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="figs/logo.png" width="400">
3
+ </p>
4
+
5
+ <div align="center">
6
+ <h2>Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution</h2>
7
+
8
+
9
+ <a href='https://arxiv.org/pdf/2401.00877'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
10
+
11
+
12
+ [Lingchen Sun](https://scholar.google.com/citations?hl=zh-CN&tzom=-480&user=ZCDjTn8AAAAJ)<sup>1,2</sup>
13
+ | [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ&hl=zh-CN)<sup>1,2</sup> |
14
+ [Jie Liang](https://scholar.google.com.sg/citations?user=REWxLZsAAAAJ&hl)<sup>2</sup> |
15
+ [Zhengqiang Zhang](https://scholar.google.com/citations?hl=zh-CN&user=UX26wSMAAAAJ&view_op=list_works&sortby=pubdate)<sup>1,2</sup> |
16
+ [Hongwei Yong](https://scholar.google.com.hk/citations?user=Xii74qQAAAAJ&hl=zh-CN)<sup>1</sup> |
17
+ [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)<sup>1,2</sup>
18
+
19
+ <sup>1</sup>The Hong Kong Polytechnic University, <sup>2</sup>OPPO Research Institute
20
+ </div>
21
+
22
+ :star: If CCSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:
23
+
24
+ ## 🧡ྀི What's New in CCSR-v2?
25
+ We have implemented the CCSR-v2 code based on the [Diffusers](https://github.com/huggingface/diffusers). Compared to CCSR-v1, CCSR-v2 brings a host of upgrades:
26
+
27
+ - 🛠️**Step Flexibility**: Offers flexibility in diffusion step selection, **allowing users to freely adjust the number of steps to suit their specific requirements**. This adaptability **requires no additional re-training**, ensuring seamless integration into diverse workflows.
28
+ - ⚡**Efficiency**: Supports highly efficient inference with **as few as 2 or even 1 diffusion step**, drastically reducing computation time without compromising quality.
29
+ - 📈**Enhanced Clarity**: With upgraded algorithms, CCSR-v2 restores images with crisper details while maintaining fidelity.
30
+ - ⚖️**Results stability**: CCSR-v2 exhibits significantly improved stability in synthesizing fine image details, ensuring higher-quality outputs.
31
+ - 🔄**Stage 2 Refinement**: In CCSR-v2, the output $\hat{x}_{0 \gets T}$ from Stage 1 is now directly fed into Stage 2, streamlining the restoration process into an efficient one-step diffusion workflow. This strategy boosts both speed and performance.
32
+
33
+ ![ccsr](figs/fig.png)
34
+ Visual comparisons between the SR outputs with the same input low-quality image but two different noise samples by different DM-based
35
+ methods. `S` denotes diffusion sampling timesteps. Existing DM-based methods, including StableSR, PASD, SeeSR, SUPIR and AddSR, **show noticeable instability with the different noise samples**. OSEDiff directly takes low-quality image as input without
36
+ noise sampling. It is deterministic and stable, but **cannot perform multi-step diffusion** for high generative capacity. In contrast, **our proposed CCSR method
37
+ is flexible for both multi-step diffusion and single-step diffusion, while producing stable results with high fidelity and visual quality**.
38
+
39
+ ## ⏰ Update
40
+ - **2024.12.12**: Code and models for CCSR-v2 are released. 👀 Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v2.0).
41
+ - **2024.9.25**: ⭐[CCSR-v2](https://arxiv.org/pdf/2401.00877) is released, offering reduced step requirements and supporting flexible diffusion step selection (2 or even 1 step) during the inference stage without the need for re-training.
42
+ - **2023.12.23**: Code and models for [CCSR-v1](https://arxiv.org/pdf/2401.00877v1) are released. Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v1.0).
43
+
44
+
45
+ ## 🌟 Overview Framework
46
+ ![ccsr](figs/framework.png)
47
+
48
+ ## 😍 Visual Results
49
+ ### Demo on Real-world SR
50
+
51
+ [<img src="figs/compare_1.png" height="213px"/>](https://imgsli.com/MzI2MTg5) [<img src="figs/compare_2.png" height="213px"/>](https://imgsli.com/MzI2MTky/1/3) [<img src="figs/compare_3.png" height="213px"/>](https://imgsli.com/MzI2MTk0/0/2) [<img src="figs/compare_4.png" height="213px"/>](https://imgsli.com/MzI2MTk1/0/2)
52
+
53
+
54
+ ![ccsr](figs/compare_standard.png)
55
+
56
+ ![ccsr](figs/compare_efficient.png)
57
+ For more comparisons, please refer to our paper for details.
58
+
59
+ ## 📝 Quantitative comparisons
60
+ We propose new stability metrics, namely global standard deviation (G-STD) and local standard deviation (L-STD), to respectively measure the image-level and pixel-level variations of the SR results of diffusion-based methods.
61
+
62
+ More details about G-STD and L-STD can be found in our paper.
63
+
64
+ ![ccsr](figs/table.png)
65
+ ## ⚙ Dependencies and Installation
66
+ ```shell
67
+ ## git clone this repository
68
+ git clone https://github.com/csslc/CCSR.git
69
+ cd CCSR
70
+
71
+
72
+ # create an environment with python >= 3.9
73
+ conda create -n ccsr python=3.9
74
+ conda activate ccsr
75
+ pip install -r requirements.txt
76
+ ```
77
+ ## 🍭 Quick Inference
78
+ **For ease of comparison, we have provided the test results of CCSR-v2 on the DIV2K, RealSR, and DrealSR benchmarks with varying diffusion steps, which can be accessed via [Google Drive](https://drive.google.com/drive/folders/1xjURQZgKAlENzMnAJA2PDG9h_UxfZzio?usp=sharing).**
79
+
80
+ #### Step 1: Download the pretrained models
81
+ - Download the pretrained SD-2.1-base models from [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
82
+ - Download the CCSR-v2 models from and put the models in the `preset/models`:
83
+
84
+ | Model Name | Description | GoogleDrive | BaiduNetdisk |
85
+ |:-----------------------|:---------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------|
86
+ | Controlnet | Trained in the stage 1. | [download](https://drive.google.com/drive/folders/1aHwgodKwKYZJBKs0QlFzanSjMDhrNyRA?usp=sharing) | [download](https://pan.baidu.com/s/1SKS70iE4GhhHGxqY1KS8mw) (pwd: ccsr) |
87
+ | VAE | Trained in the stage 2. | [download](https://drive.google.com/drive/folders/1yHfMV81Md6db4StHTP5MC-eSeLFeBKm8?usp=sharing) | [download](https://pan.baidu.com/s/1fxOIeL6Hk6Muq9h8itAIKQ) (pwd: ccsr) |
88
+ | Pre-trained Controlnet | The pre-trained model of stage1. | [download](https://drive.google.com/drive/folders/1LTtBRuObITOJwbW-sTDnHtp8xIUZFDHh?usp=sharing) | [download](https://pan.baidu.com/s/1mDeuHBqNj_Iol7PCY_Xfww) (pwd: ccsr) |
89
+ | Dino models | The pre-trained models for disc. | [download](https://drive.google.com/drive/folders/1PcuZGUTJlltdPz2yk2ZIa4GCtb1yk_y6?usp=sharing) | [download](https://pan.baidu.com/s/1nPdNwgua91mDDRApWUm39Q) (pwd: ccsr) |
90
+
91
+ #### Step 2: Prepare testing data
92
+ You can put the testing images in the `preset/test_datasets`.
93
+
94
+ #### Step 3: Running testing command
95
+ For one-step diffusion process:
96
+ ```
97
+ python test_ccsr_tile.py \
98
+ --pretrained_model_path preset/models/stable-diffusion-2-1-base \
99
+ --controlnet_model_path preset/models \
100
+ --vae_model_path preset/models \
101
+ --baseline_name ccsr-v2 \
102
+ --image_path preset/test_datasets \
103
+ --output_dir experiments/test \
104
+ --sample_method ddpm \
105
+ --num_inference_steps 1 \
106
+ --t_min 0.0 \
107
+ --start_point lr \
108
+ --start_steps 999 \
109
+ --process_size 512 \
110
+ --guidance_scale 1.0 \
111
+ --sample_times 1 \
112
+ --use_vae_encode_condition \
113
+ --upscale 4
114
+ ```
115
+ For multi-step diffusion process:
116
+ ```
117
+ python test_ccsr_tile.py \
118
+ --pretrained_model_path preset/models/stable-diffusion-2-1-base \
119
+ --controlnet_model_path preset/models \
120
+ --vae_model_path preset/models \
121
+ --baseline_name ccsr-v2 \
122
+ --image_path preset/test_datasets \
123
+ --output_dir experiments/test \
124
+ --sample_method ddpm \
125
+ --num_inference_steps 6 \
126
+ --t_max 0.6667 \
127
+ --t_min 0.5 \
128
+ --start_point lr \
129
+ --start_steps 999 \
130
+ --process_size 512 \
131
+ --guidance_scale 4.5 \
132
+ --sample_times 1 \
133
+ --use_vae_encode_condition \
134
+ --upscale 4
135
+ ```
136
+ We integrate [tile_diffusion](https://github.com/albarji/mixture-of-diffusers) and [tile_vae](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/tree/main) to the [test_ccsr_tile.py](test_ccsr_tile.py) to save the GPU memory for inference.
137
+ You can change the tile size and stride according to the VRAM of your device.
138
+ ```
139
+ python test_ccsr_tile.py \
140
+ --pretrained_model_path preset/models/stable-diffusion-2-1-base \
141
+ --controlnet_model_path preset/models \
142
+ --vae_model_path preset/models \
143
+ --baseline_name ccsr-v2 \
144
+ --image_path preset/test_datasets \
145
+ --output_dir experiments/test \
146
+ --sample_method ddpm \
147
+ --num_inference_steps 6 \
148
+ --t_max 0.6667 \
149
+ --t_min 0.5 \
150
+ --start_point lr \
151
+ --start_steps 999 \
152
+ --process_size 512 \
153
+ --guidance_scale 4.5 \
154
+ --sample_times 1 \
155
+ --use_vae_encode_condition \
156
+ --upscale 4 \
157
+ --tile_diffusion \
158
+ --tile_diffusion_size 512 \
159
+ --tile_diffusion_stride 256 \
160
+ --tile_vae \
161
+ --vae_decoder_tile_size 224 \
162
+ --vae_encoder_tile_size 1024 \
163
+ ```
164
+
165
+ You can obtain `N` different SR results by setting `sample_times` as `N` to test the stability of CCSR. The data folder should be like this:
166
+
167
+ ```
168
+ experiments/test
169
+ ├── sample00 # the first group of SR results
170
+ └── sample01 # the second group of SR results
171
+ ...
172
+ └── sampleN # the N-th group of SR results
173
+ ```
174
+
175
+ ## 📏 Evaluation
176
+ 1. Calculate the Image Quality Assessment for each restored group.
177
+
178
+ Fill in the required information in [cal_iqa.py](cal_iqa/cal_iqa.py) and run, then you can obtain the evaluation results in the folder like this:
179
+ ```
180
+ log_path
181
+ ├── log_name_npy # save the IQA values of each restored group as the npy files
182
+ └── log_name.log # log recode
183
+ ```
184
+
185
+ 2. Calculate the G-STD value for the diffusion-based SR method.
186
+
187
+ Fill in the required information in [iqa_G-STD.py](cal_iqa/iqa_G-STD.py) and run, then you can obtain the mean IQA values of N restored groups and G-STD value.
188
+
189
+ 3. Calculate the L-STD value for the diffusion-based SR method.
190
+
191
+ Fill in the required information in [iqa_L-STD.py](cal_iqa/iqa_L-STD.py) and run, then you can obtain the L-STD value.
192
+
193
+
194
+ ## 🚋 Train
195
+
196
+ #### Step1: Prepare training data
197
+ Generate txt file for the training set.
198
+ Fill in the required information in [get_path](scripts/get_path.py) and run, then you can obtain the txt file recording the paths of ground-truth images.
199
+ You can save the txt file into `preset/gt_path.txt`.
200
+
201
+ #### Step2: Train Stage1 Model
202
+ 1. Download pretrained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to provide generative capabilities.
203
+
204
+ ```shell
205
+ wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt --no-check-certificate
206
+ ```
207
+
208
+ 2. Start training.
209
+
210
+ ```shell
211
+ CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage1.py \
212
+ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
213
+ --controlnet_model_name_or_path='preset/models/pretrained_controlnet' \
214
+ --enable_xformers_memory_efficient_attention \
215
+ --output_dir="./experiments/ccsrv2_stage1" \
216
+ --mixed_precision="fp16" \
217
+ --resolution=512 \
218
+ --learning_rate=5e-5 \
219
+ --train_batch_size=4 \
220
+ --gradient_accumulation_steps=6 \
221
+ --dataloader_num_workers=0 \
222
+ --checkpointing_steps=500 \
223
+ --t_max=0.6667 \
224
+ --max_train_steps=20000 \
225
+ --dataset_root_folders 'preset/gt_path.txt'
226
+ ```
227
+
228
+ #### Step3: Train Stage2 Model
229
+ 1. Put the model obtained from the stage1 into `controlnet_model_name_or_path`.
230
+ 2. Start training.
231
+ ```shell
232
+ CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage2.py \
233
+ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
234
+ --controlnet_model_name_or_path='preset/models/model_stage1' \
235
+ --enable_xformers_memory_efficient_attention \
236
+ --output_dir="./experiments/ccsrv2_stage2" \
237
+ --mixed_precision="fp16" \
238
+ --resolution=512 \
239
+ --learning_rate=5e-6 \
240
+ --train_batch_size=2 \
241
+ --gradient_accumulation_steps=8 \
242
+ --checkpointing_steps=500 \
243
+ --is_start_lr=True \
244
+ --t_max=0.6667 \
245
+ --num_inference_steps=1 \
246
+ --is_module \
247
+ --lambda_l2=1.0 \
248
+ --lambda_lpips=1.0 \
249
+ --lambda_disc=0.05 \
250
+ --lambda_disc_train=0.5 \
251
+ --begin_disc=100 \
252
+ --max_train_steps=2000 \
253
+ --dataset_root_folders 'preset/gt_path.txt'
254
+ ```
255
+
256
+
257
+
258
+
259
+
260
+ ### Citations
261
+
262
+ If our code helps your research or work, please consider citing our paper.
263
+ The following are BibTeX references:
264
+
265
+ ```
266
+ @article{sun2023ccsr,
267
+ title={Improving the Stability of Diffusion Models for Content Consistent Super-Resolution},
268
+ author={Sun, Lingchen and Wu, Rongyuan and Zhang, Zhengqiang and Yong, Hongwei and Zhang, Lei},
269
+ journal={arXiv preprint arXiv:2401.00877},
270
+ year={2024}
271
+ }
272
+ ```
273
+
274
+ ### License
275
+ This project is released under the [Apache 2.0 license](LICENSE).
276
+
277
+ ### Acknowledgement
278
+ This project is based on [ControlNet](https://github.com/lllyasviel/ControlNet), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [SeeSR](https://github.com/cswry/SeeSR). Some codes are brought from [ADDSR](https://github.com/NJU-PCALab/AddSR). Thanks for their awesome works.
279
+
280
+ ### Contact
281
+ If you have any questions, please contact: [email protected]
282
+
283
+
284
+ <details>
285
+ <summary>statistics</summary>
286
+
287
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=csslc/CCSR)
288
+
289
+ </details>
290
+
291
+
dataloaders/paired_dataset_txt.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from PIL import Image
4
+ import random
5
+ import numpy as np
6
+
7
+ from torch import nn
8
+ from torchvision import transforms
9
+ from torch.utils import data as data
10
+ import torch.nn.functional as F
11
+
12
+ from .realesrgan import RealESRGAN_degradation
13
+
14
+ class PairedCaptionDataset(data.Dataset):
15
+ def __init__(
16
+ self,
17
+ root_folders=None,
18
+ tokenizer=None,
19
+ gt_ratio=0, # let lr is gt
20
+ ):
21
+ super(PairedCaptionDataset, self).__init__()
22
+
23
+ self.gt_ratio = gt_ratio
24
+ with open(root_folders, 'r') as f:
25
+ self.gt_list = [line.strip() for line in f.readlines()]
26
+
27
+ self.img_preproc = transforms.Compose([
28
+ transforms.RandomCrop((512, 512)),
29
+ transforms.Resize((512, 512)),
30
+ transforms.RandomHorizontalFlip(),
31
+ ])
32
+
33
+ self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')
34
+ self.tokenizer = tokenizer
35
+
36
+
37
+ def tokenize_caption(self, caption=""):
38
+ inputs = self.tokenizer(
39
+ caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
40
+ )
41
+
42
+ return inputs.input_ids
43
+
44
+ def __getitem__(self, index):
45
+
46
+ gt_path = self.gt_list[index]
47
+ gt_img = Image.open(gt_path).convert('RGB')
48
+ gt_img = self.img_preproc(gt_img)
49
+
50
+ gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)
51
+
52
+ if random.random() < self.gt_ratio:
53
+ lq_img = gt_img
54
+ else:
55
+ lq_img = img_t
56
+
57
+ # no caption used
58
+ lq_caption = ''
59
+
60
+ example = dict()
61
+ example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1]
62
+ example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]
63
+ example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0)
64
+
65
+ lq_img = lq_img.squeeze()
66
+
67
+ return example
68
+
69
+ def __len__(self):
70
+ return len(self.gt_list)
dataloaders/params_ccsr.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ color_jitter_prob: 0.0
3
+ gray_prob: 0.0
4
+
5
+ # the first degradation process
6
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
7
+ resize_range: [0.3, 1.5]
8
+ gaussian_noise_prob: 0.5
9
+ noise_range: [1, 15]
10
+ poisson_scale_range: [0.05, 2.0]
11
+ gray_noise_prob: 0.4
12
+ jpeg_range: [60, 95]
13
+
14
+
15
+ # the second degradation process
16
+ second_blur_prob: 0.5
17
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
18
+ resize_range2: [0.6, 1.2]
19
+ gaussian_noise_prob2: 0.5
20
+ noise_range2: [1, 12]
21
+ poisson_scale_range2: [0.05, 1.0]
22
+ gray_noise_prob2: 0.4
23
+ jpeg_range2: [60, 100]
24
+
25
+ kernel_info:
26
+ blur_kernel_size: 21
27
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
28
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
29
+ sinc_prob: 0.1
30
+ blur_sigma: [0.2, 1.5]
31
+ betag_range: [0.5, 2.0]
32
+ betap_range: [1, 1.5]
33
+
34
+ blur_kernel_size2: 11
35
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
36
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
37
+ sinc_prob2: 0.1
38
+ blur_sigma2: [0.2, 1.0]
39
+ betag_range2: [0.5, 2.0]
40
+ betap_range2: [1, 1.5]
41
+
42
+ final_sinc_prob: 0.8
dataloaders/realesrgan.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import glob
5
+ import math
6
+ import yaml
7
+ import random
8
+ from collections import OrderedDict
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from basicsr.data.transforms import augment
13
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
14
+ from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
15
+ from basicsr.utils.img_process_util import filter2D
16
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
17
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
18
+ normalize, rgb_to_grayscale)
19
+
20
+ cur_path = os.path.dirname(os.path.abspath(__file__))
21
+
22
+
23
+ def ordered_yaml():
24
+ """Support OrderedDict for yaml.
25
+
26
+ Returns:
27
+ yaml Loader and Dumper.
28
+ """
29
+ try:
30
+ from yaml import CDumper as Dumper
31
+ from yaml import CLoader as Loader
32
+ except ImportError:
33
+ from yaml import Dumper, Loader
34
+
35
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
36
+
37
+ def dict_representer(dumper, data):
38
+ return dumper.represent_dict(data.items())
39
+
40
+ def dict_constructor(loader, node):
41
+ return OrderedDict(loader.construct_pairs(node))
42
+
43
+ Dumper.add_representer(OrderedDict, dict_representer)
44
+ Loader.add_constructor(_mapping_tag, dict_constructor)
45
+ return Loader, Dumper
46
+
47
+ def opt_parse(opt_path):
48
+ with open(opt_path, mode='r') as f:
49
+ Loader, _ = ordered_yaml()
50
+ opt = yaml.load(f, Loader=Loader)
51
+
52
+ return opt
53
+
54
+ class RealESRGAN_degradation(object):
55
+ def __init__(self, opt_path='', device='cpu'):
56
+ self.opt = opt_parse(opt_path)
57
+ self.device = device #torch.device('cpu')
58
+ optk = self.opt['kernel_info']
59
+
60
+ # blur settings for the first degradation
61
+ self.blur_kernel_size = optk['blur_kernel_size']
62
+ self.kernel_list = optk['kernel_list']
63
+ self.kernel_prob = optk['kernel_prob']
64
+ self.blur_sigma = optk['blur_sigma']
65
+ self.betag_range = optk['betag_range']
66
+ self.betap_range = optk['betap_range']
67
+ self.sinc_prob = optk['sinc_prob']
68
+
69
+ # blur settings for the second degradation
70
+ self.blur_kernel_size2 = optk['blur_kernel_size2']
71
+ self.kernel_list2 = optk['kernel_list2']
72
+ self.kernel_prob2 = optk['kernel_prob2']
73
+ self.blur_sigma2 = optk['blur_sigma2']
74
+ self.betag_range2 = optk['betag_range2']
75
+ self.betap_range2 = optk['betap_range2']
76
+ self.sinc_prob2 = optk['sinc_prob2']
77
+
78
+ # a final sinc filter
79
+ self.final_sinc_prob = optk['final_sinc_prob']
80
+
81
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
82
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
83
+ self.pulse_tensor[10, 10] = 1
84
+
85
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
86
+ self.usm_shaper = USMSharp().to(self.device)
87
+
88
+ def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
89
+ fn_idx = torch.randperm(4)
90
+ for fn_id in fn_idx:
91
+ if fn_id == 0 and brightness is not None:
92
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
93
+ img = adjust_brightness(img, brightness_factor)
94
+
95
+ if fn_id == 1 and contrast is not None:
96
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
97
+ img = adjust_contrast(img, contrast_factor)
98
+
99
+ if fn_id == 2 and saturation is not None:
100
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
101
+ img = adjust_saturation(img, saturation_factor)
102
+
103
+ if fn_id == 3 and hue is not None:
104
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
105
+ img = adjust_hue(img, hue_factor)
106
+ return img
107
+
108
+ def random_augment(self, img_gt):
109
+ # random horizontal flip
110
+ img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
111
+ """
112
+ # random color jitter
113
+ if np.random.uniform() < self.opt['color_jitter_prob']:
114
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
115
+ img_gt = img_gt + jitter_val
116
+ img_gt = np.clip(img_gt, 0, 1)
117
+
118
+ # random grayscale
119
+ if np.random.uniform() < self.opt['gray_prob']:
120
+ #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
121
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
122
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
123
+ """
124
+ # BGR to RGB, HWC to CHW, numpy to tensor
125
+ img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
126
+
127
+ return img_gt
128
+
129
+ def random_kernels(self):
130
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
131
+ kernel_size = random.choice(self.kernel_range)
132
+ if np.random.uniform() < self.sinc_prob:
133
+ # this sinc filter setting is for kernels ranging from [7, 21]
134
+ if kernel_size < 13:
135
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
136
+ else:
137
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
138
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
139
+ else:
140
+ kernel = random_mixed_kernels(
141
+ self.kernel_list,
142
+ self.kernel_prob,
143
+ kernel_size,
144
+ self.blur_sigma,
145
+ self.blur_sigma, [-math.pi, math.pi],
146
+ self.betag_range,
147
+ self.betap_range,
148
+ noise_range=None)
149
+ # pad kernel
150
+ pad_size = (21 - kernel_size) // 2
151
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
152
+
153
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
154
+ kernel_size = random.choice(self.kernel_range)
155
+ if np.random.uniform() < self.sinc_prob2:
156
+ if kernel_size < 13:
157
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
158
+ else:
159
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
160
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
161
+ else:
162
+ kernel2 = random_mixed_kernels(
163
+ self.kernel_list2,
164
+ self.kernel_prob2,
165
+ kernel_size,
166
+ self.blur_sigma2,
167
+ self.blur_sigma2, [-math.pi, math.pi],
168
+ self.betag_range2,
169
+ self.betap_range2,
170
+ noise_range=None)
171
+
172
+ # pad kernel
173
+ pad_size = (21 - kernel_size) // 2
174
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
175
+
176
+ # ------------------------------------- sinc kernel ------------------------------------- #
177
+ if np.random.uniform() < self.final_sinc_prob:
178
+ kernel_size = random.choice(self.kernel_range)
179
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
180
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
181
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
182
+ else:
183
+ sinc_kernel = self.pulse_tensor
184
+
185
+ kernel = torch.FloatTensor(kernel)
186
+ kernel2 = torch.FloatTensor(kernel2)
187
+
188
+ return kernel, kernel2, sinc_kernel
189
+
190
+ @torch.no_grad()
191
+ def degrade_process(self, img_gt, resize_bak=False):
192
+ img_gt = self.random_augment(img_gt)
193
+ kernel1, kernel2, sinc_kernel = self.random_kernels()
194
+ img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
195
+ #img_gt = self.usm_shaper(img_gt) # shaper gt
196
+ ori_h, ori_w = img_gt.size()[2:4]
197
+
198
+ #scale_final = random.randint(4, 16)
199
+ scale_final = 4
200
+
201
+ # ----------------------- The first degradation process ----------------------- #
202
+ # blur
203
+ out = filter2D(img_gt, kernel1)
204
+ # random resize
205
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
206
+ if updown_type == 'up':
207
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
208
+ elif updown_type == 'down':
209
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
210
+ else:
211
+ scale = 1
212
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
213
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
214
+ # noise
215
+ gray_noise_prob = self.opt['gray_noise_prob']
216
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
217
+ out = random_add_gaussian_noise_pt(
218
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
219
+ else:
220
+ out = random_add_poisson_noise_pt(
221
+ out,
222
+ scale_range=self.opt['poisson_scale_range'],
223
+ gray_prob=gray_noise_prob,
224
+ clip=True,
225
+ rounds=False)
226
+ # JPEG compression
227
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
228
+ out = torch.clamp(out, 0, 1)
229
+ out = self.jpeger(out, quality=jpeg_p)
230
+
231
+ # ----------------------- The second degradation process ----------------------- #
232
+ # blur
233
+ if np.random.uniform() < self.opt['second_blur_prob']:
234
+ out = filter2D(out, kernel2)
235
+ # random resize
236
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
237
+ if updown_type == 'up':
238
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
239
+ elif updown_type == 'down':
240
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
241
+ else:
242
+ scale = 1
243
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
244
+ out = F.interpolate(
245
+ out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
246
+ # noise
247
+ gray_noise_prob = self.opt['gray_noise_prob2']
248
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
249
+ out = random_add_gaussian_noise_pt(
250
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
251
+ else:
252
+ out = random_add_poisson_noise_pt(
253
+ out,
254
+ scale_range=self.opt['poisson_scale_range2'],
255
+ gray_prob=gray_noise_prob,
256
+ clip=True,
257
+ rounds=False)
258
+
259
+ # JPEG compression + the final sinc filter
260
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
261
+ # as one operation.
262
+ # We consider two orders:
263
+ # 1. [resize back + sinc filter] + JPEG compression
264
+ # 2. JPEG compression + [resize back + sinc filter]
265
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
266
+ if np.random.uniform() < 0.5:
267
+ # resize back + the final sinc filter
268
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
269
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
270
+ out = filter2D(out, sinc_kernel)
271
+ # JPEG compression
272
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
273
+ out = torch.clamp(out, 0, 1)
274
+ out = self.jpeger(out, quality=jpeg_p)
275
+ else:
276
+ # JPEG compression
277
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
278
+ out = torch.clamp(out, 0, 1)
279
+ out = self.jpeger(out, quality=jpeg_p)
280
+ # resize back + the final sinc filter
281
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
282
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
283
+ out = filter2D(out, sinc_kernel)
284
+
285
+ if np.random.uniform() < self.opt['gray_prob']:
286
+ out = rgb_to_grayscale(out, num_output_channels=1)
287
+
288
+ if np.random.uniform() < self.opt['color_jitter_prob']:
289
+ brightness = self.opt.get('brightness', (0.5, 1.5))
290
+ contrast = self.opt.get('contrast', (0.5, 1.5))
291
+ saturation = self.opt.get('saturation', (0, 1.5))
292
+ hue = self.opt.get('hue', (-0.1, 0.1))
293
+ out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
294
+
295
+ if resize_bak:
296
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
297
+ out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
298
+ # clamp and round
299
+ img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
300
+
301
+ return img_gt, img_lq
302
+
303
+
models/DiffAugment.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 2-Clause "Simplified" License
2
+ # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
+ # All rights reserved.
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without
6
+ # modification, are permitted provided that the following conditions are met:
7
+ #
8
+ # * Redistributions of source code must retain the above copyright notice, this
9
+ # list of conditions and the following disclaimer.
10
+ #
11
+ # * Redistributions in binary form must reproduce the above copyright notice,
12
+ # this list of conditions and the following disclaimer in the documentation
13
+ # and/or other materials provided with the distribution.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ #
26
+ # Code from https://github.com/mit-han-lab/data-efficient-gans
27
+
28
+ """Training GANs with DiffAugment."""
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+
34
+
35
+ def DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor:
36
+ if policy:
37
+ if not channels_first:
38
+ x = x.permute(0, 3, 1, 2)
39
+ for p in policy.split(','):
40
+ for f in AUGMENT_FNS[p]:
41
+ x = f(x)
42
+ if not channels_first:
43
+ x = x.permute(0, 2, 3, 1)
44
+ x = x.contiguous()
45
+ return x
46
+
47
+
48
+ def rand_brightness(x: torch.Tensor) -> torch.Tensor:
49
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
50
+ return x
51
+
52
+
53
+ def rand_saturation(x: torch.Tensor) -> torch.Tensor:
54
+ x_mean = x.mean(dim=1, keepdim=True)
55
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
56
+ return x
57
+
58
+
59
+ def rand_contrast(x: torch.Tensor) -> torch.Tensor:
60
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
61
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
62
+ return x
63
+
64
+
65
+ def rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor:
66
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
67
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
68
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
69
+ grid_batch, grid_x, grid_y = torch.meshgrid(
70
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
71
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
72
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
73
+ )
74
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
75
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
76
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
77
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
78
+ return x
79
+
80
+
81
+ def rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor:
82
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
83
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
84
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
85
+ grid_batch, grid_x, grid_y = torch.meshgrid(
86
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
87
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
88
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
89
+ )
90
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
91
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
92
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
93
+ mask[grid_batch, grid_x, grid_y] = 0
94
+ x = x * mask.unsqueeze(1)
95
+ return x
96
+
97
+
98
+ def rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor:
99
+ resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio
100
+ resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear')
101
+ org_size = x.shape[3]
102
+ if int(resize_ratio*x.shape[3]) < x.shape[3]:
103
+ left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2.
104
+ left_pad = int(left_pad)
105
+ right_pad = x.shape[3] - left_pad - resized_img.shape[3]
106
+ x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.)
107
+ else:
108
+ left = (int(resize_ratio*x.shape[3])-x.shape[3])/2.
109
+ left = int(left)
110
+ x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])]
111
+ assert x.shape[2] == org_size
112
+ assert x.shape[3] == org_size
113
+ return x
114
+
115
+
116
+ AUGMENT_FNS = {
117
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
118
+ 'translation': [rand_translation],
119
+ 'resize': [rand_resize],
120
+ 'cutout': [rand_cutout],
121
+ }
models/controlnet.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalControlnetMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
25
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from .unet_2d_blocks import (
28
+ CrossAttnDownBlock2D,
29
+ DownBlock2D,
30
+ UNetMidBlock2DCrossAttn,
31
+ get_down_block,
32
+ )
33
+ from .unet_2d_condition import UNet2DConditionModel
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @dataclass
40
+ class ControlNetOutput(BaseOutput):
41
+ """
42
+ The output of [`ControlNetModel`].
43
+
44
+ Args:
45
+ down_block_res_samples (`tuple[torch.Tensor]`):
46
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
47
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
48
+ used to condition the original UNet's downsampling activations.
49
+ mid_down_block_re_sample (`torch.Tensor`):
50
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
51
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
52
+ Output can be used to condition the original UNet's middle block activation.
53
+ """
54
+
55
+ down_block_res_samples: Tuple[torch.Tensor]
56
+ mid_block_res_sample: torch.Tensor
57
+
58
+
59
+ class ControlNetConditioningEmbedding(nn.Module):
60
+ """
61
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
62
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
63
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
64
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
65
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
66
+ model) to encode image-space conditions ... into feature maps ..."
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ conditioning_embedding_channels: int,
72
+ conditioning_channels: int = 3,
73
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
74
+ ):
75
+ super().__init__()
76
+
77
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
78
+
79
+ self.blocks = nn.ModuleList([])
80
+
81
+ for i in range(len(block_out_channels) - 1):
82
+ channel_in = block_out_channels[i]
83
+ channel_out = block_out_channels[i + 1]
84
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
85
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
86
+
87
+ self.conv_out = zero_module(
88
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
89
+ )
90
+
91
+ def forward(self, conditioning):
92
+ embedding = self.conv_in(conditioning)
93
+ embedding = F.silu(embedding)
94
+
95
+ for block in self.blocks:
96
+ embedding = block(embedding)
97
+ embedding = F.silu(embedding)
98
+
99
+ embedding = self.conv_out(embedding)
100
+
101
+ return embedding
102
+
103
+
104
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
105
+ """
106
+ A ControlNet model.
107
+
108
+ Args:
109
+ in_channels (`int`, defaults to 4):
110
+ The number of channels in the input sample.
111
+ flip_sin_to_cos (`bool`, defaults to `True`):
112
+ Whether to flip the sin to cos in the time embedding.
113
+ freq_shift (`int`, defaults to 0):
114
+ The frequency shift to apply to the time embedding.
115
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
116
+ The tuple of downsample blocks to use.
117
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
118
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
119
+ The tuple of output channels for each block.
120
+ layers_per_block (`int`, defaults to 2):
121
+ The number of layers per block.
122
+ downsample_padding (`int`, defaults to 1):
123
+ The padding to use for the downsampling convolution.
124
+ mid_block_scale_factor (`float`, defaults to 1):
125
+ The scale factor to use for the mid block.
126
+ act_fn (`str`, defaults to "silu"):
127
+ The activation function to use.
128
+ norm_num_groups (`int`, *optional*, defaults to 32):
129
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
130
+ in post-processing.
131
+ norm_eps (`float`, defaults to 1e-5):
132
+ The epsilon to use for the normalization.
133
+ cross_attention_dim (`int`, defaults to 1280):
134
+ The dimension of the cross attention features.
135
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
136
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
137
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
138
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
139
+ encoder_hid_dim (`int`, *optional*, defaults to None):
140
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
141
+ dimension to `cross_attention_dim`.
142
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
143
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
144
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
145
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
146
+ The dimension of the attention heads.
147
+ use_linear_projection (`bool`, defaults to `False`):
148
+ class_embed_type (`str`, *optional*, defaults to `None`):
149
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
150
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
151
+ addition_embed_type (`str`, *optional*, defaults to `None`):
152
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
153
+ "text". "text" will use the `TextTimeEmbedding` layer.
154
+ num_class_embeds (`int`, *optional*, defaults to 0):
155
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
156
+ class conditioning with `class_embed_type` equal to `None`.
157
+ upcast_attention (`bool`, defaults to `False`):
158
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
159
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
160
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
161
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
162
+ `class_embed_type="projection"`.
163
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
164
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
165
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
166
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
167
+ global_pool_conditions (`bool`, defaults to `False`):
168
+ """
169
+
170
+ _supports_gradient_checkpointing = True
171
+
172
+ @register_to_config
173
+ def __init__(
174
+ self,
175
+ in_channels: int = 4,
176
+ conditioning_channels: int = 3,
177
+ flip_sin_to_cos: bool = True,
178
+ freq_shift: int = 0,
179
+ down_block_types: Tuple[str] = (
180
+ "CrossAttnDownBlock2D",
181
+ "CrossAttnDownBlock2D",
182
+ "CrossAttnDownBlock2D",
183
+ "DownBlock2D",
184
+ ),
185
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
186
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
187
+ layers_per_block: int = 2,
188
+ downsample_padding: int = 1,
189
+ mid_block_scale_factor: float = 1,
190
+ act_fn: str = "silu",
191
+ norm_num_groups: Optional[int] = 32,
192
+ norm_eps: float = 1e-5,
193
+ cross_attention_dim: int = 1280,
194
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
195
+ encoder_hid_dim: Optional[int] = None,
196
+ encoder_hid_dim_type: Optional[str] = None,
197
+ attention_head_dim: Union[int, Tuple[int]] = 8,
198
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ projection_class_embeddings_input_dim: Optional[int] = None,
207
+ controlnet_conditioning_channel_order: str = "rgb",
208
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
209
+ global_pool_conditions: bool = False,
210
+ addition_embed_type_num_heads=64,
211
+ use_vae_encode_condition=False,
212
+ ):
213
+ super().__init__()
214
+
215
+ # If `num_attention_heads` is not defined (which is the case for most models)
216
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
217
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
218
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
219
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
220
+ # which is why we correct for the naming here.
221
+ num_attention_heads = num_attention_heads or attention_head_dim
222
+
223
+ # Check inputs
224
+ if len(block_out_channels) != len(down_block_types):
225
+ raise ValueError(
226
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
227
+ )
228
+
229
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
232
+ )
233
+
234
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if isinstance(transformer_layers_per_block, int):
240
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
241
+
242
+ # input
243
+ conv_in_kernel = 3
244
+ conv_in_padding = (conv_in_kernel - 1) // 2
245
+ self.conv_in = nn.Conv2d(
246
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
247
+ )
248
+
249
+ # use_vae_encode_condition
250
+ self.use_vae_encode_condition = use_vae_encode_condition
251
+ if self.use_vae_encode_condition:
252
+ print(f'============================')
253
+ print(f'use vae encode condition in CONTROLNET!!!')
254
+ print(f'============================')
255
+ self.condition_conv_in = nn.Conv2d(
256
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
+ )
258
+ else:
259
+ print(f'============================')
260
+ print(f'Not !!! use vae encode condition in CONTROLNET')
261
+ print(f'============================')
262
+ # control net conditioning embedding
263
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
264
+ conditioning_embedding_channels=block_out_channels[0],
265
+ block_out_channels=conditioning_embedding_out_channels,
266
+ conditioning_channels=conditioning_channels,
267
+ )
268
+
269
+ # time
270
+ time_embed_dim = block_out_channels[0] * 4
271
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
272
+ timestep_input_dim = block_out_channels[0]
273
+ self.time_embedding = TimestepEmbedding(
274
+ timestep_input_dim,
275
+ time_embed_dim,
276
+ act_fn=act_fn,
277
+ )
278
+
279
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
280
+ encoder_hid_dim_type = "text_proj"
281
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
282
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
283
+
284
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
285
+ raise ValueError(
286
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
287
+ )
288
+
289
+ if encoder_hid_dim_type == "text_proj":
290
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
291
+ elif encoder_hid_dim_type == "text_image_proj":
292
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
293
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
294
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
295
+ self.encoder_hid_proj = TextImageProjection(
296
+ text_embed_dim=encoder_hid_dim,
297
+ image_embed_dim=cross_attention_dim,
298
+ cross_attention_dim=cross_attention_dim,
299
+ )
300
+
301
+ elif encoder_hid_dim_type is not None:
302
+ raise ValueError(
303
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
304
+ )
305
+ else:
306
+ self.encoder_hid_proj = None
307
+
308
+ # class embedding
309
+ if class_embed_type is None and num_class_embeds is not None:
310
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
311
+ elif class_embed_type == "timestep":
312
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
313
+ elif class_embed_type == "identity":
314
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
315
+ elif class_embed_type == "projection":
316
+ if projection_class_embeddings_input_dim is None:
317
+ raise ValueError(
318
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
319
+ )
320
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
321
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
322
+ # 2. it projects from an arbitrary input dimension.
323
+ #
324
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
325
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
326
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
327
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
328
+ else:
329
+ self.class_embedding = None
330
+
331
+ if addition_embed_type == "text":
332
+ if encoder_hid_dim is not None:
333
+ text_time_embedding_from_dim = encoder_hid_dim
334
+ else:
335
+ text_time_embedding_from_dim = cross_attention_dim
336
+
337
+ self.add_embedding = TextTimeEmbedding(
338
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
339
+ )
340
+ elif addition_embed_type == "text_image":
341
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
342
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
343
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
344
+ self.add_embedding = TextImageTimeEmbedding(
345
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
346
+ )
347
+ elif addition_embed_type == "text_time":
348
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
349
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
350
+
351
+ elif addition_embed_type is not None:
352
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
353
+
354
+ self.down_blocks = nn.ModuleList([])
355
+ self.controlnet_down_blocks = nn.ModuleList([])
356
+
357
+ if isinstance(only_cross_attention, bool):
358
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
359
+
360
+ if isinstance(attention_head_dim, int):
361
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
362
+
363
+ if isinstance(num_attention_heads, int):
364
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
365
+
366
+ # down
367
+ output_channel = block_out_channels[0]
368
+
369
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
370
+ controlnet_block = zero_module(controlnet_block)
371
+ self.controlnet_down_blocks.append(controlnet_block)
372
+
373
+
374
+ for i, down_block_type in enumerate(down_block_types):
375
+ input_channel = output_channel
376
+ output_channel = block_out_channels[i]
377
+ is_final_block = i == len(block_out_channels) - 1
378
+
379
+ down_block = get_down_block(
380
+ down_block_type,
381
+ num_layers=layers_per_block,
382
+ transformer_layers_per_block=transformer_layers_per_block[i],
383
+ in_channels=input_channel,
384
+ out_channels=output_channel,
385
+ temb_channels=time_embed_dim,
386
+ add_downsample=not is_final_block,
387
+ resnet_eps=norm_eps,
388
+ resnet_act_fn=act_fn,
389
+ resnet_groups=norm_num_groups,
390
+ cross_attention_dim=cross_attention_dim,
391
+ num_attention_heads=num_attention_heads[i],
392
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
393
+ downsample_padding=downsample_padding,
394
+ use_linear_projection=use_linear_projection,
395
+ only_cross_attention=only_cross_attention[i],
396
+ upcast_attention=upcast_attention,
397
+ resnet_time_scale_shift=resnet_time_scale_shift,
398
+ )
399
+ self.down_blocks.append(down_block)
400
+
401
+ for _ in range(layers_per_block):
402
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
403
+ controlnet_block = zero_module(controlnet_block)
404
+ self.controlnet_down_blocks.append(controlnet_block)
405
+
406
+ if not is_final_block:
407
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
408
+ controlnet_block = zero_module(controlnet_block)
409
+ self.controlnet_down_blocks.append(controlnet_block)
410
+
411
+ # mid
412
+ mid_block_channel = block_out_channels[-1]
413
+
414
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
415
+ controlnet_block = zero_module(controlnet_block)
416
+ self.controlnet_mid_block = controlnet_block
417
+
418
+ self.mid_block = UNetMidBlock2DCrossAttn(
419
+ transformer_layers_per_block=transformer_layers_per_block[-1],
420
+ in_channels=mid_block_channel,
421
+ temb_channels=time_embed_dim,
422
+ resnet_eps=norm_eps,
423
+ resnet_act_fn=act_fn,
424
+ output_scale_factor=mid_block_scale_factor,
425
+ resnet_time_scale_shift=resnet_time_scale_shift,
426
+ cross_attention_dim=cross_attention_dim,
427
+ num_attention_heads=num_attention_heads[-1],
428
+ resnet_groups=norm_num_groups,
429
+ use_linear_projection=use_linear_projection,
430
+ upcast_attention=upcast_attention,
431
+ )
432
+
433
+ @classmethod
434
+ def from_unet(
435
+ cls,
436
+ unet: UNet2DConditionModel,
437
+ controlnet_conditioning_channel_order: str = "rgb",
438
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
439
+ load_weights_from_unet: bool = True,
440
+ use_vae_encode_condition: bool = False,
441
+ ):
442
+ r"""
443
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
444
+
445
+ Parameters:
446
+ unet (`UNet2DConditionModel`):
447
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
448
+ where applicable.
449
+ """
450
+ transformer_layers_per_block = (
451
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
452
+ )
453
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
454
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
455
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
456
+ addition_time_embed_dim = (
457
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
458
+ )
459
+
460
+ controlnet = cls(
461
+ encoder_hid_dim=encoder_hid_dim,
462
+ encoder_hid_dim_type=encoder_hid_dim_type,
463
+ addition_embed_type=addition_embed_type,
464
+ addition_time_embed_dim=addition_time_embed_dim,
465
+ transformer_layers_per_block=transformer_layers_per_block,
466
+ in_channels=unet.config.in_channels,
467
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
468
+ freq_shift=unet.config.freq_shift,
469
+ down_block_types=unet.config.down_block_types,
470
+ only_cross_attention=unet.config.only_cross_attention,
471
+ block_out_channels=unet.config.block_out_channels,
472
+ layers_per_block=unet.config.layers_per_block,
473
+ downsample_padding=unet.config.downsample_padding,
474
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
475
+ act_fn=unet.config.act_fn,
476
+ norm_num_groups=unet.config.norm_num_groups,
477
+ norm_eps=unet.config.norm_eps,
478
+ cross_attention_dim=unet.config.cross_attention_dim,
479
+ attention_head_dim=unet.config.attention_head_dim,
480
+ num_attention_heads=unet.config.num_attention_heads,
481
+ use_linear_projection=unet.config.use_linear_projection,
482
+ class_embed_type=unet.config.class_embed_type,
483
+ num_class_embeds=unet.config.num_class_embeds,
484
+ upcast_attention=unet.config.upcast_attention,
485
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
486
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
487
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
488
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
489
+ use_vae_encode_condition=use_vae_encode_condition,
490
+ )
491
+
492
+ if load_weights_from_unet:
493
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
494
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
495
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
496
+
497
+ if controlnet.class_embedding:
498
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
499
+
500
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
501
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
502
+
503
+ return controlnet
504
+
505
+ @property
506
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
507
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
508
+ r"""
509
+ Returns:
510
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
511
+ indexed by its weight name.
512
+ """
513
+ # set recursively
514
+ processors = {}
515
+
516
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
517
+ if hasattr(module, "get_processor"):
518
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
519
+
520
+ for sub_name, child in module.named_children():
521
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
522
+
523
+ return processors
524
+
525
+ for name, module in self.named_children():
526
+ fn_recursive_add_processors(name, module, processors)
527
+
528
+ return processors
529
+
530
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
531
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
532
+ r"""
533
+ Sets the attention processor to use to compute attention.
534
+
535
+ Parameters:
536
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
537
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
538
+ for **all** `Attention` layers.
539
+
540
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
541
+ processor. This is strongly recommended when setting trainable attention processors.
542
+
543
+ """
544
+ count = len(self.attn_processors.keys())
545
+
546
+ if isinstance(processor, dict) and len(processor) != count:
547
+ raise ValueError(
548
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
549
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
550
+ )
551
+
552
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
553
+ if hasattr(module, "set_processor"):
554
+ if not isinstance(processor, dict):
555
+ module.set_processor(processor)
556
+ else:
557
+ module.set_processor(processor.pop(f"{name}.processor"))
558
+
559
+ for sub_name, child in module.named_children():
560
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
561
+
562
+ for name, module in self.named_children():
563
+ fn_recursive_attn_processor(name, module, processor)
564
+
565
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
566
+ def set_default_attn_processor(self):
567
+ """
568
+ Disables custom attention processors and sets the default attention implementation.
569
+ """
570
+ self.set_attn_processor(AttnProcessor())
571
+
572
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
573
+ def set_attention_slice(self, slice_size):
574
+ r"""
575
+ Enable sliced attention computation.
576
+
577
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
578
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
579
+
580
+ Args:
581
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
582
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
583
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
584
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
585
+ must be a multiple of `slice_size`.
586
+ """
587
+ sliceable_head_dims = []
588
+
589
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
590
+ if hasattr(module, "set_attention_slice"):
591
+ sliceable_head_dims.append(module.sliceable_head_dim)
592
+
593
+ for child in module.children():
594
+ fn_recursive_retrieve_sliceable_dims(child)
595
+
596
+ # retrieve number of attention layers
597
+ for module in self.children():
598
+ fn_recursive_retrieve_sliceable_dims(module)
599
+
600
+ num_sliceable_layers = len(sliceable_head_dims)
601
+
602
+ if slice_size == "auto":
603
+ # half the attention head size is usually a good trade-off between
604
+ # speed and memory
605
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
606
+ elif slice_size == "max":
607
+ # make smallest slice possible
608
+ slice_size = num_sliceable_layers * [1]
609
+
610
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
611
+
612
+ if len(slice_size) != len(sliceable_head_dims):
613
+ raise ValueError(
614
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
615
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
616
+ )
617
+
618
+ for i in range(len(slice_size)):
619
+ size = slice_size[i]
620
+ dim = sliceable_head_dims[i]
621
+ if size is not None and size > dim:
622
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
623
+
624
+ # Recursively walk through all the children.
625
+ # Any children which exposes the set_attention_slice method
626
+ # gets the message
627
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
628
+ if hasattr(module, "set_attention_slice"):
629
+ module.set_attention_slice(slice_size.pop())
630
+
631
+ for child in module.children():
632
+ fn_recursive_set_attention_slice(child, slice_size)
633
+
634
+ reversed_slice_size = list(reversed(slice_size))
635
+ for module in self.children():
636
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
637
+
638
+ def _set_gradient_checkpointing(self, module, value=False):
639
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
640
+ module.gradient_checkpointing = value
641
+
642
+ def forward(
643
+ self,
644
+ sample: torch.FloatTensor,
645
+ timestep: Union[torch.Tensor, float, int],
646
+ encoder_hidden_states: torch.Tensor,
647
+ controlnet_cond: torch.FloatTensor,
648
+ conditioning_scale: float = 1.0,
649
+ class_labels: Optional[torch.Tensor] = None,
650
+ timestep_cond: Optional[torch.Tensor] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
653
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
654
+ guess_mode: bool = False,
655
+ return_dict: bool = True,
656
+ image_encoder_hidden_states: torch.Tensor = None,
657
+ vae_encode_condition_hidden_states: torch.Tensor = None,
658
+ use_vae_encode_condition = False,
659
+ ) -> Union[ControlNetOutput, Tuple]:
660
+ """
661
+ The [`ControlNetModel`] forward method.
662
+
663
+ Args:
664
+ sample (`torch.FloatTensor`):
665
+ The noisy input tensor.
666
+ timestep (`Union[torch.Tensor, float, int]`):
667
+ The number of timesteps to denoise an input.
668
+ encoder_hidden_states (`torch.Tensor`):
669
+ The encoder hidden states.
670
+ controlnet_cond (`torch.FloatTensor`):
671
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
672
+ conditioning_scale (`float`, defaults to `1.0`):
673
+ The scale factor for ControlNet outputs.
674
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
675
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
676
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
677
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
678
+ added_cond_kwargs (`dict`):
679
+ Additional conditions for the Stable Diffusion XL UNet.
680
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
681
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
682
+ guess_mode (`bool`, defaults to `False`):
683
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
684
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
685
+ return_dict (`bool`, defaults to `True`):
686
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
687
+
688
+ Returns:
689
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
690
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
691
+ returned where the first element is the sample tensor.
692
+ """
693
+ # check channel order
694
+ channel_order = self.config.controlnet_conditioning_channel_order
695
+
696
+ if channel_order == "rgb":
697
+ # in rgb order by default
698
+ ...
699
+ elif channel_order == "bgr":
700
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
701
+ else:
702
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
703
+
704
+ # prepare attention_mask
705
+ if attention_mask is not None:
706
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
707
+ attention_mask = attention_mask.unsqueeze(1)
708
+
709
+ # 1. time
710
+ timesteps = timestep
711
+ if not torch.is_tensor(timesteps):
712
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
713
+ # This would be a good case for the `match` statement (Python 3.10+)
714
+ is_mps = sample.device.type == "mps"
715
+ if isinstance(timestep, float):
716
+ dtype = torch.float32 if is_mps else torch.float64
717
+ else:
718
+ dtype = torch.int32 if is_mps else torch.int64
719
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
720
+ elif len(timesteps.shape) == 0:
721
+ timesteps = timesteps[None].to(sample.device)
722
+
723
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
724
+ timesteps = timesteps.expand(sample.shape[0])
725
+
726
+ t_emb = self.time_proj(timesteps)
727
+
728
+ # timesteps does not contain any weights and will always return f32 tensors
729
+ # but time_embedding might actually be running in fp16. so we need to cast here.
730
+ # there might be better ways to encapsulate this.
731
+ t_emb = t_emb.to(dtype=sample.dtype)
732
+
733
+ emb = self.time_embedding(t_emb, timestep_cond)
734
+ aug_emb = None
735
+
736
+ if self.class_embedding is not None:
737
+ if class_labels is None:
738
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
739
+
740
+ if self.config.class_embed_type == "timestep":
741
+ class_labels = self.time_proj(class_labels)
742
+
743
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
744
+ emb = emb + class_emb
745
+
746
+ if self.config.addition_embed_type is not None:
747
+ if self.config.addition_embed_type == "text":
748
+ aug_emb = self.add_embedding(encoder_hidden_states)
749
+
750
+ elif self.config.addition_embed_type == "text_time":
751
+ if "text_embeds" not in added_cond_kwargs:
752
+ raise ValueError(
753
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
754
+ )
755
+ text_embeds = added_cond_kwargs.get("text_embeds")
756
+ if "time_ids" not in added_cond_kwargs:
757
+ raise ValueError(
758
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
759
+ )
760
+ time_ids = added_cond_kwargs.get("time_ids")
761
+ time_embeds = self.add_time_proj(time_ids.flatten())
762
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
763
+
764
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
765
+ add_embeds = add_embeds.to(emb.dtype)
766
+ aug_emb = self.add_embedding(add_embeds)
767
+
768
+ emb = emb + aug_emb if aug_emb is not None else emb
769
+
770
+ # 2. pre-process
771
+ sample = self.conv_in(sample)
772
+
773
+ if not self.use_vae_encode_condition:
774
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
775
+ else:
776
+ controlnet_cond = self.condition_conv_in(vae_encode_condition_hidden_states)
777
+
778
+ sample = sample + controlnet_cond
779
+
780
+ # 3. down
781
+ down_block_res_samples = (sample,)
782
+ for downsample_block in self.down_blocks:
783
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
784
+ sample, res_samples = downsample_block(
785
+ hidden_states=sample,
786
+ temb=emb,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ attention_mask=attention_mask,
789
+ cross_attention_kwargs=cross_attention_kwargs,
790
+ image_encoder_hidden_states=image_encoder_hidden_states,
791
+ )
792
+ else:
793
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
794
+
795
+ down_block_res_samples += res_samples
796
+
797
+ # 4. mid
798
+ if self.mid_block is not None:
799
+ sample = self.mid_block(
800
+ sample,
801
+ emb,
802
+ encoder_hidden_states=encoder_hidden_states,
803
+ attention_mask=attention_mask,
804
+ cross_attention_kwargs=cross_attention_kwargs,
805
+ image_encoder_hidden_states=image_encoder_hidden_states,
806
+ )
807
+
808
+ # 5. Control net blocks
809
+
810
+ controlnet_down_block_res_samples = ()
811
+
812
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
813
+ down_block_res_sample = controlnet_block(down_block_res_sample)
814
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
815
+
816
+ down_block_res_samples = controlnet_down_block_res_samples
817
+
818
+ mid_block_res_sample = self.controlnet_mid_block(sample)
819
+
820
+ # 6. scaling
821
+ if guess_mode and not self.config.global_pool_conditions:
822
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
823
+
824
+ scales = scales * conditioning_scale
825
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
826
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
827
+ else:
828
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
829
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
830
+
831
+ if self.config.global_pool_conditions:
832
+ down_block_res_samples = [
833
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
834
+ ]
835
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
836
+
837
+ if not return_dict:
838
+ return (down_block_res_samples, mid_block_res_sample)
839
+
840
+ return ControlNetOutput(
841
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
842
+ )
843
+
844
+
845
+ def zero_module(module):
846
+ for p in module.parameters():
847
+ nn.init.zeros_(p)
848
+ return module
849
+
850
+
models/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.losses.contperceptual import LPIPSWithDiscriminator
models/losses/contperceptual.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import FromOriginalControlnetMixin
8
+
9
+ class LPIPSWithDiscriminator(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
10
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
11
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
12
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
13
+ disc_loss="hinge"):
14
+
15
+ super().__init__()
16
+ assert disc_loss in ["hinge", "vanilla"]
17
+ self.kl_weight = kl_weight
18
+ self.pixel_weight = pixelloss_weight
19
+ self.perceptual_loss = LPIPS().eval()
20
+ self.perceptual_weight = perceptual_weight
21
+ # output log variance
22
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
23
+
24
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
25
+ n_layers=disc_num_layers,
26
+ use_actnorm=use_actnorm
27
+ ).apply(weights_init)
28
+ self.discriminator_iter_start = disc_start
29
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
30
+ self.disc_factor = disc_factor
31
+ self.discriminator_weight = disc_weight
32
+ self.disc_conditional = disc_conditional
33
+
34
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
35
+ if last_layer is not None:
36
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
37
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
38
+ else:
39
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
40
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
41
+
42
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
43
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
44
+ d_weight = d_weight * self.discriminator_weight
45
+ return d_weight
46
+
47
+ def forward(self, inputs, reconstructions, optimizer_idx,
48
+ global_step, posteriors=None, last_layer=None, cond=None, split="train",
49
+ weights=None, return_dic=False):
50
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
51
+ if self.perceptual_weight > 0:
52
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
53
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
54
+
55
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
56
+ weighted_nll_loss = nll_loss
57
+ if weights is not None:
58
+ weighted_nll_loss = weights*nll_loss
59
+ weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]
60
+ nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]
61
+ if self.kl_weight>0:
62
+ kl_loss = posteriors.kl()
63
+ kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
64
+
65
+ # now the GAN part
66
+ if optimizer_idx == 0:
67
+ # generator update
68
+ if cond is None:
69
+ assert not self.disc_conditional
70
+ logits_fake = self.discriminator(reconstructions.contiguous())
71
+ else:
72
+ assert self.disc_conditional
73
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
74
+ g_loss = -torch.mean(logits_fake)
75
+
76
+ if self.disc_factor > 0.0:
77
+ try:
78
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
79
+ except RuntimeError:
80
+ # assert not self.training
81
+ d_weight = torch.tensor(1.0) * self.discriminator_weight
82
+ else:
83
+ # d_weight = torch.tensor(0.0)
84
+ d_weight = torch.tensor(0.0)
85
+
86
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
87
+ if self.kl_weight>0:
88
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
89
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
90
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
91
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
92
+ "{}/d_weight".format(split): d_weight.detach(),
93
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
94
+ "{}/g_loss".format(split): g_loss.detach().mean(),
95
+ }
96
+ if return_dic:
97
+ loss_dic = {}
98
+ loss_dic['total_loss'] = loss.clone().detach().mean()
99
+ loss_dic['logvar'] = self.logvar.detach()
100
+ loss_dic['kl_loss'] = kl_loss.detach().mean()
101
+ loss_dic['nll_loss'] = nll_loss.detach().mean()
102
+ loss_dic['rec_loss'] = rec_loss.detach().mean()
103
+ loss_dic['d_weight'] = d_weight.detach()
104
+ loss_dic['disc_factor'] = torch.tensor(disc_factor)
105
+ loss_dic['g_loss'] = g_loss.detach().mean()
106
+ else:
107
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
108
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
109
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
110
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
111
+ "{}/d_weight".format(split): d_weight.detach(),
112
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
113
+ "{}/g_loss".format(split): g_loss.detach().mean(),
114
+ }
115
+ if return_dic:
116
+ loss_dic = {}
117
+ loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean()
118
+ loss_dic["{}/logvar".format(split)] = self.logvar.detach()
119
+ loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()
120
+ loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()
121
+ loss_dic['d_weight'.format(split)] = d_weight.detach()
122
+ loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)
123
+ loss_dic['g_loss'.format(split)] = g_loss.detach().mean()
124
+
125
+ if return_dic:
126
+ return loss, log, loss_dic
127
+ return loss, log
128
+
129
+ if optimizer_idx == 1:
130
+ # second pass for discriminator update
131
+ if cond is None:
132
+ logits_real = self.discriminator(inputs.contiguous().detach())
133
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
134
+ else:
135
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
136
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
137
+
138
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
139
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
140
+
141
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
142
+ "{}/logits_real".format(split): logits_real.detach().mean(),
143
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
144
+ }
145
+
146
+ if return_dic:
147
+ loss_dic = {}
148
+ loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean()
149
+ loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean()
150
+ loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean()
151
+ return d_loss, log, loss_dic
152
+
153
+ return d_loss, log
154
+
models/losses/vqperceptual.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+
11
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15
+ loss_real = (weights * loss_real).sum() / weights.sum()
16
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
21
+ if global_step < threshold:
22
+ weight = value
23
+ return weight
24
+
25
+
26
+ def measure_perplexity(predicted_indices, n_embed):
27
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30
+ avg_probs = encodings.mean(0)
31
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32
+ cluster_use = torch.sum(avg_probs > 0)
33
+ return perplexity, cluster_use
34
+
35
+ def l1(x, y):
36
+ return torch.abs(x-y)
37
+
38
+
39
+ def l2(x, y):
40
+ return torch.pow((x-y), 2)
41
+
42
+
43
+ class VQLPIPSWithDiscriminator(nn.Module):
44
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48
+ pixel_loss="l1"):
49
+ super().__init__()
50
+ assert disc_loss in ["hinge", "vanilla"]
51
+ assert perceptual_loss in ["lpips", "clips", "dists"]
52
+ assert pixel_loss in ["l1", "l2"]
53
+ self.codebook_weight = codebook_weight
54
+ self.pixel_weight = pixelloss_weight
55
+ if perceptual_loss == "lpips":
56
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
57
+ self.perceptual_loss = LPIPS().eval()
58
+ else:
59
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60
+ self.perceptual_weight = perceptual_weight
61
+
62
+ if pixel_loss == "l1":
63
+ self.pixel_loss = l1
64
+ else:
65
+ self.pixel_loss = l2
66
+
67
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68
+ n_layers=disc_num_layers,
69
+ use_actnorm=use_actnorm,
70
+ ndf=disc_ndf
71
+ ).apply(weights_init)
72
+ self.discriminator_iter_start = disc_start
73
+ if disc_loss == "hinge":
74
+ self.disc_loss = hinge_d_loss
75
+ elif disc_loss == "vanilla":
76
+ self.disc_loss = vanilla_d_loss
77
+ else:
78
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80
+ self.disc_factor = disc_factor
81
+ self.discriminator_weight = disc_weight
82
+ self.disc_conditional = disc_conditional
83
+ self.n_classes = n_classes
84
+
85
+ # def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86
+ # if last_layer is not None:
87
+ # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88
+ # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89
+ # else:
90
+ # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91
+ # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92
+
93
+ # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94
+ # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95
+ # d_weight = d_weight * self.discriminator_weight
96
+ # return d_weight
97
+
98
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
99
+ # if last_layer is not None:
100
+ # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
101
+ # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
102
+ # else:
103
+ # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
104
+ # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
105
+
106
+ # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
107
+ # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
108
+ d_weight = 1.0 * self.discriminator_weight
109
+ return d_weight
110
+
111
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
112
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
113
+ if not exists(codebook_loss):
114
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
115
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
116
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
117
+ if self.perceptual_weight > 0:
118
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
119
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
120
+ else:
121
+ p_loss = torch.tensor([0.0])
122
+
123
+ nll_loss = rec_loss
124
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
125
+ nll_loss = torch.mean(nll_loss)
126
+
127
+ # now the GAN part
128
+ if optimizer_idx == 0:
129
+ # generator update
130
+ if cond is None:
131
+ assert not self.disc_conditional
132
+ logits_fake = self.discriminator(reconstructions.contiguous())
133
+ else:
134
+ assert self.disc_conditional
135
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
136
+ g_loss = -torch.mean(logits_fake)
137
+
138
+ try:
139
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
140
+ except RuntimeError:
141
+ assert not self.training
142
+ d_weight = torch.tensor(0.0)
143
+
144
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
145
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
146
+
147
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
148
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
149
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
150
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
151
+ "{}/p_loss".format(split): p_loss.detach().mean(),
152
+ "{}/d_weight".format(split): d_weight.detach(),
153
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
154
+ "{}/g_loss".format(split): g_loss.detach().mean(),
155
+ }
156
+ if predicted_indices is not None:
157
+ assert self.n_classes is not None
158
+ with torch.no_grad():
159
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
160
+ log[f"{split}/perplexity"] = perplexity
161
+ log[f"{split}/cluster_usage"] = cluster_usage
162
+ return loss, log
163
+
164
+ if optimizer_idx == 1:
165
+ # second pass for discriminator update
166
+ if cond is None:
167
+ logits_real = self.discriminator(inputs.contiguous().detach())
168
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
169
+ else:
170
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
171
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
172
+
173
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
174
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
175
+
176
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
177
+ "{}/logits_real".format(split): logits_real.detach().mean(),
178
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
179
+ }
180
+ return d_loss, log
models/shared.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Shared architecture blocks."""
10
+
11
+ from typing import Callable
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from ADD.th_utils.ops import bias_act
18
+
19
+
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, fn: Callable):
22
+ super().__init__()
23
+ self.fn = fn
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return (self.fn(x) + x) / np.sqrt(2)
27
+
28
+
29
+ class FullyConnectedLayer(nn.Module):
30
+ def __init__(
31
+ self,
32
+ in_features: int, # Number of input features.
33
+ out_features: int, # Number of output features.
34
+ bias: bool = True, # Apply additive bias before the activation function?
35
+ activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
36
+ lr_multiplier: float = 1.0, # Learning rate multiplier.
37
+ weight_init: float = 1.0, # Initial standard deviation of the weight tensor.
38
+ bias_init: float = 0.0, # Initial value for the additive bias.
39
+ ):
40
+
41
+ super().__init__()
42
+ self.in_features = in_features
43
+ self.out_features = out_features
44
+ self.activation = activation
45
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
46
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
47
+ self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
48
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
49
+ self.bias_gain = lr_multiplier
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ w = self.weight.to(x.dtype) * self.weight_gain
53
+ b = self.bias
54
+ if b is not None:
55
+ b = b.to(x.dtype)
56
+ if self.bias_gain != 1:
57
+ b = b * self.bias_gain
58
+
59
+ if self.activation == 'linear' and b is not None:
60
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
61
+ else:
62
+ x = x.matmul(w.t())
63
+ x = bias_act.bias_act(x, b, act=self.activation)
64
+ return x
65
+
66
+ def extra_repr(self) -> str:
67
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
68
+
69
+
70
+ class MLP(nn.Module):
71
+ def __init__(
72
+ self,
73
+ features_list: list[int], # Number of features in each layer of the MLP.
74
+ activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
75
+ lr_multiplier: float = 1.0, # Learning rate multiplier.
76
+ linear_out: bool = False # Use the 'linear' activation function for the output layer?
77
+ ):
78
+ super().__init__()
79
+ num_layers = len(features_list) - 1
80
+ self.num_layers = num_layers
81
+ self.out_dim = features_list[-1]
82
+
83
+ for idx in range(num_layers):
84
+ in_features = features_list[idx]
85
+ out_features = features_list[idx + 1]
86
+ if linear_out and idx == num_layers-1:
87
+ activation = 'linear'
88
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
89
+ setattr(self, f'fc{idx}', layer)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ ''' if x is sequence of tokens, shift tokens to batch and apply MLP to all'''
93
+ shift2batch = (x.ndim == 3)
94
+
95
+ if shift2batch:
96
+ B, K, C = x.shape
97
+ x = x.flatten(0,1)
98
+
99
+ for idx in range(self.num_layers):
100
+ layer = getattr(self, f'fc{idx}')
101
+ x = layer(x)
102
+
103
+ if shift2batch:
104
+ x = x.reshape(B, K, -1)
105
+
106
+ return x
models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
models/unet_2d_condition.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
26
+ from diffusers.models.embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ PositionNet,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2DCrossAttn,
41
+ UNetMidBlock2DSimpleCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ import os, json
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ @dataclass
53
+ class UNet2DConditionOutput(BaseOutput):
54
+ """
55
+ The output of [`UNet2DConditionModel`].
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
59
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
60
+ """
61
+
62
+ sample: torch.FloatTensor = None
63
+
64
+
65
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
66
+ r"""
67
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
68
+ shaped output.
69
+
70
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
71
+ for all models (such as downloading or saving).
72
+
73
+ Parameters:
74
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
75
+ Height and width of input/output sample.
76
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
77
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
78
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
79
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
80
+ Whether to flip the sin to cos in the time embedding.
81
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
82
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
83
+ The tuple of downsample blocks to use.
84
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
85
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
86
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
87
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
88
+ The tuple of upsample blocks to use.
89
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
90
+ Whether to include self-attention in the basic transformer blocks, see
91
+ [`~models.attention.BasicTransformerBlock`].
92
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
93
+ The tuple of output channels for each block.
94
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
95
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
96
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ encoder_hid_dim (`int`, *optional*, defaults to None):
108
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
109
+ dimension to `cross_attention_dim`.
110
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
111
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
112
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
113
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
114
+ num_attention_heads (`int`, *optional*):
115
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
116
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
117
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
118
+ class_embed_type (`str`, *optional*, defaults to `None`):
119
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
120
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
121
+ addition_embed_type (`str`, *optional*, defaults to `None`):
122
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
123
+ "text". "text" will use the `TextTimeEmbedding` layer.
124
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
125
+ Dimension for the timestep embeddings.
126
+ num_class_embeds (`int`, *optional*, defaults to `None`):
127
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
128
+ class conditioning with `class_embed_type` equal to `None`.
129
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
130
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
131
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
132
+ An optional override for the dimension of the projected time embedding.
133
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
134
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
135
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
136
+ timestep_post_act (`str`, *optional*, defaults to `None`):
137
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
138
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
139
+ The dimension of `cond_proj` layer in the timestep embedding.
140
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
141
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
142
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
143
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
144
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
145
+ embeddings with the class embeddings.
146
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
147
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
148
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
149
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
150
+ otherwise.
151
+ """
152
+
153
+ _supports_gradient_checkpointing = True
154
+
155
+ @register_to_config
156
+ def __init__(
157
+ self,
158
+ sample_size: Optional[int] = None,
159
+ in_channels: int = 4,
160
+ out_channels: int = 4,
161
+ center_input_sample: bool = False,
162
+ flip_sin_to_cos: bool = True,
163
+ freq_shift: int = 0,
164
+ down_block_types: Tuple[str] = (
165
+ "CrossAttnDownBlock2D",
166
+ "CrossAttnDownBlock2D",
167
+ "CrossAttnDownBlock2D",
168
+ "DownBlock2D",
169
+ ),
170
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
171
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
172
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
173
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
174
+ layers_per_block: Union[int, Tuple[int]] = 2,
175
+ downsample_padding: int = 1,
176
+ mid_block_scale_factor: float = 1,
177
+ act_fn: str = "silu",
178
+ norm_num_groups: Optional[int] = 32,
179
+ norm_eps: float = 1e-5,
180
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
181
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
182
+ encoder_hid_dim: Optional[int] = None,
183
+ encoder_hid_dim_type: Optional[str] = None,
184
+ attention_head_dim: Union[int, Tuple[int]] = 8,
185
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
186
+ dual_cross_attention: bool = False,
187
+ use_linear_projection: bool = False,
188
+ class_embed_type: Optional[str] = None,
189
+ addition_embed_type: Optional[str] = None,
190
+ addition_time_embed_dim: Optional[int] = None,
191
+ num_class_embeds: Optional[int] = None,
192
+ upcast_attention: bool = False,
193
+ resnet_time_scale_shift: str = "default",
194
+ resnet_skip_time_act: bool = False,
195
+ resnet_out_scale_factor: int = 1.0,
196
+ time_embedding_type: str = "positional",
197
+ time_embedding_dim: Optional[int] = None,
198
+ time_embedding_act_fn: Optional[str] = None,
199
+ timestep_post_act: Optional[str] = None,
200
+ time_cond_proj_dim: Optional[int] = None,
201
+ conv_in_kernel: int = 3,
202
+ conv_out_kernel: int = 3,
203
+ projection_class_embeddings_input_dim: Optional[int] = None,
204
+ attention_type: str = "default",
205
+ class_embeddings_concat: bool = False,
206
+ mid_block_only_cross_attention: Optional[bool] = None,
207
+ cross_attention_norm: Optional[str] = None,
208
+ addition_embed_type_num_heads=64,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.sample_size = sample_size
213
+
214
+ if num_attention_heads is not None:
215
+ raise ValueError(
216
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
217
+ )
218
+
219
+ # If `num_attention_heads` is not defined (which is the case for most models)
220
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
221
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
222
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
223
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
224
+ # which is why we correct for the naming here.
225
+ num_attention_heads = num_attention_heads or attention_head_dim
226
+
227
+ # Check inputs
228
+ if len(down_block_types) != len(up_block_types):
229
+ raise ValueError(
230
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
231
+ )
232
+
233
+ if len(block_out_channels) != len(down_block_types):
234
+ raise ValueError(
235
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
236
+ )
237
+
238
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
239
+ raise ValueError(
240
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
241
+ )
242
+
243
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
244
+ raise ValueError(
245
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
246
+ )
247
+
248
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
249
+ raise ValueError(
250
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
251
+ )
252
+
253
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
254
+ raise ValueError(
255
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
256
+ )
257
+
258
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
259
+ raise ValueError(
260
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ # input
264
+ conv_in_padding = (conv_in_kernel - 1) // 2
265
+ self.conv_in = nn.Conv2d(
266
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
267
+ )
268
+
269
+ # time
270
+ if time_embedding_type == "fourier":
271
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
272
+ if time_embed_dim % 2 != 0:
273
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
274
+ self.time_proj = GaussianFourierProjection(
275
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
276
+ )
277
+ timestep_input_dim = time_embed_dim
278
+ elif time_embedding_type == "positional":
279
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
280
+
281
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
282
+ timestep_input_dim = block_out_channels[0]
283
+ else:
284
+ raise ValueError(
285
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
286
+ )
287
+
288
+ self.time_embedding = TimestepEmbedding(
289
+ timestep_input_dim,
290
+ time_embed_dim,
291
+ act_fn=act_fn,
292
+ post_act_fn=timestep_post_act,
293
+ cond_proj_dim=time_cond_proj_dim,
294
+ )
295
+
296
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
297
+ encoder_hid_dim_type = "text_proj"
298
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
299
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
300
+
301
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
302
+ raise ValueError(
303
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
304
+ )
305
+
306
+ if encoder_hid_dim_type == "text_proj":
307
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
308
+ elif encoder_hid_dim_type == "text_image_proj":
309
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
310
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
311
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
312
+ self.encoder_hid_proj = TextImageProjection(
313
+ text_embed_dim=encoder_hid_dim,
314
+ image_embed_dim=cross_attention_dim,
315
+ cross_attention_dim=cross_attention_dim,
316
+ )
317
+ elif encoder_hid_dim_type == "image_proj":
318
+ # Kandinsky 2.2
319
+ self.encoder_hid_proj = ImageProjection(
320
+ image_embed_dim=encoder_hid_dim,
321
+ cross_attention_dim=cross_attention_dim,
322
+ )
323
+ elif encoder_hid_dim_type is not None:
324
+ raise ValueError(
325
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
326
+ )
327
+ else:
328
+ self.encoder_hid_proj = None
329
+
330
+ # class embedding
331
+ if class_embed_type is None and num_class_embeds is not None:
332
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
333
+ elif class_embed_type == "timestep":
334
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
335
+ elif class_embed_type == "identity":
336
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
337
+ elif class_embed_type == "projection":
338
+ if projection_class_embeddings_input_dim is None:
339
+ raise ValueError(
340
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
341
+ )
342
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
343
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
344
+ # 2. it projects from an arbitrary input dimension.
345
+ #
346
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
347
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
348
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
349
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
350
+ elif class_embed_type == "simple_projection":
351
+ if projection_class_embeddings_input_dim is None:
352
+ raise ValueError(
353
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
354
+ )
355
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if addition_embed_type == "text":
360
+ if encoder_hid_dim is not None:
361
+ text_time_embedding_from_dim = encoder_hid_dim
362
+ else:
363
+ text_time_embedding_from_dim = cross_attention_dim
364
+
365
+ self.add_embedding = TextTimeEmbedding(
366
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367
+ )
368
+ elif addition_embed_type == "text_image":
369
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372
+ self.add_embedding = TextImageTimeEmbedding(
373
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374
+ )
375
+ elif addition_embed_type == "text_time":
376
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+ elif addition_embed_type == "image":
379
+ # Kandinsky 2.2
380
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
381
+ elif addition_embed_type == "image_hint":
382
+ # Kandinsky 2.2 ControlNet
383
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384
+ elif addition_embed_type is not None:
385
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
386
+
387
+ if time_embedding_act_fn is None:
388
+ self.time_embed_act = None
389
+ else:
390
+ self.time_embed_act = get_activation(time_embedding_act_fn)
391
+
392
+ self.down_blocks = nn.ModuleList([])
393
+ self.up_blocks = nn.ModuleList([])
394
+
395
+ if isinstance(only_cross_attention, bool):
396
+ if mid_block_only_cross_attention is None:
397
+ mid_block_only_cross_attention = only_cross_attention
398
+
399
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
400
+
401
+ if mid_block_only_cross_attention is None:
402
+ mid_block_only_cross_attention = False
403
+
404
+ if isinstance(num_attention_heads, int):
405
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
406
+
407
+ if isinstance(attention_head_dim, int):
408
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
409
+
410
+ if isinstance(cross_attention_dim, int):
411
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
412
+
413
+ if isinstance(layers_per_block, int):
414
+ layers_per_block = [layers_per_block] * len(down_block_types)
415
+
416
+ if isinstance(transformer_layers_per_block, int):
417
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
418
+
419
+ if class_embeddings_concat:
420
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
421
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
422
+ # regular time embeddings
423
+ blocks_time_embed_dim = time_embed_dim * 2
424
+ else:
425
+ blocks_time_embed_dim = time_embed_dim
426
+
427
+ # down
428
+ output_channel = block_out_channels[0]
429
+ for i, down_block_type in enumerate(down_block_types):
430
+ input_channel = output_channel
431
+ output_channel = block_out_channels[i]
432
+ is_final_block = i == len(block_out_channels) - 1
433
+
434
+ down_block = get_down_block(
435
+ down_block_type,
436
+ num_layers=layers_per_block[i],
437
+ transformer_layers_per_block=transformer_layers_per_block[i],
438
+ in_channels=input_channel,
439
+ out_channels=output_channel,
440
+ temb_channels=blocks_time_embed_dim,
441
+ add_downsample=not is_final_block,
442
+ resnet_eps=norm_eps,
443
+ resnet_act_fn=act_fn,
444
+ resnet_groups=norm_num_groups,
445
+ cross_attention_dim=cross_attention_dim[i],
446
+ num_attention_heads=num_attention_heads[i],
447
+ downsample_padding=downsample_padding,
448
+ dual_cross_attention=dual_cross_attention,
449
+ use_linear_projection=use_linear_projection,
450
+ only_cross_attention=only_cross_attention[i],
451
+ upcast_attention=upcast_attention,
452
+ resnet_time_scale_shift=resnet_time_scale_shift,
453
+ attention_type=attention_type,
454
+ resnet_skip_time_act=resnet_skip_time_act,
455
+ resnet_out_scale_factor=resnet_out_scale_factor,
456
+ cross_attention_norm=cross_attention_norm,
457
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
+ )
459
+ self.down_blocks.append(down_block)
460
+
461
+ # mid
462
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
463
+ self.mid_block = UNetMidBlock2DCrossAttn(
464
+ transformer_layers_per_block=transformer_layers_per_block[-1],
465
+ in_channels=block_out_channels[-1],
466
+ temb_channels=blocks_time_embed_dim,
467
+ resnet_eps=norm_eps,
468
+ resnet_act_fn=act_fn,
469
+ output_scale_factor=mid_block_scale_factor,
470
+ resnet_time_scale_shift=resnet_time_scale_shift,
471
+ cross_attention_dim=cross_attention_dim[-1],
472
+ num_attention_heads=num_attention_heads[-1],
473
+ resnet_groups=norm_num_groups,
474
+ dual_cross_attention=dual_cross_attention,
475
+ use_linear_projection=use_linear_projection,
476
+ upcast_attention=upcast_attention,
477
+ attention_type=attention_type,
478
+ )
479
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
480
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
481
+ in_channels=block_out_channels[-1],
482
+ temb_channels=blocks_time_embed_dim,
483
+ resnet_eps=norm_eps,
484
+ resnet_act_fn=act_fn,
485
+ output_scale_factor=mid_block_scale_factor,
486
+ cross_attention_dim=cross_attention_dim[-1],
487
+ attention_head_dim=attention_head_dim[-1],
488
+ resnet_groups=norm_num_groups,
489
+ resnet_time_scale_shift=resnet_time_scale_shift,
490
+ skip_time_act=resnet_skip_time_act,
491
+ only_cross_attention=mid_block_only_cross_attention,
492
+ cross_attention_norm=cross_attention_norm,
493
+ )
494
+ elif mid_block_type is None:
495
+ self.mid_block = None
496
+ else:
497
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
498
+
499
+ # count how many layers upsample the images
500
+ self.num_upsamplers = 0
501
+
502
+ # up
503
+ reversed_block_out_channels = list(reversed(block_out_channels))
504
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
505
+ reversed_layers_per_block = list(reversed(layers_per_block))
506
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
507
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
508
+ only_cross_attention = list(reversed(only_cross_attention))
509
+
510
+ output_channel = reversed_block_out_channels[0]
511
+ for i, up_block_type in enumerate(up_block_types):
512
+ is_final_block = i == len(block_out_channels) - 1
513
+
514
+ prev_output_channel = output_channel
515
+ output_channel = reversed_block_out_channels[i]
516
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
517
+
518
+ # add upsample block for all BUT final layer
519
+ if not is_final_block:
520
+ add_upsample = True
521
+ self.num_upsamplers += 1
522
+ else:
523
+ add_upsample = False
524
+
525
+ up_block = get_up_block(
526
+ up_block_type,
527
+ num_layers=reversed_layers_per_block[i] + 1,
528
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
529
+ in_channels=input_channel,
530
+ out_channels=output_channel,
531
+ prev_output_channel=prev_output_channel,
532
+ temb_channels=blocks_time_embed_dim,
533
+ add_upsample=add_upsample,
534
+ resnet_eps=norm_eps,
535
+ resnet_act_fn=act_fn,
536
+ resnet_groups=norm_num_groups,
537
+ cross_attention_dim=reversed_cross_attention_dim[i],
538
+ num_attention_heads=reversed_num_attention_heads[i],
539
+ dual_cross_attention=dual_cross_attention,
540
+ use_linear_projection=use_linear_projection,
541
+ only_cross_attention=only_cross_attention[i],
542
+ upcast_attention=upcast_attention,
543
+ resnet_time_scale_shift=resnet_time_scale_shift,
544
+ attention_type=attention_type,
545
+ resnet_skip_time_act=resnet_skip_time_act,
546
+ resnet_out_scale_factor=resnet_out_scale_factor,
547
+ cross_attention_norm=cross_attention_norm,
548
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
549
+ )
550
+ self.up_blocks.append(up_block)
551
+ prev_output_channel = output_channel
552
+
553
+ # out
554
+ if norm_num_groups is not None:
555
+ self.conv_norm_out = nn.GroupNorm(
556
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
557
+ )
558
+
559
+ self.conv_act = get_activation(act_fn)
560
+
561
+ else:
562
+ self.conv_norm_out = None
563
+ self.conv_act = None
564
+
565
+ conv_out_padding = (conv_out_kernel - 1) // 2
566
+ self.conv_out = nn.Conv2d(
567
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
568
+ )
569
+
570
+ if attention_type == "gated":
571
+ positive_len = 768
572
+ if isinstance(cross_attention_dim, int):
573
+ positive_len = cross_attention_dim
574
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
575
+ positive_len = cross_attention_dim[0]
576
+ self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
577
+
578
+ @property
579
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
580
+ r"""
581
+ Returns:
582
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
583
+ indexed by its weight name.
584
+ """
585
+ # set recursively
586
+ processors = {}
587
+
588
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
589
+ if hasattr(module, "get_processor"):
590
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
591
+
592
+ for sub_name, child in module.named_children():
593
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
594
+
595
+ return processors
596
+
597
+ for name, module in self.named_children():
598
+ fn_recursive_add_processors(name, module, processors)
599
+
600
+ return processors
601
+
602
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
603
+ r"""
604
+ Sets the attention processor to use to compute attention.
605
+
606
+ Parameters:
607
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
608
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
609
+ for **all** `Attention` layers.
610
+
611
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
612
+ processor. This is strongly recommended when setting trainable attention processors.
613
+
614
+ """
615
+ count = len(self.attn_processors.keys())
616
+
617
+ if isinstance(processor, dict) and len(processor) != count:
618
+ raise ValueError(
619
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
620
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
621
+ )
622
+
623
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
624
+ if hasattr(module, "set_processor"):
625
+ if not isinstance(processor, dict):
626
+ module.set_processor(processor)
627
+ else:
628
+ module.set_processor(processor.pop(f"{name}.processor"))
629
+
630
+ for sub_name, child in module.named_children():
631
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
632
+
633
+ for name, module in self.named_children():
634
+ fn_recursive_attn_processor(name, module, processor)
635
+
636
+ def set_default_attn_processor(self):
637
+ """
638
+ Disables custom attention processors and sets the default attention implementation.
639
+ """
640
+ self.set_attn_processor(AttnProcessor())
641
+
642
+ def set_attention_slice(self, slice_size):
643
+ r"""
644
+ Enable sliced attention computation.
645
+
646
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
647
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
648
+
649
+ Args:
650
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
651
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
652
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
653
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
654
+ must be a multiple of `slice_size`.
655
+ """
656
+ sliceable_head_dims = []
657
+
658
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
659
+ if hasattr(module, "set_attention_slice"):
660
+ sliceable_head_dims.append(module.sliceable_head_dim)
661
+
662
+ for child in module.children():
663
+ fn_recursive_retrieve_sliceable_dims(child)
664
+
665
+ # retrieve number of attention layers
666
+ for module in self.children():
667
+ fn_recursive_retrieve_sliceable_dims(module)
668
+
669
+ num_sliceable_layers = len(sliceable_head_dims)
670
+
671
+ if slice_size == "auto":
672
+ # half the attention head size is usually a good trade-off between
673
+ # speed and memory
674
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
675
+ elif slice_size == "max":
676
+ # make smallest slice possible
677
+ slice_size = num_sliceable_layers * [1]
678
+
679
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
680
+
681
+ if len(slice_size) != len(sliceable_head_dims):
682
+ raise ValueError(
683
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
684
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
685
+ )
686
+
687
+ for i in range(len(slice_size)):
688
+ size = slice_size[i]
689
+ dim = sliceable_head_dims[i]
690
+ if size is not None and size > dim:
691
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
692
+
693
+ # Recursively walk through all the children.
694
+ # Any children which exposes the set_attention_slice method
695
+ # gets the message
696
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
697
+ if hasattr(module, "set_attention_slice"):
698
+ module.set_attention_slice(slice_size.pop())
699
+
700
+ for child in module.children():
701
+ fn_recursive_set_attention_slice(child, slice_size)
702
+
703
+ reversed_slice_size = list(reversed(slice_size))
704
+ for module in self.children():
705
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
706
+
707
+ def _set_gradient_checkpointing(self, module, value=False):
708
+ if hasattr(module, "gradient_checkpointing"):
709
+ module.gradient_checkpointing = value
710
+
711
+ def forward(
712
+ self,
713
+ sample: torch.FloatTensor,
714
+ timestep: Union[torch.Tensor, float, int],
715
+ encoder_hidden_states: torch.Tensor,
716
+ class_labels: Optional[torch.Tensor] = None,
717
+ timestep_cond: Optional[torch.Tensor] = None,
718
+ attention_mask: Optional[torch.Tensor] = None,
719
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
720
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
721
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
722
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
723
+ encoder_attention_mask: Optional[torch.Tensor] = None,
724
+ return_dict: bool = True,
725
+ image_encoder_hidden_states: torch.Tensor = None,
726
+ ) -> Union[UNet2DConditionOutput, Tuple]:
727
+ r"""
728
+ The [`UNet2DConditionModel`] forward method.
729
+
730
+ Args:
731
+ sample (`torch.FloatTensor`):
732
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
733
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
734
+ encoder_hidden_states (`torch.FloatTensor`):
735
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
736
+ encoder_attention_mask (`torch.Tensor`):
737
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
738
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
739
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
740
+ return_dict (`bool`, *optional*, defaults to `True`):
741
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
742
+ tuple.
743
+ cross_attention_kwargs (`dict`, *optional*):
744
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
745
+ added_cond_kwargs: (`dict`, *optional*):
746
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
747
+ are passed along to the UNet blocks.
748
+
749
+ Returns:
750
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
751
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
752
+ a `tuple` is returned where the first element is the sample tensor.
753
+ """
754
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
755
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
756
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
757
+ # on the fly if necessary.
758
+ default_overall_up_factor = 2**self.num_upsamplers
759
+
760
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
761
+ forward_upsample_size = False
762
+ upsample_size = None
763
+
764
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
765
+ logger.info("Forward upsample size to force interpolation output size.")
766
+ forward_upsample_size = True
767
+
768
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
769
+ # expects mask of shape:
770
+ # [batch, key_tokens]
771
+ # adds singleton query_tokens dimension:
772
+ # [batch, 1, key_tokens]
773
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
774
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
775
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
776
+ if attention_mask is not None:
777
+ # assume that mask is expressed as:
778
+ # (1 = keep, 0 = discard)
779
+ # convert mask into a bias that can be added to attention scores:
780
+ # (keep = +0, discard = -10000.0)
781
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
782
+ attention_mask = attention_mask.unsqueeze(1)
783
+
784
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
785
+ if encoder_attention_mask is not None:
786
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
787
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
788
+
789
+ # 0. center input if necessary
790
+ if self.config.center_input_sample:
791
+ sample = 2 * sample - 1.0
792
+
793
+ # 1. time
794
+ timesteps = timestep
795
+ if not torch.is_tensor(timesteps):
796
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
797
+ # This would be a good case for the `match` statement (Python 3.10+)
798
+ is_mps = sample.device.type == "mps"
799
+ if isinstance(timestep, float):
800
+ dtype = torch.float32 if is_mps else torch.float64
801
+ else:
802
+ dtype = torch.int32 if is_mps else torch.int64
803
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
804
+ elif len(timesteps.shape) == 0:
805
+ timesteps = timesteps[None].to(sample.device)
806
+
807
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
808
+ timesteps = timesteps.expand(sample.shape[0])
809
+
810
+ t_emb = self.time_proj(timesteps)
811
+
812
+ # `Timesteps` does not contain any weights and will always return f32 tensors
813
+ # but time_embedding might actually be running in fp16. so we need to cast here.
814
+ # there might be better ways to encapsulate this.
815
+ t_emb = t_emb.to(dtype=sample.dtype)
816
+
817
+ emb = self.time_embedding(t_emb, timestep_cond)
818
+ aug_emb = None
819
+
820
+ if self.class_embedding is not None:
821
+ if class_labels is None:
822
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
823
+
824
+ if self.config.class_embed_type == "timestep":
825
+ class_labels = self.time_proj(class_labels)
826
+
827
+ # `Timesteps` does not contain any weights and will always return f32 tensors
828
+ # there might be better ways to encapsulate this.
829
+ class_labels = class_labels.to(dtype=sample.dtype)
830
+
831
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
832
+
833
+ if self.config.class_embeddings_concat:
834
+ emb = torch.cat([emb, class_emb], dim=-1)
835
+ else:
836
+ emb = emb + class_emb
837
+
838
+ if self.config.addition_embed_type == "text":
839
+ aug_emb = self.add_embedding(encoder_hidden_states)
840
+ elif self.config.addition_embed_type == "text_image":
841
+ # Kandinsky 2.1 - style
842
+ if "image_embeds" not in added_cond_kwargs:
843
+ raise ValueError(
844
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
845
+ )
846
+
847
+ image_embs = added_cond_kwargs.get("image_embeds")
848
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
849
+ aug_emb = self.add_embedding(text_embs, image_embs)
850
+ elif self.config.addition_embed_type == "text_time":
851
+ # SDXL - style
852
+ if "text_embeds" not in added_cond_kwargs:
853
+ raise ValueError(
854
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
855
+ )
856
+ text_embeds = added_cond_kwargs.get("text_embeds")
857
+ if "time_ids" not in added_cond_kwargs:
858
+ raise ValueError(
859
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
860
+ )
861
+ time_ids = added_cond_kwargs.get("time_ids")
862
+ time_embeds = self.add_time_proj(time_ids.flatten())
863
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
864
+
865
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
866
+ add_embeds = add_embeds.to(emb.dtype)
867
+ aug_emb = self.add_embedding(add_embeds)
868
+ elif self.config.addition_embed_type == "image":
869
+ # Kandinsky 2.2 - style
870
+ if "image_embeds" not in added_cond_kwargs:
871
+ raise ValueError(
872
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
873
+ )
874
+ image_embs = added_cond_kwargs.get("image_embeds")
875
+ aug_emb = self.add_embedding(image_embs)
876
+ elif self.config.addition_embed_type == "image_hint":
877
+ # Kandinsky 2.2 - style
878
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
879
+ raise ValueError(
880
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
881
+ )
882
+ image_embs = added_cond_kwargs.get("image_embeds")
883
+ hint = added_cond_kwargs.get("hint")
884
+ aug_emb, hint = self.add_embedding(image_embs, hint)
885
+ sample = torch.cat([sample, hint], dim=1)
886
+
887
+ emb = emb + aug_emb if aug_emb is not None else emb
888
+
889
+ if self.time_embed_act is not None:
890
+ emb = self.time_embed_act(emb)
891
+
892
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
893
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
894
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
895
+ # Kadinsky 2.1 - style
896
+ if "image_embeds" not in added_cond_kwargs:
897
+ raise ValueError(
898
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
899
+ )
900
+
901
+ image_embeds = added_cond_kwargs.get("image_embeds")
902
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
903
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
904
+ # Kandinsky 2.2 - style
905
+ if "image_embeds" not in added_cond_kwargs:
906
+ raise ValueError(
907
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
908
+ )
909
+ image_embeds = added_cond_kwargs.get("image_embeds")
910
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
911
+ # 2. pre-process
912
+ sample = self.conv_in(sample)
913
+
914
+ # 2.5 GLIGEN position net
915
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
916
+ cross_attention_kwargs = cross_attention_kwargs.copy()
917
+ gligen_args = cross_attention_kwargs.pop("gligen")
918
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
919
+
920
+ # 3. down
921
+
922
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
923
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
924
+
925
+ down_block_res_samples = (sample,)
926
+ for downsample_block in self.down_blocks:
927
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
928
+ # For t2i-adapter CrossAttnDownBlock2D
929
+ additional_residuals = {}
930
+ if is_adapter and len(down_block_additional_residuals) > 0:
931
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
932
+
933
+ sample, res_samples = downsample_block(
934
+ hidden_states=sample,
935
+ temb=emb,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ attention_mask=attention_mask,
938
+ cross_attention_kwargs=cross_attention_kwargs,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ image_encoder_hidden_states=image_encoder_hidden_states,
941
+ **additional_residuals,
942
+ )
943
+ else:
944
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
945
+
946
+ if is_adapter and len(down_block_additional_residuals) > 0:
947
+ sample += down_block_additional_residuals.pop(0)
948
+
949
+ down_block_res_samples += res_samples
950
+
951
+ if is_controlnet:
952
+ new_down_block_res_samples = ()
953
+
954
+ for down_block_res_sample, down_block_additional_residual in zip(
955
+ down_block_res_samples, down_block_additional_residuals
956
+ ):
957
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
958
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
959
+
960
+ down_block_res_samples = new_down_block_res_samples
961
+
962
+ # 4. mid
963
+ if self.mid_block is not None:
964
+ sample = self.mid_block(
965
+ sample,
966
+ emb,
967
+ encoder_hidden_states=encoder_hidden_states,
968
+ attention_mask=attention_mask,
969
+ cross_attention_kwargs=cross_attention_kwargs,
970
+ encoder_attention_mask=encoder_attention_mask,
971
+ image_encoder_hidden_states=image_encoder_hidden_states,
972
+ )
973
+ # To support T2I-Adapter-XL
974
+ if (
975
+ is_adapter
976
+ and len(down_block_additional_residuals) > 0
977
+ and sample.shape == down_block_additional_residuals[0].shape
978
+ ):
979
+ sample += down_block_additional_residuals.pop(0)
980
+
981
+ if is_controlnet:
982
+ sample = sample + mid_block_additional_residual
983
+
984
+ # 5. up
985
+ for i, upsample_block in enumerate(self.up_blocks):
986
+ is_final_block = i == len(self.up_blocks) - 1
987
+
988
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
989
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
990
+
991
+ # if we have not reached the final block and need to forward the
992
+ # upsample size, we do it here
993
+ if not is_final_block and forward_upsample_size:
994
+ upsample_size = down_block_res_samples[-1].shape[2:]
995
+
996
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
997
+ sample = upsample_block(
998
+ hidden_states=sample,
999
+ temb=emb,
1000
+ res_hidden_states_tuple=res_samples,
1001
+ encoder_hidden_states=encoder_hidden_states,
1002
+ cross_attention_kwargs=cross_attention_kwargs,
1003
+ upsample_size=upsample_size,
1004
+ attention_mask=attention_mask,
1005
+ encoder_attention_mask=encoder_attention_mask,
1006
+ image_encoder_hidden_states=image_encoder_hidden_states,
1007
+ )
1008
+ else:
1009
+ sample = upsample_block(
1010
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1011
+ )
1012
+
1013
+ # 6. post-process
1014
+ if self.conv_norm_out:
1015
+ sample = self.conv_norm_out(sample)
1016
+ sample = self.conv_act(sample)
1017
+ sample = self.conv_out(sample)
1018
+
1019
+ if not return_dict:
1020
+ return (sample,)
1021
+
1022
+ return UNet2DConditionOutput(sample=sample)
1023
+
1024
+ @classmethod
1025
+ def from_pretrained_orig(cls, pretrained_model_path, subfolder=None, **kwargs):
1026
+ if subfolder is not None:
1027
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1028
+
1029
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1030
+ if not os.path.isfile(config_file):
1031
+ raise RuntimeError(f"{config_file} does not exist")
1032
+ with open(config_file, "r") as f:
1033
+ config = json.load(f)
1034
+
1035
+ from diffusers.utils import WEIGHTS_NAME
1036
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
1037
+
1038
+
1039
+ model = cls.from_config(config)
1040
+
1041
+ ## for .bin file
1042
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1043
+ # if not os.path.isfile(model_file):
1044
+ # raise RuntimeError(f"{model_file} does not exist")
1045
+ # state_dict = torch.load(model_file, map_location="cpu")
1046
+ # model.load_state_dict(state_dict, strict=False)
1047
+
1048
+ ## for .safetensors file
1049
+ import safetensors
1050
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
1051
+ if not os.path.isfile(model_file):
1052
+ raise RuntimeError(f"{model_file} does not exist")
1053
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
1054
+ model.load_state_dict(state_dict, strict=False)
1055
+
1056
+ return model
1057
+
1058
+ @classmethod
1059
+ def from_pretrained_safetensor(cls, pretrained_model_path, subfolder=None, **kwargs):
1060
+ if subfolder is not None:
1061
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1062
+
1063
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1064
+ if not os.path.isfile(config_file):
1065
+ raise RuntimeError(f"{config_file} does not exist")
1066
+ with open(config_file, "r") as f:
1067
+ config = json.load(f)
1068
+
1069
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
1070
+ model = cls.from_config(config)
1071
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1072
+ if not os.path.isfile(model_file):
1073
+ raise RuntimeError(f"{model_file} does not exist")
1074
+ state_dict = torch.load(model_file, map_location="cpu")
1075
+ for k, v in model.state_dict().items():
1076
+ if 'attn2_plus' in k:
1077
+ print(k)
1078
+ state_dict.update({k: v})
1079
+ model.load_state_dict(state_dict, strict=False)
1080
+
1081
+ return model
models/vit_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab)
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ # Based on code from https://github.com/isl-org/DPT
24
+
25
+ """Flexible configuration and feature extraction of timm VisionTransformers."""
26
+
27
+ import types
28
+ import math
29
+ from typing import Callable
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class AddReadout(nn.Module):
37
+ def __init__(self, start_index: bool = 1):
38
+ super(AddReadout, self).__init__()
39
+ self.start_index = start_index
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ if self.start_index == 2:
43
+ readout = (x[:, 0] + x[:, 1]) / 2
44
+ else:
45
+ readout = x[:, 0]
46
+ return x[:, self.start_index:] + readout.unsqueeze(1)
47
+
48
+
49
+ class Transpose(nn.Module):
50
+ def __init__(self, dim0: int, dim1: int):
51
+ super(Transpose, self).__init__()
52
+ self.dim0 = dim0
53
+ self.dim1 = dim1
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ x = x.transpose(self.dim0, self.dim1)
57
+ return x.contiguous()
58
+
59
+
60
+ def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict:
61
+ _, _, H, W = x.size()
62
+ _ = pretrained.model.forward_flex(x)
63
+ return {k: pretrained.rearrange(v) for k, v in activations.items()}
64
+
65
+
66
+ def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:
67
+ posemb_tok, posemb_grid = (
68
+ posemb[:, : self.start_index],
69
+ posemb[0, self.start_index :],
70
+ )
71
+
72
+ gs_old = int(math.sqrt(len(posemb_grid)))
73
+
74
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
75
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
76
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
77
+
78
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
79
+
80
+ return posemb
81
+
82
+
83
+ def forward_flex(self, x: torch.Tensor) -> torch.Tensor:
84
+ # patch proj and dynamically resize
85
+ B, C, H, W = x.size()
86
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
87
+ pos_embed = self._resize_pos_embed(
88
+ self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]
89
+ )
90
+
91
+ # add cls token
92
+ cls_tokens = self.cls_token.expand(
93
+ x.size(0), -1, -1
94
+ )
95
+ x = torch.cat((cls_tokens, x), dim=1)
96
+
97
+ # forward pass
98
+ x = x + pos_embed
99
+ x = self.pos_drop(x)
100
+
101
+ for blk in self.blocks:
102
+ x = blk(x)
103
+
104
+ x = self.norm(x)
105
+ return x
106
+
107
+
108
+ activations = {}
109
+
110
+
111
+ def get_activation(name: str) -> Callable:
112
+ def hook(model, input, output):
113
+ activations[name] = output
114
+ return hook
115
+
116
+
117
+ def make_sd_backbone(
118
+ model: nn.Module,
119
+ hooks: list[int] = [2, 5, 8, 11],
120
+ hook_patch: bool = True,
121
+ start_index: list[int] = 1,
122
+ ):
123
+ assert len(hooks) == 4
124
+
125
+ pretrained = nn.Module()
126
+ pretrained.model = model
127
+
128
+ # add hooks
129
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
130
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
131
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
132
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
133
+ if hook_patch:
134
+ pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
135
+
136
+ # configure readout
137
+ pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
138
+ pretrained.model.start_index = start_index
139
+ pretrained.model.patch_size = patch_size
140
+
141
+ # We inject this function into the VisionTransformer instances so that
142
+ # we can use it with interpolated position embeddings without modifying the library source.
143
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
144
+ pretrained.model._resize_pos_embed = types.MethodType(
145
+ _resize_pos_embed, pretrained.model
146
+ )
147
+
148
+ return pretrained
149
+
150
+ def make_vit_backbone(
151
+ model: nn.Module,
152
+ patch_size: list[int] = [16, 16],
153
+ hooks: list[int] = [2, 5, 8, 11],
154
+ hook_patch: bool = True,
155
+ start_index: list[int] = 1,
156
+ ):
157
+ assert len(hooks) == 4
158
+
159
+ pretrained = nn.Module()
160
+ pretrained.model = model
161
+
162
+ # add hooks
163
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
164
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
165
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
166
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
167
+ if hook_patch:
168
+ pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
169
+
170
+ # configure readout
171
+ pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
172
+ pretrained.model.start_index = start_index
173
+ pretrained.model.patch_size = patch_size
174
+
175
+ # We inject this function into the VisionTransformer instances so that
176
+ # we can use it with interpolated position embeddings without modifying the library source.
177
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
178
+ pretrained.model._resize_pos_embed = types.MethodType(
179
+ _resize_pos_embed, pretrained.model
180
+ )
181
+
182
+ return pretrained