d. nye commited on
Commit
67f64a6
1 Parent(s): d0919c2

Initial release

Browse files
Files changed (2) hide show
  1. app.py +624 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Huggingface.co UI Script
2
+ # Using gradio to present a simple UI to select a random seed and generate an NFT
3
+
4
+ import sys
5
+ from subprocess import call
6
+ def run_cmd(command):
7
+ try:
8
+ print(command)
9
+ call(command, shell=True)
10
+ except KeyboardInterrupt:
11
+ print("Process interrupted")
12
+ sys.exit(1)
13
+
14
+ print("⬇️ Installing latest gradio==2.4.7b9")
15
+ run_cmd("pip install --upgrade pip")
16
+ run_cmd('pip install gradio==2.4.7b9')
17
+
18
+ import gradio as gr
19
+ import os
20
+ import random
21
+ import math
22
+ import numpy as np
23
+ import matplotlib.pyplot as plt
24
+
25
+ from enum import Enum
26
+ from glob import glob
27
+ from functools import partial
28
+
29
+ import tensorflow as tf
30
+ from tensorflow import keras
31
+ from tensorflow.keras import layers
32
+ from tensorflow.keras.models import Sequential
33
+ from tensorflow_addons.layers import InstanceNormalization
34
+
35
+ import tensorflow_datasets as tfds
36
+
37
+ # Model Definition
38
+
39
+ def log2(x):
40
+ return int(np.log2(x))
41
+
42
+
43
+ def resize_image(res, sample):
44
+ print("Call resize_image...")
45
+ image = sample["image"]
46
+ # only donwsampling, so use nearest neighbor that is faster to run
47
+ image = tf.image.resize(
48
+ image, (res, res), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
49
+ )
50
+ image = tf.cast(image, tf.float32) / 127.5 - 1.0
51
+ return image
52
+
53
+
54
+ def create_dataloader(res):
55
+ batch_size = batch_sizes[log2(res)]
56
+ dl = ds_train.map(partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE)
57
+ dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
58
+ return dl
59
+
60
+ def fade_in(alpha, a, b):
61
+ return alpha * a + (1.0 - alpha) * b
62
+
63
+
64
+ def wasserstein_loss(y_true, y_pred):
65
+ return -tf.reduce_mean(y_true * y_pred)
66
+
67
+
68
+ def pixel_norm(x, epsilon=1e-8):
69
+ return x / tf.math.sqrt(tf.reduce_mean(x ** 2, axis=-1, keepdims=True) + epsilon)
70
+
71
+
72
+ def minibatch_std(input_tensor, epsilon=1e-8):
73
+ n, h, w, c = tf.shape(input_tensor)
74
+ group_size = tf.minimum(4, n)
75
+ x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
76
+ group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
77
+ group_std = tf.sqrt(group_var + epsilon)
78
+ avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
79
+ x = tf.tile(avg_std, [group_size, h, w, 1])
80
+ return tf.concat([input_tensor, x], axis=-1)
81
+
82
+
83
+ class EqualizedConv(layers.Layer):
84
+ def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
85
+ super(EqualizedConv, self).__init__(**kwargs)
86
+ self.kernel = kernel
87
+ self.out_channels = out_channels
88
+ self.gain = gain
89
+ self.pad = kernel != 1
90
+
91
+ def build(self, input_shape):
92
+ self.in_channels = input_shape[-1]
93
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
94
+ self.w = self.add_weight(
95
+ shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
96
+ initializer=initializer,
97
+ trainable=True,
98
+ name="kernel",
99
+ )
100
+ self.b = self.add_weight(
101
+ shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
102
+ )
103
+ fan_in = self.kernel * self.kernel * self.in_channels
104
+ self.scale = tf.sqrt(self.gain / fan_in)
105
+
106
+ def call(self, inputs):
107
+ if self.pad:
108
+ x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
109
+ else:
110
+ x = inputs
111
+ output = (
112
+ tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
113
+ )
114
+ return output
115
+
116
+
117
+ class EqualizedDense(layers.Layer):
118
+ def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
119
+ super(EqualizedDense, self).__init__(**kwargs)
120
+ self.units = units
121
+ self.gain = gain
122
+ self.learning_rate_multiplier = learning_rate_multiplier
123
+
124
+ def build(self, input_shape):
125
+ self.in_channels = input_shape[-1]
126
+ initializer = keras.initializers.RandomNormal(
127
+ mean=0.0, stddev=1.0 / self.learning_rate_multiplier
128
+ )
129
+ self.w = self.add_weight(
130
+ shape=[self.in_channels, self.units],
131
+ initializer=initializer,
132
+ trainable=True,
133
+ name="kernel",
134
+ )
135
+ self.b = self.add_weight(
136
+ shape=(self.units,), initializer="zeros", trainable=True, name="bias"
137
+ )
138
+ fan_in = self.in_channels
139
+ self.scale = tf.sqrt(self.gain / fan_in)
140
+
141
+ def call(self, inputs):
142
+ output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
143
+ return output * self.learning_rate_multiplier
144
+
145
+
146
+ class AddNoise(layers.Layer):
147
+ def build(self, input_shape):
148
+ n, h, w, c = input_shape[0]
149
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
150
+ self.b = self.add_weight(
151
+ shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
152
+ )
153
+
154
+ def call(self, inputs):
155
+ x, noise = inputs
156
+ output = x + self.b * noise
157
+ return output
158
+
159
+
160
+ class AdaIN(layers.Layer):
161
+ def __init__(self, gain=1, **kwargs):
162
+ super(AdaIN, self).__init__(**kwargs)
163
+ self.gain = gain
164
+
165
+ def build(self, input_shapes):
166
+ x_shape = input_shapes[0]
167
+ w_shape = input_shapes[1]
168
+
169
+ self.w_channels = w_shape[-1]
170
+ self.x_channels = x_shape[-1]
171
+
172
+ self.dense_1 = EqualizedDense(self.x_channels, gain=1)
173
+ self.dense_2 = EqualizedDense(self.x_channels, gain=1)
174
+
175
+ def call(self, inputs):
176
+ x, w = inputs
177
+ ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
178
+ yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
179
+ return ys * x + yb
180
+
181
+ def Mapping(num_stages, input_shape=512):
182
+ z = layers.Input(shape=(input_shape))
183
+ w = pixel_norm(z)
184
+ for i in range(8):
185
+ w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
186
+ w = layers.LeakyReLU(0.2)(w)
187
+ w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
188
+ return keras.Model(z, w, name="mapping")
189
+
190
+
191
+ class Generator:
192
+ def __init__(self, start_res_log2, target_res_log2):
193
+ self.start_res_log2 = start_res_log2
194
+ self.target_res_log2 = target_res_log2
195
+ self.num_stages = target_res_log2 - start_res_log2 + 1
196
+ # list of generator blocks at increasing resolution
197
+ self.g_blocks = []
198
+ # list of layers to convert g_block activation to RGB
199
+ self.to_rgb = []
200
+ # list of noise input of different resolutions into g_blocks
201
+ self.noise_inputs = []
202
+ # filter size to use at each stage, keys are log2(resolution)
203
+ self.filter_nums = {
204
+ 0: 512,
205
+ 1: 512,
206
+ 2: 512, # 4x4
207
+ 3: 512, # 8x8
208
+ 4: 512, # 16x16
209
+ 5: 512, # 32x32
210
+ 6: 256, # 64x64
211
+ 7: 128, # 128x128
212
+ 8: 64, # 256x256
213
+ 9: 32, # 512x512
214
+ 10: 16,
215
+ } # 1024x1024
216
+
217
+ start_res = 2 ** start_res_log2
218
+ self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
219
+ self.g_input = layers.Input(self.input_shape, name="generator_input")
220
+
221
+ for i in range(start_res_log2, target_res_log2 + 1):
222
+ filter_num = self.filter_nums[i]
223
+ res = 2 ** i
224
+ self.noise_inputs.append(
225
+ layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
226
+ )
227
+ to_rgb = Sequential(
228
+ [
229
+ layers.InputLayer(input_shape=(res, res, filter_num)),
230
+ EqualizedConv(3, 1, gain=1),
231
+ ],
232
+ name=f"to_rgb_{res}x{res}",
233
+ )
234
+ self.to_rgb.append(to_rgb)
235
+ is_base = i == self.start_res_log2
236
+ if is_base:
237
+ input_shape = (res, res, self.filter_nums[i - 1])
238
+ else:
239
+ input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
240
+ g_block = self.build_block(
241
+ filter_num, res=res, input_shape=input_shape, is_base=is_base
242
+ )
243
+ self.g_blocks.append(g_block)
244
+
245
+ def build_block(self, filter_num, res, input_shape, is_base):
246
+ input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
247
+ noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
248
+ w = layers.Input(shape=512)
249
+ x = input_tensor
250
+
251
+ if not is_base:
252
+ x = layers.UpSampling2D((2, 2))(x)
253
+ x = EqualizedConv(filter_num, 3)(x)
254
+
255
+ x = AddNoise()([x, noise])
256
+ x = layers.LeakyReLU(0.2)(x)
257
+ x = InstanceNormalization()(x)
258
+ x = AdaIN()([x, w])
259
+
260
+ x = EqualizedConv(filter_num, 3)(x)
261
+ x = AddNoise()([x, noise])
262
+ x = layers.LeakyReLU(0.2)(x)
263
+ x = InstanceNormalization()(x)
264
+ x = AdaIN()([x, w])
265
+ return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
266
+
267
+ def grow(self, res_log2):
268
+ res = 2 ** res_log2
269
+
270
+ num_stages = res_log2 - self.start_res_log2 + 1
271
+ w = layers.Input(shape=(self.num_stages, 512), name="w")
272
+
273
+ alpha = layers.Input(shape=(1), name="g_alpha")
274
+ x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
275
+
276
+ if num_stages == 1:
277
+ rgb = self.to_rgb[0](x)
278
+ else:
279
+ for i in range(1, num_stages - 1):
280
+
281
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
282
+
283
+ old_rgb = self.to_rgb[num_stages - 2](x)
284
+ old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
285
+
286
+ i = num_stages - 1
287
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
288
+
289
+ new_rgb = self.to_rgb[i](x)
290
+
291
+ rgb = fade_in(alpha[0], new_rgb, old_rgb)
292
+
293
+ return keras.Model(
294
+ [self.g_input, w, self.noise_inputs, alpha],
295
+ rgb,
296
+ name=f"generator_{res}_x_{res}",
297
+ )
298
+
299
+
300
+ class Discriminator:
301
+ def __init__(self, start_res_log2, target_res_log2):
302
+ self.start_res_log2 = start_res_log2
303
+ self.target_res_log2 = target_res_log2
304
+ self.num_stages = target_res_log2 - start_res_log2 + 1
305
+ # filter size to use at each stage, keys are log2(resolution)
306
+ self.filter_nums = {
307
+ 0: 512,
308
+ 1: 512,
309
+ 2: 512, # 4x4
310
+ 3: 512, # 8x8
311
+ 4: 512, # 16x16
312
+ 5: 512, # 32x32
313
+ 6: 256, # 64x64
314
+ 7: 128, # 128x128
315
+ 8: 64, # 256x256
316
+ 9: 32, # 512x512
317
+ 10: 16,
318
+ } # 1024x1024
319
+ # list of discriminator blocks at increasing resolution
320
+ self.d_blocks = []
321
+ # list of layers to convert RGB into activation for d_blocks inputs
322
+ self.from_rgb = []
323
+
324
+ for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
325
+ res = 2 ** res_log2
326
+ filter_num = self.filter_nums[res_log2]
327
+ from_rgb = Sequential(
328
+ [
329
+ layers.InputLayer(
330
+ input_shape=(res, res, 3), name=f"from_rgb_input_{res}"
331
+ ),
332
+ EqualizedConv(filter_num, 1),
333
+ layers.LeakyReLU(0.2),
334
+ ],
335
+ name=f"from_rgb_{res}",
336
+ )
337
+
338
+ self.from_rgb.append(from_rgb)
339
+
340
+ input_shape = (res, res, filter_num)
341
+ if len(self.d_blocks) == 0:
342
+ d_block = self.build_base(filter_num, res)
343
+ else:
344
+ d_block = self.build_block(
345
+ filter_num, self.filter_nums[res_log2 - 1], res
346
+ )
347
+
348
+ self.d_blocks.append(d_block)
349
+
350
+ def build_base(self, filter_num, res):
351
+ input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
352
+ x = minibatch_std(input_tensor)
353
+ x = EqualizedConv(filter_num, 3)(x)
354
+ x = layers.LeakyReLU(0.2)(x)
355
+ x = layers.Flatten()(x)
356
+ x = EqualizedDense(filter_num)(x)
357
+ x = layers.LeakyReLU(0.2)(x)
358
+ x = EqualizedDense(1)(x)
359
+ return keras.Model(input_tensor, x, name=f"d_{res}")
360
+
361
+ def build_block(self, filter_num_1, filter_num_2, res):
362
+ input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
363
+ x = EqualizedConv(filter_num_1, 3)(input_tensor)
364
+ x = layers.LeakyReLU(0.2)(x)
365
+ x = EqualizedConv(filter_num_2)(x)
366
+ x = layers.LeakyReLU(0.2)(x)
367
+ x = layers.AveragePooling2D((2, 2))(x)
368
+ return keras.Model(input_tensor, x, name=f"d_{res}")
369
+
370
+ def grow(self, res_log2):
371
+ res = 2 ** res_log2
372
+ idx = res_log2 - self.start_res_log2
373
+ alpha = layers.Input(shape=(1), name="d_alpha")
374
+ input_image = layers.Input(shape=(res, res, 3), name="input_image")
375
+ x = self.from_rgb[idx](input_image)
376
+ x = self.d_blocks[idx](x)
377
+ if idx > 0:
378
+ idx -= 1
379
+ downsized_image = layers.AveragePooling2D((2, 2))(input_image)
380
+ y = self.from_rgb[idx](downsized_image)
381
+ x = fade_in(alpha[0], x, y)
382
+
383
+ for i in range(idx, -1, -1):
384
+ x = self.d_blocks[i](x)
385
+ return keras.Model([input_image, alpha], x, name=f"discriminator_{res}_x_{res}")
386
+
387
+ class StyleGAN(tf.keras.Model):
388
+ def __init__(self, z_dim=512, target_res=64, start_res=4):
389
+ super(StyleGAN, self).__init__()
390
+ self.z_dim = z_dim
391
+
392
+ self.target_res_log2 = log2(target_res)
393
+ self.start_res_log2 = log2(start_res)
394
+ self.current_res_log2 = self.target_res_log2
395
+ self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
396
+
397
+ self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
398
+
399
+ self.mapping = Mapping(num_stages=self.num_stages)
400
+ self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
401
+ self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
402
+ self.g_input_shape = self.g_builder.input_shape
403
+
404
+ self.phase = None
405
+ self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
406
+
407
+ self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
408
+
409
+ def grow_model(self, res):
410
+ tf.keras.backend.clear_session()
411
+ res_log2 = log2(res)
412
+ self.generator = self.g_builder.grow(res_log2)
413
+ self.discriminator = self.d_builder.grow(res_log2)
414
+ self.current_res_log2 = res_log2
415
+ print(f"\nModel resolution:{res}x{res}")
416
+
417
+ def compile(
418
+ self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
419
+ ):
420
+ self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
421
+ self.steps_per_epoch = steps_per_epoch
422
+ if res != 2 ** self.current_res_log2:
423
+ self.grow_model(res)
424
+ self.d_optimizer = d_optimizer
425
+ self.g_optimizer = g_optimizer
426
+
427
+ self.train_step_counter.assign(0)
428
+ self.phase = phase
429
+ self.d_loss_metric = keras.metrics.Mean(name="d_loss")
430
+ self.g_loss_metric = keras.metrics.Mean(name="g_loss")
431
+ super(StyleGAN, self).compile(*args, **kwargs)
432
+
433
+ @property
434
+ def metrics(self):
435
+ return [self.d_loss_metric, self.g_loss_metric]
436
+
437
+ def generate_noise(self, batch_size):
438
+ noise = [
439
+ tf.random.normal((batch_size, 2 ** res, 2 ** res, 1))
440
+ for res in range(self.start_res_log2, self.target_res_log2 + 1)
441
+ ]
442
+ return noise
443
+
444
+ def gradient_loss(self, grad):
445
+ loss = tf.square(grad)
446
+ loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
447
+ loss = tf.sqrt(loss)
448
+ loss = tf.reduce_mean(tf.square(loss - 1))
449
+ return loss
450
+
451
+ def train_step(self, real_images):
452
+
453
+ self.train_step_counter.assign_add(1)
454
+
455
+ if self.phase == "TRANSITION":
456
+ self.alpha.assign(
457
+ tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
458
+ )
459
+ elif self.phase == "STABLE":
460
+ self.alpha.assign(1.0)
461
+ else:
462
+ raise NotImplementedError
463
+ alpha = tf.expand_dims(self.alpha, 0)
464
+ batch_size = tf.shape(real_images)[0]
465
+ real_labels = tf.ones(batch_size)
466
+ fake_labels = -tf.ones(batch_size)
467
+
468
+ z = tf.random.normal((batch_size, self.z_dim))
469
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
470
+ noise = self.generate_noise(batch_size)
471
+
472
+ # generator
473
+ with tf.GradientTape() as g_tape:
474
+ w = self.mapping(z)
475
+ fake_images = self.generator([const_input, w, noise, alpha])
476
+ pred_fake = self.discriminator([fake_images, alpha])
477
+ g_loss = wasserstein_loss(real_labels, pred_fake)
478
+
479
+ trainable_weights = (
480
+ self.mapping.trainable_weights + self.generator.trainable_weights
481
+ )
482
+ gradients = g_tape.gradient(g_loss, trainable_weights)
483
+ self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
484
+
485
+ # discriminator
486
+ with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
487
+ # forward pass
488
+ pred_fake = self.discriminator([fake_images, alpha])
489
+ pred_real = self.discriminator([real_images, alpha])
490
+
491
+ epsilon = tf.random.uniform((batch_size, 1, 1, 1))
492
+ interpolates = epsilon * real_images + (1 - epsilon) * fake_images
493
+ gradient_tape.watch(interpolates)
494
+ pred_fake_grad = self.discriminator([interpolates, alpha])
495
+
496
+ # calculate losses
497
+ loss_fake = wasserstein_loss(fake_labels, pred_fake)
498
+ loss_real = wasserstein_loss(real_labels, pred_real)
499
+ loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
500
+
501
+ # gradient penalty
502
+ gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
503
+ gradient_penalty = self.loss_weights[
504
+ "gradient_penalty"
505
+ ] * self.gradient_loss(gradients_fake)
506
+
507
+ # drift loss
508
+ all_pred = tf.concat([pred_fake, pred_real], axis=0)
509
+ drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred ** 2)
510
+
511
+ d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
512
+
513
+ gradients = total_tape.gradient(
514
+ d_loss, self.discriminator.trainable_weights
515
+ )
516
+ self.d_optimizer.apply_gradients(
517
+ zip(gradients, self.discriminator.trainable_weights)
518
+ )
519
+
520
+ # Update metrics
521
+ self.d_loss_metric.update_state(d_loss)
522
+ self.g_loss_metric.update_state(g_loss)
523
+ return {
524
+ "d_loss": self.d_loss_metric.result(),
525
+ "g_loss": self.g_loss_metric.result(),
526
+ }
527
+
528
+ def call(self, inputs: dict()):
529
+ style_code = inputs.get("style_code", None)
530
+ z = inputs.get("z", None)
531
+ noise = inputs.get("noise", None)
532
+ batch_size = inputs.get("batch_size", 1)
533
+ alpha = inputs.get("alpha", 1.0)
534
+ alpha = tf.expand_dims(alpha, 0)
535
+ if style_code is None:
536
+ if z is None:
537
+ z = tf.random.normal((batch_size, self.z_dim))
538
+ style_code = self.mapping(z)
539
+
540
+ if noise is None:
541
+ noise = self.generate_noise(batch_size)
542
+
543
+ # self.alpha.assign(alpha)
544
+
545
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
546
+ images = self.generator([const_input, style_code, noise, alpha])
547
+ images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
548
+
549
+ return images
550
+
551
+ # Set up GAN
552
+
553
+ batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
554
+ train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
555
+
556
+ START_RES = 4
557
+ TARGET_RES = 128
558
+
559
+ # style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
560
+
561
+ print("Loading...")
562
+
563
+ url = "https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip"
564
+
565
+ weights_path = keras.utils.get_file(
566
+ "stylegan_128x128.ckpt.zip",
567
+ url,
568
+ extract=True,
569
+ cache_dir=os.path.abspath("."),
570
+ cache_subdir="pretrained",
571
+ )
572
+
573
+ # style_gan.grow_model(128)
574
+ # style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
575
+
576
+ # tf.random.set_seed(196)
577
+ # batch_size = 2
578
+ # z = tf.random.normal((batch_size, style_gan.z_dim))
579
+ # w = style_gan.mapping(z)
580
+ # noise = style_gan.generate_noise(batch_size=batch_size)
581
+ # images = style_gan({"style_code": w, "noise": noise, "alpha": 1.0})
582
+
583
+ # plot_images(images, 5)
584
+
585
+ class InferenceWrapper:
586
+ def __init__(self, model):
587
+ self.model = model
588
+ self.style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
589
+ self.style_gan.grow_model(128)
590
+ self.style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
591
+ self.seed = -1
592
+
593
+ def __call__(self, seed, feature):
594
+ if seed != self.seed:
595
+ print(f"Loading model: {self.model}")
596
+ tf.random.set_seed(seed)
597
+ batch_size = 1
598
+ self.z = tf.random.normal((batch_size, self.style_gan.z_dim))
599
+ self.w = self.style_gan.mapping(self.z)
600
+ self.noise = self.style_gan.generate_noise(batch_size=batch_size)
601
+ else:
602
+ print(f"Model '{self.model}' already loaded, reusing it.")
603
+ return self.style_gan({"style_code": self.w, "noise": self.noise, "alpha": 1.0})[0]
604
+
605
+
606
+ wrapper = InferenceWrapper('celeba')
607
+
608
+ def fn(seed, feature):
609
+ return wrapper(seed, feature)
610
+
611
+ gr.Interface(
612
+ fn,
613
+ inputs=[
614
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'),
615
+ gr.inputs.Radio(list({"test1","test2"}), type="value", default='test1', label='Feature Type')
616
+ ],
617
+ outputs='image',
618
+ examples=[[343, 'test1'], [456, 'test2']],
619
+ enable_queue=True,
620
+ title="NFT GAN",
621
+ description="Select random seed and selct Submit to generate a new image",
622
+ article="<p>Face image generation with StyleGAN using tf.keras. The code is from the Keras.io <a class='moflo-link' href='https://keras.io/examples/generative/stylegan/'>exmple</a> by Soon-Yau Cheong</p>",
623
+ css=".panel { padding: 5px } .moflo-link { color: #999 }"
624
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow
2
+ tensorflow-datasets
3
+ tensorflow-addons
4
+ matplotlib