adirik commited on
Commit
f4c3c2b
1 Parent(s): 04bc7a6

update repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +2 -0
  3. LICENSE.txt +97 -0
  4. dnnlib/__init__.py +9 -0
  5. dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
  6. dnnlib/__pycache__/util.cpython-38.pyc +0 -0
  7. dnnlib/util.py +473 -0
  8. encoder4editing/LICENSE +21 -0
  9. encoder4editing/configs/__init__.py +0 -0
  10. encoder4editing/configs/data_configs.py +41 -0
  11. encoder4editing/configs/paths_config.py +28 -0
  12. encoder4editing/configs/transforms_config.py +62 -0
  13. encoder4editing/criteria/__init__.py +0 -0
  14. encoder4editing/criteria/id_loss.py +47 -0
  15. encoder4editing/criteria/lpips/__init__.py +0 -0
  16. encoder4editing/criteria/lpips/lpips.py +35 -0
  17. encoder4editing/criteria/lpips/networks.py +96 -0
  18. encoder4editing/criteria/lpips/utils.py +30 -0
  19. encoder4editing/criteria/moco_loss.py +71 -0
  20. encoder4editing/criteria/w_norm.py +14 -0
  21. encoder4editing/datasets/__init__.py +0 -0
  22. encoder4editing/datasets/gt_res_dataset.py +32 -0
  23. encoder4editing/datasets/images_dataset.py +33 -0
  24. encoder4editing/datasets/inference_dataset.py +25 -0
  25. encoder4editing/editings/ganspace.py +22 -0
  26. encoder4editing/editings/ganspace_pca/cars_pca.pt +3 -0
  27. encoder4editing/editings/ganspace_pca/ffhq_pca.pt +3 -0
  28. encoder4editing/editings/interfacegan_directions/age.pt +3 -0
  29. encoder4editing/editings/interfacegan_directions/pose.pt +3 -0
  30. encoder4editing/editings/interfacegan_directions/smile.pt +3 -0
  31. encoder4editing/editings/latent_editor.py +45 -0
  32. encoder4editing/editings/sefa.py +46 -0
  33. encoder4editing/environment/e4e_env.yaml +73 -0
  34. encoder4editing/infer.py +134 -0
  35. encoder4editing/metrics/LEC.py +134 -0
  36. encoder4editing/models/__init__.py +0 -0
  37. encoder4editing/models/discriminator.py +20 -0
  38. encoder4editing/models/encoders/__init__.py +0 -0
  39. encoder4editing/models/encoders/helpers.py +140 -0
  40. encoder4editing/models/encoders/model_irse.py +84 -0
  41. encoder4editing/models/encoders/psp_encoders.py +235 -0
  42. encoder4editing/models/latent_codes_pool.py +55 -0
  43. encoder4editing/models/psp.py +100 -0
  44. encoder4editing/models/stylegan2/__init__.py +0 -0
  45. encoder4editing/models/stylegan2/model.py +673 -0
  46. encoder4editing/models/stylegan2/op/__init__.py +2 -0
  47. encoder4editing/models/stylegan2/op/fused_act.py +85 -0
  48. encoder4editing/models/stylegan2/op/fused_bias_act.cpp +21 -0
  49. encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
  50. encoder4editing/models/stylegan2/op/upfirdn2d.cpp +23 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth* filter=lfs diff=lfs merge=lfs -text
36
+ filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (206 Bytes). View file
 
dnnlib/__pycache__/util.cpython-38.pyc ADDED
Binary file (13.7 kB). View file
 
dnnlib/util.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ class EasyDict(dict):
37
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38
+
39
+ def __getattr__(self, name: str) -> Any:
40
+ try:
41
+ return self[name]
42
+ except KeyError:
43
+ raise AttributeError(name)
44
+
45
+ def __setattr__(self, name: str, value: Any) -> None:
46
+ self[name] = value
47
+
48
+ def __delattr__(self, name: str) -> None:
49
+ del self[name]
50
+
51
+
52
+ class Logger(object):
53
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54
+
55
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56
+ self.file = None
57
+
58
+ if file_name is not None:
59
+ self.file = open(file_name, file_mode)
60
+
61
+ self.should_flush = should_flush
62
+ self.stdout = sys.stdout
63
+ self.stderr = sys.stderr
64
+
65
+ sys.stdout = self
66
+ sys.stderr = self
67
+
68
+ def __enter__(self) -> "Logger":
69
+ return self
70
+
71
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72
+ self.close()
73
+
74
+ def write(self, text: Union[str, bytes]) -> None:
75
+ """Write text to stdout (and a file) and optionally flush."""
76
+ if isinstance(text, bytes):
77
+ text = text.decode()
78
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
79
+ return
80
+
81
+ if self.file is not None:
82
+ self.file.write(text)
83
+
84
+ self.stdout.write(text)
85
+
86
+ if self.should_flush:
87
+ self.flush()
88
+
89
+ def flush(self) -> None:
90
+ """Flush written text to both stdout and a file, if open."""
91
+ if self.file is not None:
92
+ self.file.flush()
93
+
94
+ self.stdout.flush()
95
+
96
+ def close(self) -> None:
97
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
98
+ self.flush()
99
+
100
+ # if using multiple loggers, prevent closing in wrong order
101
+ if sys.stdout is self:
102
+ sys.stdout = self.stdout
103
+ if sys.stderr is self:
104
+ sys.stderr = self.stderr
105
+
106
+ if self.file is not None:
107
+ self.file.close()
108
+ self.file = None
109
+
110
+
111
+ # Cache directories
112
+ # ------------------------------------------------------------------------------------------
113
+
114
+ _dnnlib_cache_dir = None
115
+
116
+ def set_cache_dir(path: str) -> None:
117
+ global _dnnlib_cache_dir
118
+ _dnnlib_cache_dir = path
119
+
120
+ def make_cache_dir_path(*paths: str) -> str:
121
+ if _dnnlib_cache_dir is not None:
122
+ return os.path.join(_dnnlib_cache_dir, *paths)
123
+ if 'DNNLIB_CACHE_DIR' in os.environ:
124
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
125
+ if 'HOME' in os.environ:
126
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
127
+ if 'USERPROFILE' in os.environ:
128
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
129
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
130
+
131
+ # Small util functions
132
+ # ------------------------------------------------------------------------------------------
133
+
134
+
135
+ def format_time(seconds: Union[int, float]) -> str:
136
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
137
+ s = int(np.rint(seconds))
138
+
139
+ if s < 60:
140
+ return "{0}s".format(s)
141
+ elif s < 60 * 60:
142
+ return "{0}m {1:02}s".format(s // 60, s % 60)
143
+ elif s < 24 * 60 * 60:
144
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
145
+ else:
146
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
147
+
148
+
149
+ def ask_yes_no(question: str) -> bool:
150
+ """Ask the user the question until the user inputs a valid answer."""
151
+ while True:
152
+ try:
153
+ print("{0} [y/n]".format(question))
154
+ return strtobool(input().lower())
155
+ except ValueError:
156
+ pass
157
+
158
+
159
+ def tuple_product(t: Tuple) -> Any:
160
+ """Calculate the product of the tuple elements."""
161
+ result = 1
162
+
163
+ for v in t:
164
+ result *= v
165
+
166
+ return result
167
+
168
+
169
+ _str_to_ctype = {
170
+ "uint8": ctypes.c_ubyte,
171
+ "uint16": ctypes.c_uint16,
172
+ "uint32": ctypes.c_uint32,
173
+ "uint64": ctypes.c_uint64,
174
+ "int8": ctypes.c_byte,
175
+ "int16": ctypes.c_int16,
176
+ "int32": ctypes.c_int32,
177
+ "int64": ctypes.c_int64,
178
+ "float32": ctypes.c_float,
179
+ "float64": ctypes.c_double
180
+ }
181
+
182
+
183
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
184
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
185
+ type_str = None
186
+
187
+ if isinstance(type_obj, str):
188
+ type_str = type_obj
189
+ elif hasattr(type_obj, "__name__"):
190
+ type_str = type_obj.__name__
191
+ elif hasattr(type_obj, "name"):
192
+ type_str = type_obj.name
193
+ else:
194
+ raise RuntimeError("Cannot infer type name from input")
195
+
196
+ assert type_str in _str_to_ctype.keys()
197
+
198
+ my_dtype = np.dtype(type_str)
199
+ my_ctype = _str_to_ctype[type_str]
200
+
201
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
202
+
203
+ return my_dtype, my_ctype
204
+
205
+
206
+ def is_pickleable(obj: Any) -> bool:
207
+ try:
208
+ with io.BytesIO() as stream:
209
+ pickle.dump(obj, stream)
210
+ return True
211
+ except:
212
+ return False
213
+
214
+
215
+ # Functionality to import modules/objects by name, and call functions by name
216
+ # ------------------------------------------------------------------------------------------
217
+
218
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
219
+ """Searches for the underlying module behind the name to some python object.
220
+ Returns the module and the object name (original name with module part removed)."""
221
+
222
+ # allow convenience shorthands, substitute them by full names
223
+ obj_name = re.sub("^np.", "numpy.", obj_name)
224
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
225
+
226
+ # list alternatives for (module_name, local_obj_name)
227
+ parts = obj_name.split(".")
228
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
229
+
230
+ # try each alternative in turn
231
+ for module_name, local_obj_name in name_pairs:
232
+ try:
233
+ module = importlib.import_module(module_name) # may raise ImportError
234
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
235
+ return module, local_obj_name
236
+ except:
237
+ pass
238
+
239
+ # maybe some of the modules themselves contain errors?
240
+ for module_name, _local_obj_name in name_pairs:
241
+ try:
242
+ importlib.import_module(module_name) # may raise ImportError
243
+ except ImportError:
244
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
245
+ raise
246
+
247
+ # maybe the requested attribute is missing?
248
+ for module_name, local_obj_name in name_pairs:
249
+ try:
250
+ module = importlib.import_module(module_name) # may raise ImportError
251
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
252
+ except ImportError:
253
+ pass
254
+
255
+ # we are out of luck, but we have no idea why
256
+ raise ImportError(obj_name)
257
+
258
+
259
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
260
+ """Traverses the object name and returns the last (rightmost) python object."""
261
+ if obj_name == '':
262
+ return module
263
+ obj = module
264
+ for part in obj_name.split("."):
265
+ obj = getattr(obj, part)
266
+ return obj
267
+
268
+
269
+ def get_obj_by_name(name: str) -> Any:
270
+ """Finds the python object with the given name."""
271
+ module, obj_name = get_module_from_obj_name(name)
272
+ return get_obj_from_module(module, obj_name)
273
+
274
+
275
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
276
+ """Finds the python object with the given name and calls it as a function."""
277
+ assert func_name is not None
278
+ func_obj = get_obj_by_name(func_name)
279
+ assert callable(func_obj)
280
+ return func_obj(*args, **kwargs)
281
+
282
+
283
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
284
+ """Finds the python class with the given name and constructs it with the given arguments."""
285
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
286
+
287
+
288
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
289
+ """Get the directory path of the module containing the given object name."""
290
+ module, _ = get_module_from_obj_name(obj_name)
291
+ return os.path.dirname(inspect.getfile(module))
292
+
293
+
294
+ def is_top_level_function(obj: Any) -> bool:
295
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
296
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
297
+
298
+
299
+ def get_top_level_function_name(obj: Any) -> str:
300
+ """Return the fully-qualified name of a top-level function."""
301
+ assert is_top_level_function(obj)
302
+ module = obj.__module__
303
+ if module == '__main__':
304
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
305
+ return module + "." + obj.__name__
306
+
307
+
308
+ # File system helpers
309
+ # ------------------------------------------------------------------------------------------
310
+
311
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
312
+ """List all files recursively in a given directory while ignoring given file and directory names.
313
+ Returns list of tuples containing both absolute and relative paths."""
314
+ assert os.path.isdir(dir_path)
315
+ base_name = os.path.basename(os.path.normpath(dir_path))
316
+
317
+ if ignores is None:
318
+ ignores = []
319
+
320
+ result = []
321
+
322
+ for root, dirs, files in os.walk(dir_path, topdown=True):
323
+ for ignore_ in ignores:
324
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
325
+
326
+ # dirs need to be edited in-place
327
+ for d in dirs_to_remove:
328
+ dirs.remove(d)
329
+
330
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
331
+
332
+ absolute_paths = [os.path.join(root, f) for f in files]
333
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
334
+
335
+ if add_base_to_relative:
336
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
337
+
338
+ assert len(absolute_paths) == len(relative_paths)
339
+ result += zip(absolute_paths, relative_paths)
340
+
341
+ return result
342
+
343
+
344
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
345
+ """Takes in a list of tuples of (src, dst) paths and copies files.
346
+ Will create all necessary directories."""
347
+ for file in files:
348
+ target_dir_name = os.path.dirname(file[1])
349
+
350
+ # will create all intermediate-level directories
351
+ if not os.path.exists(target_dir_name):
352
+ os.makedirs(target_dir_name)
353
+
354
+ shutil.copyfile(file[0], file[1])
355
+
356
+
357
+ # URL helpers
358
+ # ------------------------------------------------------------------------------------------
359
+
360
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
361
+ """Determine whether the given object is a valid URL string."""
362
+ if not isinstance(obj, str) or not "://" in obj:
363
+ return False
364
+ if allow_file_urls and obj.startswith('file://'):
365
+ return True
366
+ try:
367
+ res = requests.compat.urlparse(obj)
368
+ if not res.scheme or not res.netloc or not "." in res.netloc:
369
+ return False
370
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
371
+ if not res.scheme or not res.netloc or not "." in res.netloc:
372
+ return False
373
+ except:
374
+ return False
375
+ return True
376
+
377
+
378
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
379
+ """Download the given URL and return a binary-mode file object to access the data."""
380
+ assert num_attempts >= 1
381
+ assert not (return_filename and (not cache))
382
+
383
+ # Doesn't look like an URL scheme so interpret it as a local filename.
384
+ if not re.match('^[a-z]+://', url):
385
+ return url if return_filename else open(url, "rb")
386
+
387
+ # Handle file URLs. This code handles unusual file:// patterns that
388
+ # arise on Windows:
389
+ #
390
+ # file:///c:/foo.txt
391
+ #
392
+ # which would translate to a local '/c:/foo.txt' filename that's
393
+ # invalid. Drop the forward slash for such pathnames.
394
+ #
395
+ # If you touch this code path, you should test it on both Linux and
396
+ # Windows.
397
+ #
398
+ # Some internet resources suggest using urllib.request.url2pathname() but
399
+ # but that converts forward slashes to backslashes and this causes
400
+ # its own set of problems.
401
+ if url.startswith('file://'):
402
+ filename = urllib.parse.urlparse(url).path
403
+ if re.match(r'^/[a-zA-Z]:', filename):
404
+ filename = filename[1:]
405
+ return filename if return_filename else open(filename, "rb")
406
+
407
+ assert is_url(url)
408
+
409
+ # Lookup from cache.
410
+ if cache_dir is None:
411
+ cache_dir = make_cache_dir_path('downloads')
412
+
413
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
414
+ if cache:
415
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
416
+ if len(cache_files) == 1:
417
+ filename = cache_files[0]
418
+ return filename if return_filename else open(filename, "rb")
419
+
420
+ # Download.
421
+ url_name = None
422
+ url_data = None
423
+ with requests.Session() as session:
424
+ if verbose:
425
+ print("Downloading %s ..." % url, end="", flush=True)
426
+ for attempts_left in reversed(range(num_attempts)):
427
+ try:
428
+ with session.get(url) as res:
429
+ res.raise_for_status()
430
+ if len(res.content) == 0:
431
+ raise IOError("No data received")
432
+
433
+ if len(res.content) < 8192:
434
+ content_str = res.content.decode("utf-8")
435
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
436
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
437
+ if len(links) == 1:
438
+ url = requests.compat.urljoin(url, links[0])
439
+ raise IOError("Google Drive virus checker nag")
440
+ if "Google Drive - Quota exceeded" in content_str:
441
+ raise IOError("Google Drive download quota exceeded -- please try again later")
442
+
443
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
444
+ url_name = match[1] if match else url
445
+ url_data = res.content
446
+ if verbose:
447
+ print(" done")
448
+ break
449
+ except KeyboardInterrupt:
450
+ raise
451
+ except:
452
+ if not attempts_left:
453
+ if verbose:
454
+ print(" failed")
455
+ raise
456
+ if verbose:
457
+ print(".", end="", flush=True)
458
+
459
+ # Save to cache.
460
+ if cache:
461
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
462
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
463
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
464
+ os.makedirs(cache_dir, exist_ok=True)
465
+ with open(temp_file, "wb") as f:
466
+ f.write(url_data)
467
+ os.replace(temp_file, cache_file) # atomic
468
+ if return_filename:
469
+ return cache_file
470
+
471
+ # Return data as file object.
472
+ assert not return_filename
473
+ return io.BytesIO(url_data)
encoder4editing/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 omertov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
encoder4editing/configs/__init__.py ADDED
File without changes
encoder4editing/configs/data_configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['celeba_test'],
11
+ 'test_target_root': dataset_paths['celeba_test'],
12
+ },
13
+ 'cars_encode': {
14
+ 'transforms': transforms_config.CarsEncodeTransforms,
15
+ 'train_source_root': dataset_paths['cars_train'],
16
+ 'train_target_root': dataset_paths['cars_train'],
17
+ 'test_source_root': dataset_paths['cars_test'],
18
+ 'test_target_root': dataset_paths['cars_test'],
19
+ },
20
+ 'horse_encode': {
21
+ 'transforms': transforms_config.EncodeTransforms,
22
+ 'train_source_root': dataset_paths['horse_train'],
23
+ 'train_target_root': dataset_paths['horse_train'],
24
+ 'test_source_root': dataset_paths['horse_test'],
25
+ 'test_target_root': dataset_paths['horse_test'],
26
+ },
27
+ 'church_encode': {
28
+ 'transforms': transforms_config.EncodeTransforms,
29
+ 'train_source_root': dataset_paths['church_train'],
30
+ 'train_target_root': dataset_paths['church_train'],
31
+ 'test_source_root': dataset_paths['church_test'],
32
+ 'test_target_root': dataset_paths['church_test'],
33
+ },
34
+ 'cats_encode': {
35
+ 'transforms': transforms_config.EncodeTransforms,
36
+ 'train_source_root': dataset_paths['cats_train'],
37
+ 'train_target_root': dataset_paths['cats_train'],
38
+ 'test_source_root': dataset_paths['cats_test'],
39
+ 'test_target_root': dataset_paths['cats_test'],
40
+ }
41
+ }
encoder4editing/configs/paths_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ # Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
3
+ 'ffhq': '',
4
+ 'celeba_test': '',
5
+
6
+ # Cars Dataset (In the paper: Stanford cars)
7
+ 'cars_train': '',
8
+ 'cars_test': '',
9
+
10
+ # Horse Dataset (In the paper: LSUN Horse)
11
+ 'horse_train': '',
12
+ 'horse_test': '',
13
+
14
+ # Church Dataset (In the paper: LSUN Church)
15
+ 'church_train': '',
16
+ 'church_test': '',
17
+
18
+ # Cats Dataset (In the paper: LSUN Cat)
19
+ 'cats_train': '',
20
+ 'cats_test': ''
21
+ }
22
+
23
+ model_paths = {
24
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
25
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
26
+ 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
27
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
28
+ }
encoder4editing/configs/transforms_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+
4
+
5
+ class TransformsConfig(object):
6
+
7
+ def __init__(self, opts):
8
+ self.opts = opts
9
+
10
+ @abstractmethod
11
+ def get_transforms(self):
12
+ pass
13
+
14
+
15
+ class EncodeTransforms(TransformsConfig):
16
+
17
+ def __init__(self, opts):
18
+ super(EncodeTransforms, self).__init__(opts)
19
+
20
+ def get_transforms(self):
21
+ transforms_dict = {
22
+ 'transform_gt_train': transforms.Compose([
23
+ transforms.Resize((256, 256)),
24
+ transforms.RandomHorizontalFlip(0.5),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
27
+ 'transform_source': None,
28
+ 'transform_test': transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
32
+ 'transform_inference': transforms.Compose([
33
+ transforms.Resize((256, 256)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
36
+ }
37
+ return transforms_dict
38
+
39
+
40
+ class CarsEncodeTransforms(TransformsConfig):
41
+
42
+ def __init__(self, opts):
43
+ super(CarsEncodeTransforms, self).__init__(opts)
44
+
45
+ def get_transforms(self):
46
+ transforms_dict = {
47
+ 'transform_gt_train': transforms.Compose([
48
+ transforms.Resize((192, 256)),
49
+ transforms.RandomHorizontalFlip(0.5),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
52
+ 'transform_source': None,
53
+ 'transform_test': transforms.Compose([
54
+ transforms.Resize((192, 256)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
57
+ 'transform_inference': transforms.Compose([
58
+ transforms.Resize((192, 256)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
61
+ }
62
+ return transforms_dict
encoder4editing/criteria/__init__.py ADDED
File without changes
encoder4editing/criteria/id_loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import model_paths
4
+ from models.encoders.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
13
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14
+ self.facenet.eval()
15
+ for module in [self.facenet, self.face_pool]:
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def extract_feats(self, x):
20
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
21
+ x = self.face_pool(x)
22
+ x_feats = self.facenet(x)
23
+ return x_feats
24
+
25
+ def forward(self, y_hat, y, x):
26
+ n_samples = x.shape[0]
27
+ x_feats = self.extract_feats(x)
28
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
+ y_hat_feats = self.extract_feats(y_hat)
30
+ y_feats = y_feats.detach()
31
+ loss = 0
32
+ sim_improvement = 0
33
+ id_logs = []
34
+ count = 0
35
+ for i in range(n_samples):
36
+ diff_target = y_hat_feats[i].dot(y_feats[i])
37
+ diff_input = y_hat_feats[i].dot(x_feats[i])
38
+ diff_views = y_feats[i].dot(x_feats[i])
39
+ id_logs.append({'diff_target': float(diff_target),
40
+ 'diff_input': float(diff_input),
41
+ 'diff_views': float(diff_views)})
42
+ loss += 1 - diff_target
43
+ id_diff = float(diff_target) - float(diff_views)
44
+ sim_improvement += id_diff
45
+ count += 1
46
+
47
+ return loss / count, sim_improvement / count, id_logs
encoder4editing/criteria/lpips/__init__.py ADDED
File without changes
encoder4editing/criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
encoder4editing/criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
encoder4editing/criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
encoder4editing/criteria/moco_loss.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from configs.paths_config import model_paths
6
+
7
+
8
+ class MocoLoss(nn.Module):
9
+
10
+ def __init__(self, opts):
11
+ super(MocoLoss, self).__init__()
12
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
13
+ self.model = self.__load_model()
14
+ self.model.eval()
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+
18
+ @staticmethod
19
+ def __load_model():
20
+ import torchvision.models as models
21
+ model = models.__dict__["resnet50"]()
22
+ # freeze all layers but the last fc
23
+ for name, param in model.named_parameters():
24
+ if name not in ['fc.weight', 'fc.bias']:
25
+ param.requires_grad = False
26
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
27
+ state_dict = checkpoint['state_dict']
28
+ # rename moco pre-trained keys
29
+ for k in list(state_dict.keys()):
30
+ # retain only encoder_q up to before the embedding layer
31
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
32
+ # remove prefix
33
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
34
+ # delete renamed or unused k
35
+ del state_dict[k]
36
+ msg = model.load_state_dict(state_dict, strict=False)
37
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
38
+ # remove output layer
39
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
40
+ return model
41
+
42
+ def extract_feats(self, x):
43
+ x = F.interpolate(x, size=224)
44
+ x_feats = self.model(x)
45
+ x_feats = nn.functional.normalize(x_feats, dim=1)
46
+ x_feats = x_feats.squeeze()
47
+ return x_feats
48
+
49
+ def forward(self, y_hat, y, x):
50
+ n_samples = x.shape[0]
51
+ x_feats = self.extract_feats(x)
52
+ y_feats = self.extract_feats(y)
53
+ y_hat_feats = self.extract_feats(y_hat)
54
+ y_feats = y_feats.detach()
55
+ loss = 0
56
+ sim_improvement = 0
57
+ sim_logs = []
58
+ count = 0
59
+ for i in range(n_samples):
60
+ diff_target = y_hat_feats[i].dot(y_feats[i])
61
+ diff_input = y_hat_feats[i].dot(x_feats[i])
62
+ diff_views = y_feats[i].dot(x_feats[i])
63
+ sim_logs.append({'diff_target': float(diff_target),
64
+ 'diff_input': float(diff_input),
65
+ 'diff_views': float(diff_views)})
66
+ loss += 1 - diff_target
67
+ sim_diff = float(diff_target) - float(diff_views)
68
+ sim_improvement += sim_diff
69
+ count += 1
70
+
71
+ return loss / count, sim_improvement / count, sim_logs
encoder4editing/criteria/w_norm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class WNormLoss(nn.Module):
6
+
7
+ def __init__(self, start_from_latent_avg=True):
8
+ super(WNormLoss, self).__init__()
9
+ self.start_from_latent_avg = start_from_latent_avg
10
+
11
+ def forward(self, latent, latent_avg=None):
12
+ if self.start_from_latent_avg:
13
+ latent = latent - latent_avg
14
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
encoder4editing/datasets/__init__.py ADDED
File without changes
encoder4editing/datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import torch
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
encoder4editing/datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
encoder4editing/datasets/inference_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None, preprocess=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.preprocess = preprocess
12
+ self.opts = opts
13
+
14
+ def __len__(self):
15
+ return len(self.paths)
16
+
17
+ def __getitem__(self, index):
18
+ from_path = self.paths[index]
19
+ if self.preprocess is not None:
20
+ from_im = self.preprocess(from_path)
21
+ else:
22
+ from_im = Image.open(from_path).convert('RGB')
23
+ if self.transform:
24
+ from_im = self.transform(from_im)
25
+ return from_im
encoder4editing/editings/ganspace.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def edit(latents, pca, edit_directions):
5
+ edit_latents = []
6
+ for latent in latents:
7
+ for pca_idx, start, end, strength in edit_directions:
8
+ delta = get_delta(pca, latent, pca_idx, strength)
9
+ delta_padded = torch.zeros(latent.shape).to('cuda')
10
+ delta_padded[start:end] += delta.repeat(end - start, 1)
11
+ edit_latents.append(latent + delta_padded)
12
+ return torch.stack(edit_latents)
13
+
14
+
15
+ def get_delta(pca, latent, idx, strength):
16
+ # pca: ganspace checkpoint. latent: (16, 512) w+
17
+ w_centered = latent - pca['mean'].to('cuda')
18
+ lat_comp = pca['comp'].to('cuda')
19
+ lat_std = pca['std'].to('cuda')
20
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
21
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
22
+ return delta
encoder4editing/editings/ganspace_pca/cars_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
3
+ size 167562
encoder4editing/editings/ganspace_pca/ffhq_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
3
+ size 167562
encoder4editing/editings/interfacegan_directions/age.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
3
+ size 2808
encoder4editing/editings/interfacegan_directions/pose.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
3
+ size 37624
encoder4editing/editings/interfacegan_directions/smile.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
3
+ size 2808
encoder4editing/editings/latent_editor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.path.append(".")
4
+ sys.path.append("..")
5
+ from editings import ganspace, sefa
6
+ from utils.common import tensor2im
7
+
8
+
9
+ class LatentEditor(object):
10
+ def __init__(self, stylegan_generator, is_cars=False):
11
+ self.generator = stylegan_generator
12
+ self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
13
+
14
+ def apply_ganspace(self, latent, ganspace_pca, edit_directions):
15
+ edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
16
+ return self._latents_to_image(edit_latents)
17
+
18
+ def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
19
+ edit_latents = []
20
+ if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
21
+ for f in range(*factor_range):
22
+ edit_latent = latent + f * direction
23
+ edit_latents.append(edit_latent)
24
+ edit_latents = torch.cat(edit_latents)
25
+ else:
26
+ edit_latents = latent + factor * direction
27
+ return self._latents_to_image(edit_latents)
28
+
29
+ def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
30
+ edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
31
+ return self._latents_to_image(edit_latents)
32
+
33
+ # Currently, in order to apply StyleFlow editings, one should run inference,
34
+ # save the latent codes and load them form the official StyleFlow repository.
35
+ # def apply_styleflow(self):
36
+ # pass
37
+
38
+ def _latents_to_image(self, latents):
39
+ with torch.no_grad():
40
+ images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
41
+ if self.is_cars:
42
+ images = images[:, :, 64:448, :] # 512x512 -> 384x512
43
+ horizontal_concat_image = torch.cat(list(images), 2)
44
+ final_image = tensor2im(horizontal_concat_image)
45
+ return final_image
encoder4editing/editings/sefa.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+
6
+ def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
7
+
8
+ layers, boundaries, values = factorize_weight(generator, indices)
9
+ codes = latents.detach().cpu().numpy() # (1,18,512)
10
+
11
+ # Generate visualization pages.
12
+ distances = np.linspace(start_distance, end_distance, step)
13
+ num_sam = num_samples
14
+ num_sem = semantics
15
+
16
+ edited_latents = []
17
+ for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
18
+ boundary = boundaries[sem_id:sem_id + 1]
19
+ for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
20
+ code = codes[sam_id:sam_id + 1]
21
+ for col_id, d in enumerate(distances, start=1):
22
+ temp_code = code.copy()
23
+ temp_code[:, layers, :] += boundary * d
24
+ edited_latents.append(torch.from_numpy(temp_code).float().cuda())
25
+ return torch.cat(edited_latents)
26
+
27
+
28
+ def factorize_weight(g_ema, layers='all'):
29
+
30
+ weights = []
31
+ if layers == 'all' or 0 in layers:
32
+ weight = g_ema.conv1.conv.modulation.weight.T
33
+ weights.append(weight.cpu().detach().numpy())
34
+
35
+ if layers == 'all':
36
+ layers = list(range(g_ema.num_layers - 1))
37
+ else:
38
+ layers = [l - 1 for l in layers if l != 0]
39
+
40
+ for idx in layers:
41
+ weight = g_ema.convs[idx].conv.modulation.weight.T
42
+ weights.append(weight.cpu().detach().numpy())
43
+ weight = np.concatenate(weights, axis=1).astype(np.float32)
44
+ weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
45
+ eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
46
+ return layers, eigen_vectors.T, eigen_values
encoder4editing/environment/e4e_env.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: e4e_env
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - ca-certificates=2020.4.5.1=hecc5488_0
8
+ - certifi=2020.4.5.1=py36h9f0ad1d_0
9
+ - libedit=3.1.20181209=hc058e9b_0
10
+ - libffi=3.2.1=hd88cf55_4
11
+ - libgcc-ng=9.1.0=hdf63c60_0
12
+ - libstdcxx-ng=9.1.0=hdf63c60_0
13
+ - ncurses=6.2=he6710b0_1
14
+ - ninja=1.10.0=hc9558a2_0
15
+ - openssl=1.1.1g=h516909a_0
16
+ - pip=20.0.2=py36_3
17
+ - python=3.6.7=h0371630_0
18
+ - python_abi=3.6=1_cp36m
19
+ - readline=7.0=h7b6447c_5
20
+ - setuptools=46.4.0=py36_0
21
+ - sqlite=3.31.1=h62c20be_1
22
+ - tk=8.6.8=hbc83047_0
23
+ - wheel=0.34.2=py36_0
24
+ - xz=5.2.5=h7b6447c_0
25
+ - zlib=1.2.11=h7b6447c_3
26
+ - pip:
27
+ - absl-py==0.9.0
28
+ - cachetools==4.1.0
29
+ - chardet==3.0.4
30
+ - cycler==0.10.0
31
+ - decorator==4.4.2
32
+ - future==0.18.2
33
+ - google-auth==1.15.0
34
+ - google-auth-oauthlib==0.4.1
35
+ - grpcio==1.29.0
36
+ - idna==2.9
37
+ - imageio==2.8.0
38
+ - importlib-metadata==1.6.0
39
+ - kiwisolver==1.2.0
40
+ - markdown==3.2.2
41
+ - matplotlib==3.2.1
42
+ - mxnet==1.6.0
43
+ - networkx==2.4
44
+ - numpy==1.18.4
45
+ - oauthlib==3.1.0
46
+ - opencv-python==4.2.0.34
47
+ - pillow==7.1.2
48
+ - protobuf==3.12.1
49
+ - pyasn1==0.4.8
50
+ - pyasn1-modules==0.2.8
51
+ - pyparsing==2.4.7
52
+ - python-dateutil==2.8.1
53
+ - pytorch-lightning==0.7.1
54
+ - pywavelets==1.1.1
55
+ - requests==2.23.0
56
+ - requests-oauthlib==1.3.0
57
+ - rsa==4.0
58
+ - scikit-image==0.17.2
59
+ - scipy==1.4.1
60
+ - six==1.15.0
61
+ - tensorboard==2.2.1
62
+ - tensorboard-plugin-wit==1.6.0.post3
63
+ - tensorboardx==1.9
64
+ - tifffile==2020.5.25
65
+ - torch==1.6.0
66
+ - torchvision==0.7.1
67
+ - tqdm==4.46.0
68
+ - urllib3==1.25.9
69
+ - werkzeug==1.0.1
70
+ - zipp==3.1.0
71
+ - pyaml
72
+ prefix: ~/anaconda3/envs/e4e_env
73
+
encoder4editing/infer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from argparse import Namespace
4
+ import time
5
+ import os
6
+ import sys
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+
12
+ sys.path.append(".")
13
+ sys.path.append("..")
14
+
15
+ from utils.common import tensor2im
16
+ from models.psp import pSp # we use the pSp framework to load the e4e encoder.
17
+ experiment_type = 'ffhq_encode'
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--input_image', type=str, default="", help='input image path')
21
+ args = parser.parse_args()
22
+ opts = vars(args)
23
+ print(opts)
24
+ image_path = opts["input_image"]
25
+
26
+ def get_download_model_command(file_id, file_name):
27
+ """ Get wget download command for downloading the desired model and save to directory pretrained_models. """
28
+ current_directory = os.getcwd()
29
+ save_path = "encoder4editing/saves"
30
+ if not os.path.exists(save_path):
31
+ os.makedirs(save_path)
32
+ url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
33
+ return url
34
+
35
+ MODEL_PATHS = {
36
+ "ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"},
37
+ "cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"},
38
+ "horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"},
39
+ "church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"}
40
+ }
41
+
42
+ path = MODEL_PATHS[experiment_type]
43
+ download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
44
+
45
+ EXPERIMENT_DATA_ARGS = {
46
+ "ffhq_encode": {
47
+ "model_path": "encoder4editing/e4e_ffhq_encode.pt",
48
+ "image_path": "notebooks/images/input_img.jpg"
49
+ },
50
+ "cars_encode": {
51
+ "model_path": "pretrained_models/e4e_cars_encode.pt",
52
+ "image_path": "notebooks/images/car_img.jpg"
53
+ },
54
+ "horse_encode": {
55
+ "model_path": "pretrained_models/e4e_horse_encode.pt",
56
+ "image_path": "notebooks/images/horse_img.jpg"
57
+ },
58
+ "church_encode": {
59
+ "model_path": "pretrained_models/e4e_church_encode.pt",
60
+ "image_path": "notebooks/images/church_img.jpg"
61
+ }
62
+
63
+ }
64
+ # Setup required image transformations
65
+ EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
66
+ if experiment_type == 'cars_encode':
67
+ EXPERIMENT_ARGS['transform'] = transforms.Compose([
68
+ transforms.Resize((192, 256)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
71
+ resize_dims = (256, 192)
72
+ else:
73
+ EXPERIMENT_ARGS['transform'] = transforms.Compose([
74
+ transforms.Resize((256, 256)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
77
+ resize_dims = (256, 256)
78
+
79
+
80
+ model_path = EXPERIMENT_ARGS['model_path']
81
+ ckpt = torch.load(model_path, map_location='cpu')
82
+ opts = ckpt['opts']
83
+
84
+ # update the training options
85
+ opts['checkpoint_path'] = model_path
86
+ opts= Namespace(**opts)
87
+ net = pSp(opts)
88
+ net.eval()
89
+ net.cuda()
90
+ print('Model successfully loaded!')
91
+
92
+
93
+ original_image = Image.open(image_path)
94
+ original_image = original_image.convert("RGB")
95
+
96
+ def run_alignment(image_path):
97
+ import dlib
98
+ from utils.alignment import align_face
99
+ predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat")
100
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
101
+ print("Aligned image has shape: {}".format(aligned_image.size))
102
+ return aligned_image
103
+
104
+ if experiment_type == "ffhq_encode":
105
+ input_image = run_alignment(image_path)
106
+ else:
107
+ input_image = original_image
108
+
109
+ input_image.resize(resize_dims)
110
+
111
+ img_transforms = EXPERIMENT_ARGS['transform']
112
+ transformed_image = img_transforms(input_image)
113
+
114
+ def display_alongside_source_image(result_image, source_image):
115
+ res = np.concatenate([np.array(source_image.resize(resize_dims)),
116
+ np.array(result_image.resize(resize_dims))], axis=1)
117
+ return Image.fromarray(res)
118
+
119
+ def run_on_batch(inputs, net):
120
+ images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True)
121
+ if experiment_type == 'cars_encode':
122
+ images = images[:, :, 32:224, :]
123
+ return images, latents
124
+
125
+ with torch.no_grad():
126
+ tic = time.time()
127
+ images, latents = run_on_batch(transformed_image.unsqueeze(0), net)
128
+ result_image, latent = images[0], latents[0]
129
+ toc = time.time()
130
+ print('Inference took {:.4f} seconds.'.format(toc - tic))
131
+
132
+ # Display inversion:
133
+ display_alongside_source_image(tensor2im(result_image), input_image)
134
+ np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy())
encoder4editing/metrics/LEC.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader
6
+
7
+ sys.path.append(".")
8
+ sys.path.append("..")
9
+
10
+ from configs import data_configs
11
+ from datasets.images_dataset import ImagesDataset
12
+ from utils.model_utils import setup_model
13
+
14
+
15
+ class LEC:
16
+ def __init__(self, net, is_cars=False):
17
+ """
18
+ Latent Editing Consistency metric as proposed in the main paper.
19
+ :param net: e4e model loaded over the pSp framework.
20
+ :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
21
+ """
22
+ self.net = net
23
+ self.is_cars = is_cars
24
+
25
+ def _encode(self, images):
26
+ """
27
+ Encodes the given images into StyleGAN's latent space.
28
+ :param images: Tensor of shape NxCxHxW representing the images to be encoded.
29
+ :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
30
+ """
31
+ codes = self.net.encoder(images)
32
+ assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
33
+ # normalize with respect to the center of an average face
34
+ if self.net.opts.start_from_latent_avg:
35
+ codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
36
+ return codes
37
+
38
+ def _generate(self, codes):
39
+ """
40
+ Generate the StyleGAN2 images of the given codes
41
+ :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
42
+ :return: Tensor of shape NxCxHxW representing the generated images.
43
+ """
44
+ images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
45
+ images = self.net.face_pool(images)
46
+ if self.is_cars:
47
+ images = images[:, :, 32:224, :]
48
+ return images
49
+
50
+ @staticmethod
51
+ def _filter_outliers(arr):
52
+ arr = np.array(arr)
53
+
54
+ lo = np.percentile(arr, 1, interpolation="lower")
55
+ hi = np.percentile(arr, 99, interpolation="higher")
56
+ return np.extract(
57
+ np.logical_and(lo <= arr, arr <= hi), arr
58
+ )
59
+
60
+ def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
61
+ """
62
+ Calculate the LEC metric score.
63
+ :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
64
+ :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
65
+ latent space.
66
+ :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
67
+ `edit_function` parameter.
68
+ :return: The LEC metric score.
69
+ """
70
+ distances = []
71
+ with torch.no_grad():
72
+ for batch in data_loader:
73
+ x, _ = batch
74
+ inputs = x.to(device).float()
75
+
76
+ codes = self._encode(inputs)
77
+ edited_codes = edit_function(codes)
78
+ edited_image = self._generate(edited_codes)
79
+ edited_image_inversion_codes = self._encode(edited_image)
80
+ inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
81
+
82
+ dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
83
+ distances.append(dist.to("cpu").numpy())
84
+
85
+ distances = self._filter_outliers(distances)
86
+ return distances.mean()
87
+
88
+
89
+ if __name__ == "__main__":
90
+ device = "cuda"
91
+
92
+ parser = argparse.ArgumentParser(description="LEC metric calculator")
93
+
94
+ parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
95
+ parser.add_argument("--images_dir", type=str, default=None,
96
+ help="Path to the images directory on which we calculate the LEC score")
97
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
98
+
99
+ args = parser.parse_args()
100
+ print(args)
101
+
102
+ net, opts = setup_model(args.ckpt, device)
103
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
104
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
105
+
106
+ images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
107
+ test_dataset = ImagesDataset(source_root=images_directory,
108
+ target_root=images_directory,
109
+ source_transform=transforms_dict['transform_source'],
110
+ target_transform=transforms_dict['transform_test'],
111
+ opts=opts)
112
+
113
+ data_loader = DataLoader(test_dataset,
114
+ batch_size=args.batch,
115
+ shuffle=False,
116
+ num_workers=2,
117
+ drop_last=True)
118
+
119
+ print(f'dataset length: {len(test_dataset)}')
120
+
121
+ # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
122
+ # Change the provided example according to your domain and needs.
123
+ direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
124
+
125
+ def edit_func_example(codes):
126
+ return codes + 3 * direction
127
+
128
+
129
+ def inverse_edit_func_example(codes):
130
+ return codes - 3 * direction
131
+
132
+ lec = LEC(net, is_cars='car' in opts.dataset_type)
133
+ result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
134
+ print(f"LEC: {result}")
encoder4editing/models/__init__.py ADDED
File without changes
encoder4editing/models/discriminator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class LatentCodesDiscriminator(nn.Module):
5
+ def __init__(self, style_dim, n_mlp):
6
+ super().__init__()
7
+
8
+ self.style_dim = style_dim
9
+
10
+ layers = []
11
+ for i in range(n_mlp-1):
12
+ layers.append(
13
+ nn.Linear(style_dim, style_dim)
14
+ )
15
+ layers.append(nn.LeakyReLU(0.2))
16
+ layers.append(nn.Linear(512, 1))
17
+ self.mlp = nn.Sequential(*layers)
18
+
19
+ def forward(self, w):
20
+ return self.mlp(w)
encoder4editing/models/encoders/__init__.py ADDED
File without changes
encoder4editing/models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
encoder4editing/models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
encoder4editing/models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ def forward(self, x):
174
+ x = self.input_layer(x)
175
+
176
+ modulelist = list(self.body._modules.values())
177
+ for i, l in enumerate(modulelist):
178
+ x = l(x)
179
+ if i == 6:
180
+ c1 = x
181
+ elif i == 20:
182
+ c2 = x
183
+ elif i == 23:
184
+ c3 = x
185
+
186
+ # Infer main W and duplicate it
187
+ w0 = self.styles[0](c3)
188
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
189
+ stage = self.progressive_stage.value
190
+ features = c3
191
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
192
+ if i == self.coarse_ind:
193
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
194
+ features = p2
195
+ elif i == self.middle_ind:
196
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
197
+ features = p1
198
+ delta_i = self.styles[i](features)
199
+ w[:, i] += delta_i
200
+ return w
201
+
202
+
203
+ class BackboneEncoderUsingLastLayerIntoW(Module):
204
+ def __init__(self, num_layers, mode='ir', opts=None):
205
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
206
+ print('Using BackboneEncoderUsingLastLayerIntoW')
207
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
208
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
209
+ blocks = get_blocks(num_layers)
210
+ if mode == 'ir':
211
+ unit_module = bottleneck_IR
212
+ elif mode == 'ir_se':
213
+ unit_module = bottleneck_IR_SE
214
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
215
+ BatchNorm2d(64),
216
+ PReLU(64))
217
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
218
+ self.linear = EqualLinear(512, 512, lr_mul=1)
219
+ modules = []
220
+ for block in blocks:
221
+ for bottleneck in block:
222
+ modules.append(unit_module(bottleneck.in_channel,
223
+ bottleneck.depth,
224
+ bottleneck.stride))
225
+ self.body = Sequential(*modules)
226
+ log_size = int(math.log(opts.stylegan_size, 2))
227
+ self.style_count = 2 * log_size - 2
228
+
229
+ def forward(self, x):
230
+ x = self.input_layer(x)
231
+ x = self.body(x)
232
+ x = self.output_pool(x)
233
+ x = x.view(-1, 512)
234
+ x = self.linear(x)
235
+ return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
encoder4editing/models/latent_codes_pool.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class LatentCodesPool:
6
+ """This class implements latent codes buffer that stores previously generated w latent codes.
7
+ This buffer enables us to update discriminators using a history of generated w's
8
+ rather than the ones produced by the latest encoder.
9
+ """
10
+
11
+ def __init__(self, pool_size):
12
+ """Initialize the ImagePool class
13
+ Parameters:
14
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
15
+ """
16
+ self.pool_size = pool_size
17
+ if self.pool_size > 0: # create an empty pool
18
+ self.num_ws = 0
19
+ self.ws = []
20
+
21
+ def query(self, ws):
22
+ """Return w's from the pool.
23
+ Parameters:
24
+ ws: the latest generated w's from the generator
25
+ Returns w's from the buffer.
26
+ By 50/100, the buffer will return input w's.
27
+ By 50/100, the buffer will return w's previously stored in the buffer,
28
+ and insert the current w's to the buffer.
29
+ """
30
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
31
+ return ws
32
+ return_ws = []
33
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
34
+ # w = torch.unsqueeze(image.data, 0)
35
+ if w.ndim == 2:
36
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
37
+ w = w[i]
38
+ self.handle_w(w, return_ws)
39
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
40
+ return return_ws
41
+
42
+ def handle_w(self, w, return_ws):
43
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
44
+ self.num_ws = self.num_ws + 1
45
+ self.ws.append(w)
46
+ return_ws.append(w)
47
+ else:
48
+ p = random.uniform(0, 1)
49
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
50
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
51
+ tmp = self.ws[random_id].clone()
52
+ self.ws[random_id] = w
53
+ return_ws.append(tmp)
54
+ else: # by another 50% chance, the buffer will return the current image
55
+ return_ws.append(w)
encoder4editing/models/psp.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+ import torch
5
+ from torch import nn
6
+ from models.encoders import psp_encoders
7
+ from models.stylegan2.model import Generator
8
+ from configs.paths_config import model_paths
9
+
10
+
11
+ def get_keys(d, name):
12
+ if 'state_dict' in d:
13
+ d = d['state_dict']
14
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
15
+ return d_filt
16
+
17
+
18
+ class pSp(nn.Module):
19
+
20
+ def __init__(self, opts):
21
+ super(pSp, self).__init__()
22
+ self.opts = opts
23
+ # Define architecture
24
+ self.encoder = self.set_encoder()
25
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
26
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
27
+ # Load weights if needed
28
+ self.load_weights()
29
+
30
+ def set_encoder(self):
31
+ if self.opts.encoder_type == 'GradualStyleEncoder':
32
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
33
+ elif self.opts.encoder_type == 'Encoder4Editing':
34
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
35
+ elif self.opts.encoder_type == 'SingleStyleCodeEncoder':
36
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
37
+ else:
38
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
39
+ return encoder
40
+
41
+ def load_weights(self):
42
+ if self.opts.checkpoint_path is not None:
43
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
44
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
45
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
46
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
47
+ self.__load_latent_avg(ckpt)
48
+ else:
49
+ print('Loading encoders weights from irse50!')
50
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
51
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
52
+ print('Loading decoder weights from pretrained!')
53
+ ckpt = torch.load(self.opts.stylegan_weights)
54
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
55
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
56
+
57
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
58
+ inject_latent=None, return_latents=False, alpha=None):
59
+ if input_code:
60
+ codes = x
61
+ else:
62
+ codes = self.encoder(x)
63
+ # normalize with respect to the center of an average face
64
+ if self.opts.start_from_latent_avg:
65
+ if codes.ndim == 2:
66
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
67
+ else:
68
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
69
+
70
+ if latent_mask is not None:
71
+ for i in latent_mask:
72
+ if inject_latent is not None:
73
+ if alpha is not None:
74
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
75
+ else:
76
+ codes[:, i] = inject_latent[:, i]
77
+ else:
78
+ codes[:, i] = 0
79
+
80
+ input_is_latent = not input_code
81
+ images, result_latent = self.decoder([codes],
82
+ input_is_latent=input_is_latent,
83
+ randomize_noise=randomize_noise,
84
+ return_latents=return_latents)
85
+
86
+ if resize:
87
+ images = self.face_pool(images)
88
+
89
+ if return_latents:
90
+ return images, result_latent
91
+ else:
92
+ return images
93
+
94
+ def __load_latent_avg(self, ckpt, repeat=None):
95
+ if 'latent_avg' in ckpt:
96
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
97
+ if repeat is not None:
98
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
99
+ else:
100
+ self.latent_avg = None
encoder4editing/models/stylegan2/__init__.py ADDED
File without changes
encoder4editing/models/stylegan2/model.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+
9
+
10
+ class PixelNorm(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, input):
15
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
16
+
17
+
18
+ def make_kernel(k):
19
+ k = torch.tensor(k, dtype=torch.float32)
20
+
21
+ if k.ndim == 1:
22
+ k = k[None, :] * k[:, None]
23
+
24
+ k /= k.sum()
25
+
26
+ return k
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, kernel, factor=2):
31
+ super().__init__()
32
+
33
+ self.factor = factor
34
+ kernel = make_kernel(kernel) * (factor ** 2)
35
+ self.register_buffer('kernel', kernel)
36
+
37
+ p = kernel.shape[0] - factor
38
+
39
+ pad0 = (p + 1) // 2 + factor - 1
40
+ pad1 = p // 2
41
+
42
+ self.pad = (pad0, pad1)
43
+
44
+ def forward(self, input):
45
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
46
+
47
+ return out
48
+
49
+
50
+ class Downsample(nn.Module):
51
+ def __init__(self, kernel, factor=2):
52
+ super().__init__()
53
+
54
+ self.factor = factor
55
+ kernel = make_kernel(kernel)
56
+ self.register_buffer('kernel', kernel)
57
+
58
+ p = kernel.shape[0] - factor
59
+
60
+ pad0 = (p + 1) // 2
61
+ pad1 = p // 2
62
+
63
+ self.pad = (pad0, pad1)
64
+
65
+ def forward(self, input):
66
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
67
+
68
+ return out
69
+
70
+
71
+ class Blur(nn.Module):
72
+ def __init__(self, kernel, pad, upsample_factor=1):
73
+ super().__init__()
74
+
75
+ kernel = make_kernel(kernel)
76
+
77
+ if upsample_factor > 1:
78
+ kernel = kernel * (upsample_factor ** 2)
79
+
80
+ self.register_buffer('kernel', kernel)
81
+
82
+ self.pad = pad
83
+
84
+ def forward(self, input):
85
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
86
+
87
+ return out
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(
92
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
93
+ ):
94
+ super().__init__()
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
98
+ )
99
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
100
+
101
+ self.stride = stride
102
+ self.padding = padding
103
+
104
+ if bias:
105
+ self.bias = nn.Parameter(torch.zeros(out_channel))
106
+
107
+ else:
108
+ self.bias = None
109
+
110
+ def forward(self, input):
111
+ out = F.conv2d(
112
+ input,
113
+ self.weight * self.scale,
114
+ bias=self.bias,
115
+ stride=self.stride,
116
+ padding=self.padding,
117
+ )
118
+
119
+ return out
120
+
121
+ def __repr__(self):
122
+ return (
123
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
124
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
125
+ )
126
+
127
+
128
+ class EqualLinear(nn.Module):
129
+ def __init__(
130
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
131
+ ):
132
+ super().__init__()
133
+
134
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
135
+
136
+ if bias:
137
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
138
+
139
+ else:
140
+ self.bias = None
141
+
142
+ self.activation = activation
143
+
144
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
145
+ self.lr_mul = lr_mul
146
+
147
+ def forward(self, input):
148
+ if self.activation:
149
+ out = F.linear(input, self.weight * self.scale)
150
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
151
+
152
+ else:
153
+ out = F.linear(
154
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
+ )
156
+
157
+ return out
158
+
159
+ def __repr__(self):
160
+ return (
161
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
162
+ )
163
+
164
+
165
+ class ScaledLeakyReLU(nn.Module):
166
+ def __init__(self, negative_slope=0.2):
167
+ super().__init__()
168
+
169
+ self.negative_slope = negative_slope
170
+
171
+ def forward(self, input):
172
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
173
+
174
+ return out * math.sqrt(2)
175
+
176
+
177
+ class ModulatedConv2d(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channel,
181
+ out_channel,
182
+ kernel_size,
183
+ style_dim,
184
+ demodulate=True,
185
+ upsample=False,
186
+ downsample=False,
187
+ blur_kernel=[1, 3, 3, 1],
188
+ ):
189
+ super().__init__()
190
+
191
+ self.eps = 1e-8
192
+ self.kernel_size = kernel_size
193
+ self.in_channel = in_channel
194
+ self.out_channel = out_channel
195
+ self.upsample = upsample
196
+ self.downsample = downsample
197
+
198
+ if upsample:
199
+ factor = 2
200
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
201
+ pad0 = (p + 1) // 2 + factor - 1
202
+ pad1 = p // 2 + 1
203
+
204
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
205
+
206
+ if downsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
209
+ pad0 = (p + 1) // 2
210
+ pad1 = p // 2
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
213
+
214
+ fan_in = in_channel * kernel_size ** 2
215
+ self.scale = 1 / math.sqrt(fan_in)
216
+ self.padding = kernel_size // 2
217
+
218
+ self.weight = nn.Parameter(
219
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
220
+ )
221
+
222
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
223
+
224
+ self.demodulate = demodulate
225
+
226
+ def __repr__(self):
227
+ return (
228
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
229
+ f'upsample={self.upsample}, downsample={self.downsample})'
230
+ )
231
+
232
+ def forward(self, input, style):
233
+ batch, in_channel, height, width = input.shape
234
+
235
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
236
+ weight = self.scale * self.weight * style
237
+
238
+ if self.demodulate:
239
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
240
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
241
+
242
+ weight = weight.view(
243
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
244
+ )
245
+
246
+ if self.upsample:
247
+ input = input.view(1, batch * in_channel, height, width)
248
+ weight = weight.view(
249
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
250
+ )
251
+ weight = weight.transpose(1, 2).reshape(
252
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
253
+ )
254
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
255
+ _, _, height, width = out.shape
256
+ out = out.view(batch, self.out_channel, height, width)
257
+ out = self.blur(out)
258
+
259
+ elif self.downsample:
260
+ input = self.blur(input)
261
+ _, _, height, width = input.shape
262
+ input = input.view(1, batch * in_channel, height, width)
263
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
264
+ _, _, height, width = out.shape
265
+ out = out.view(batch, self.out_channel, height, width)
266
+
267
+ else:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
+ _, _, height, width = out.shape
271
+ out = out.view(batch, self.out_channel, height, width)
272
+
273
+ return out
274
+
275
+
276
+ class NoiseInjection(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ self.weight = nn.Parameter(torch.zeros(1))
281
+
282
+ def forward(self, image, noise=None):
283
+ if noise is None:
284
+ batch, _, height, width = image.shape
285
+ noise = image.new_empty(batch, 1, height, width).normal_()
286
+
287
+ return image + self.weight * noise
288
+
289
+
290
+ class ConstantInput(nn.Module):
291
+ def __init__(self, channel, size=4):
292
+ super().__init__()
293
+
294
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
+
296
+ def forward(self, input):
297
+ batch = input.shape[0]
298
+ out = self.input.repeat(batch, 1, 1, 1)
299
+
300
+ return out
301
+
302
+
303
+ class StyledConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channel,
307
+ out_channel,
308
+ kernel_size,
309
+ style_dim,
310
+ upsample=False,
311
+ blur_kernel=[1, 3, 3, 1],
312
+ demodulate=True,
313
+ ):
314
+ super().__init__()
315
+
316
+ self.conv = ModulatedConv2d(
317
+ in_channel,
318
+ out_channel,
319
+ kernel_size,
320
+ style_dim,
321
+ upsample=upsample,
322
+ blur_kernel=blur_kernel,
323
+ demodulate=demodulate,
324
+ )
325
+
326
+ self.noise = NoiseInjection()
327
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
328
+ # self.activate = ScaledLeakyReLU(0.2)
329
+ self.activate = FusedLeakyReLU(out_channel)
330
+
331
+ def forward(self, input, style, noise=None):
332
+ out = self.conv(input, style)
333
+ out = self.noise(out, noise=noise)
334
+ # out = out + self.bias
335
+ out = self.activate(out)
336
+
337
+ return out
338
+
339
+
340
+ class ToRGB(nn.Module):
341
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
342
+ super().__init__()
343
+
344
+ if upsample:
345
+ self.upsample = Upsample(blur_kernel)
346
+
347
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
348
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
349
+
350
+ def forward(self, input, style, skip=None):
351
+ out = self.conv(input, style)
352
+ out = out + self.bias
353
+
354
+ if skip is not None:
355
+ skip = self.upsample(skip)
356
+
357
+ out = out + skip
358
+
359
+ return out
360
+
361
+
362
+ class Generator(nn.Module):
363
+ def __init__(
364
+ self,
365
+ size,
366
+ style_dim,
367
+ n_mlp,
368
+ channel_multiplier=2,
369
+ blur_kernel=[1, 3, 3, 1],
370
+ lr_mlp=0.01,
371
+ ):
372
+ super().__init__()
373
+
374
+ self.size = size
375
+
376
+ self.style_dim = style_dim
377
+
378
+ layers = [PixelNorm()]
379
+
380
+ for i in range(n_mlp):
381
+ layers.append(
382
+ EqualLinear(
383
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
384
+ )
385
+ )
386
+
387
+ self.style = nn.Sequential(*layers)
388
+
389
+ self.channels = {
390
+ 4: 512,
391
+ 8: 512,
392
+ 16: 512,
393
+ 32: 512,
394
+ 64: 256 * channel_multiplier,
395
+ 128: 128 * channel_multiplier,
396
+ 256: 64 * channel_multiplier,
397
+ 512: 32 * channel_multiplier,
398
+ 1024: 16 * channel_multiplier,
399
+ }
400
+
401
+ self.input = ConstantInput(self.channels[4])
402
+ self.conv1 = StyledConv(
403
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
404
+ )
405
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
406
+
407
+ self.log_size = int(math.log(size, 2))
408
+ self.num_layers = (self.log_size - 2) * 2 + 1
409
+
410
+ self.convs = nn.ModuleList()
411
+ self.upsamples = nn.ModuleList()
412
+ self.to_rgbs = nn.ModuleList()
413
+ self.noises = nn.Module()
414
+
415
+ in_channel = self.channels[4]
416
+
417
+ for layer_idx in range(self.num_layers):
418
+ res = (layer_idx + 5) // 2
419
+ shape = [1, 1, 2 ** res, 2 ** res]
420
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
421
+
422
+ for i in range(3, self.log_size + 1):
423
+ out_channel = self.channels[2 ** i]
424
+
425
+ self.convs.append(
426
+ StyledConv(
427
+ in_channel,
428
+ out_channel,
429
+ 3,
430
+ style_dim,
431
+ upsample=True,
432
+ blur_kernel=blur_kernel,
433
+ )
434
+ )
435
+
436
+ self.convs.append(
437
+ StyledConv(
438
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
439
+ )
440
+ )
441
+
442
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
443
+
444
+ in_channel = out_channel
445
+
446
+ self.n_latent = self.log_size * 2 - 2
447
+
448
+ def make_noise(self):
449
+ device = self.input.input.device
450
+
451
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
452
+
453
+ for i in range(3, self.log_size + 1):
454
+ for _ in range(2):
455
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
456
+
457
+ return noises
458
+
459
+ def mean_latent(self, n_latent):
460
+ latent_in = torch.randn(
461
+ n_latent, self.style_dim, device=self.input.input.device
462
+ )
463
+ latent = self.style(latent_in).mean(0, keepdim=True)
464
+
465
+ return latent
466
+
467
+ def get_latent(self, input):
468
+ return self.style(input)
469
+
470
+ def forward(
471
+ self,
472
+ styles,
473
+ return_latents=False,
474
+ return_features=False,
475
+ inject_index=None,
476
+ truncation=1,
477
+ truncation_latent=None,
478
+ input_is_latent=False,
479
+ noise=None,
480
+ randomize_noise=True,
481
+ ):
482
+ if not input_is_latent:
483
+ styles = [self.style(s) for s in styles]
484
+
485
+ if noise is None:
486
+ if randomize_noise:
487
+ noise = [None] * self.num_layers
488
+ else:
489
+ noise = [
490
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
491
+ ]
492
+
493
+ if truncation < 1:
494
+ style_t = []
495
+
496
+ for style in styles:
497
+ style_t.append(
498
+ truncation_latent + truncation * (style - truncation_latent)
499
+ )
500
+
501
+ styles = style_t
502
+
503
+ if len(styles) < 2:
504
+ inject_index = self.n_latent
505
+
506
+ if styles[0].ndim < 3:
507
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508
+ else:
509
+ latent = styles[0]
510
+
511
+ else:
512
+ if inject_index is None:
513
+ inject_index = random.randint(1, self.n_latent - 1)
514
+
515
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
516
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
517
+
518
+ latent = torch.cat([latent, latent2], 1)
519
+
520
+ out = self.input(latent)
521
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
522
+
523
+ skip = self.to_rgb1(out, latent[:, 1])
524
+
525
+ i = 1
526
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
527
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
528
+ ):
529
+ out = conv1(out, latent[:, i], noise=noise1)
530
+ out = conv2(out, latent[:, i + 1], noise=noise2)
531
+ skip = to_rgb(out, latent[:, i + 2], skip)
532
+
533
+ i += 2
534
+
535
+ image = skip
536
+
537
+ if return_latents:
538
+ return image, latent
539
+ elif return_features:
540
+ return image, out
541
+ else:
542
+ return image, None
543
+
544
+
545
+ class ConvLayer(nn.Sequential):
546
+ def __init__(
547
+ self,
548
+ in_channel,
549
+ out_channel,
550
+ kernel_size,
551
+ downsample=False,
552
+ blur_kernel=[1, 3, 3, 1],
553
+ bias=True,
554
+ activate=True,
555
+ ):
556
+ layers = []
557
+
558
+ if downsample:
559
+ factor = 2
560
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
561
+ pad0 = (p + 1) // 2
562
+ pad1 = p // 2
563
+
564
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
565
+
566
+ stride = 2
567
+ self.padding = 0
568
+
569
+ else:
570
+ stride = 1
571
+ self.padding = kernel_size // 2
572
+
573
+ layers.append(
574
+ EqualConv2d(
575
+ in_channel,
576
+ out_channel,
577
+ kernel_size,
578
+ padding=self.padding,
579
+ stride=stride,
580
+ bias=bias and not activate,
581
+ )
582
+ )
583
+
584
+ if activate:
585
+ if bias:
586
+ layers.append(FusedLeakyReLU(out_channel))
587
+
588
+ else:
589
+ layers.append(ScaledLeakyReLU(0.2))
590
+
591
+ super().__init__(*layers)
592
+
593
+
594
+ class ResBlock(nn.Module):
595
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
596
+ super().__init__()
597
+
598
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
599
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
600
+
601
+ self.skip = ConvLayer(
602
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
603
+ )
604
+
605
+ def forward(self, input):
606
+ out = self.conv1(input)
607
+ out = self.conv2(out)
608
+
609
+ skip = self.skip(input)
610
+ out = (out + skip) / math.sqrt(2)
611
+
612
+ return out
613
+
614
+
615
+ class Discriminator(nn.Module):
616
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
617
+ super().__init__()
618
+
619
+ channels = {
620
+ 4: 512,
621
+ 8: 512,
622
+ 16: 512,
623
+ 32: 512,
624
+ 64: 256 * channel_multiplier,
625
+ 128: 128 * channel_multiplier,
626
+ 256: 64 * channel_multiplier,
627
+ 512: 32 * channel_multiplier,
628
+ 1024: 16 * channel_multiplier,
629
+ }
630
+
631
+ convs = [ConvLayer(3, channels[size], 1)]
632
+
633
+ log_size = int(math.log(size, 2))
634
+
635
+ in_channel = channels[size]
636
+
637
+ for i in range(log_size, 2, -1):
638
+ out_channel = channels[2 ** (i - 1)]
639
+
640
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
641
+
642
+ in_channel = out_channel
643
+
644
+ self.convs = nn.Sequential(*convs)
645
+
646
+ self.stddev_group = 4
647
+ self.stddev_feat = 1
648
+
649
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
650
+ self.final_linear = nn.Sequential(
651
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
652
+ EqualLinear(channels[4], 1),
653
+ )
654
+
655
+ def forward(self, input):
656
+ out = self.convs(input)
657
+
658
+ batch, channel, height, width = out.shape
659
+ group = min(batch, self.stddev_group)
660
+ stddev = out.view(
661
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
662
+ )
663
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
664
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
665
+ stddev = stddev.repeat(group, 1, height, width)
666
+ out = torch.cat([out, stddev], 1)
667
+
668
+ out = self.final_conv(out)
669
+
670
+ out = out.view(batch, -1)
671
+ out = self.final_linear(out)
672
+
673
+ return out
encoder4editing/models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
encoder4editing/models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ fused = load(
10
+ 'fused',
11
+ sources=[
12
+ os.path.join(module_path, 'fused_bias_act.cpp'),
13
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class FusedLeakyReLUFunctionBackward(Function):
19
+ @staticmethod
20
+ def forward(ctx, grad_output, out, negative_slope, scale):
21
+ ctx.save_for_backward(out)
22
+ ctx.negative_slope = negative_slope
23
+ ctx.scale = scale
24
+
25
+ empty = grad_output.new_empty(0)
26
+
27
+ grad_input = fused.fused_bias_act(
28
+ grad_output, empty, out, 3, 1, negative_slope, scale
29
+ )
30
+
31
+ dim = [0]
32
+
33
+ if grad_input.ndim > 2:
34
+ dim += list(range(2, grad_input.ndim))
35
+
36
+ grad_bias = grad_input.sum(dim).detach()
37
+
38
+ return grad_input, grad_bias
39
+
40
+ @staticmethod
41
+ def backward(ctx, gradgrad_input, gradgrad_bias):
42
+ out, = ctx.saved_tensors
43
+ gradgrad_out = fused.fused_bias_act(
44
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45
+ )
46
+
47
+ return gradgrad_out, None, None, None
48
+
49
+
50
+ class FusedLeakyReLUFunction(Function):
51
+ @staticmethod
52
+ def forward(ctx, input, bias, negative_slope, scale):
53
+ empty = input.new_empty(0)
54
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55
+ ctx.save_for_backward(out)
56
+ ctx.negative_slope = negative_slope
57
+ ctx.scale = scale
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ out, = ctx.saved_tensors
64
+
65
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66
+ grad_output, out, ctx.negative_slope, ctx.scale
67
+ )
68
+
69
+ return grad_input, grad_bias, None, None
70
+
71
+
72
+ class FusedLeakyReLU(nn.Module):
73
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74
+ super().__init__()
75
+
76
+ self.bias = nn.Parameter(torch.zeros(channel))
77
+ self.negative_slope = negative_slope
78
+ self.scale = scale
79
+
80
+ def forward(self, input):
81
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82
+
83
+
84
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
encoder4editing/models/stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
encoder4editing/models/stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }