jayparmr commited on
Commit
35575bb
·
verified ·
1 Parent(s): c95142c

update : inference

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -1
  2. external/briarmbg.py +460 -0
  3. external/llite/library/custom_train_functions.py +529 -529
  4. external/midas/__init__.py +0 -39
  5. external/midas/base_model.py +16 -0
  6. external/midas/blocks.py +342 -0
  7. external/midas/dpt_depth.py +109 -0
  8. external/midas/midas_net.py +76 -0
  9. external/midas/midas_net_custom.py +128 -0
  10. external/midas/transforms.py +234 -0
  11. external/midas/vit.py +491 -0
  12. external/realesrgan/__init__.py +6 -0
  13. external/realesrgan/archs/__init__.py +10 -0
  14. external/realesrgan/archs/discriminator_arch.py +67 -0
  15. external/realesrgan/archs/srvgg_arch.py +69 -0
  16. external/realesrgan/data/__init__.py +10 -0
  17. external/realesrgan/data/realesrgan_dataset.py +192 -0
  18. external/realesrgan/data/realesrgan_paired_dataset.py +117 -0
  19. external/realesrgan/models/__init__.py +10 -0
  20. external/realesrgan/models/realesrgan_model.py +258 -0
  21. external/realesrgan/models/realesrnet_model.py +188 -0
  22. external/realesrgan/train.py +11 -0
  23. external/realesrgan/utils.py +302 -0
  24. handler.py +6 -1
  25. inference.py +241 -96
  26. internals/data/task.py +17 -1
  27. internals/pipelines/commons.py +23 -4
  28. internals/pipelines/controlnets.py +277 -61
  29. internals/pipelines/high_res.py +32 -3
  30. internals/pipelines/inpaint_imageprocessor.py +976 -0
  31. internals/pipelines/inpainter.py +35 -4
  32. internals/pipelines/prompt_modifier.py +3 -1
  33. internals/pipelines/realtime_draw.py +13 -3
  34. internals/pipelines/remove_background.py +55 -5
  35. internals/pipelines/replace_background.py +8 -8
  36. internals/pipelines/safety_checker.py +3 -2
  37. internals/pipelines/sdxl_llite_pipeline.py +3 -1
  38. internals/pipelines/sdxl_tile_upscale.py +85 -14
  39. internals/pipelines/upscaler.py +25 -8
  40. internals/util/__init__.py +6 -0
  41. internals/util/cache.py +2 -0
  42. internals/util/commons.py +4 -4
  43. internals/util/config.py +19 -5
  44. internals/util/failure_hander.py +7 -4
  45. internals/util/image.py +18 -0
  46. internals/util/lora_style.py +29 -5
  47. internals/util/model_loader.py +8 -0
  48. internals/util/prompt.py +6 -1
  49. internals/util/sdxl_lightning.py +74 -0
  50. internals/util/slack.py +3 -0
.gitignore CHANGED
@@ -3,6 +3,8 @@
3
  .ipynb_checkpoints █
4
  .vscode
5
  env
6
- test.py
7
  *.jpeg
8
  __pycache__
 
 
 
3
  .ipynb_checkpoints █
4
  .vscode
5
  env
6
+ test*.py
7
  *.jpeg
8
  __pycache__
9
+ sample_task.txt
10
+ .idea
external/briarmbg.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+
7
+ class REBNCONV(nn.Module):
8
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
9
+ super(REBNCONV, self).__init__()
10
+
11
+ self.conv_s1 = nn.Conv2d(
12
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
13
+ )
14
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
15
+ self.relu_s1 = nn.ReLU(inplace=True)
16
+
17
+ def forward(self, x):
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
27
+
28
+ return src
29
+
30
+
31
+ ### RSU-7 ###
32
+ class RSU7(nn.Module):
33
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
34
+ super(RSU7, self).__init__()
35
+
36
+ self.in_ch = in_ch
37
+ self.mid_ch = mid_ch
38
+ self.out_ch = out_ch
39
+
40
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
41
+
42
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
43
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
56
+
57
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
58
+
59
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
60
+
61
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
67
+
68
+ def forward(self, x):
69
+ b, c, h, w = x.shape
70
+
71
+ hx = x
72
+ hxin = self.rebnconvin(hx)
73
+
74
+ hx1 = self.rebnconv1(hxin)
75
+ hx = self.pool1(hx1)
76
+
77
+ hx2 = self.rebnconv2(hx)
78
+ hx = self.pool2(hx2)
79
+
80
+ hx3 = self.rebnconv3(hx)
81
+ hx = self.pool3(hx3)
82
+
83
+ hx4 = self.rebnconv4(hx)
84
+ hx = self.pool4(hx4)
85
+
86
+ hx5 = self.rebnconv5(hx)
87
+ hx = self.pool5(hx5)
88
+
89
+ hx6 = self.rebnconv6(hx)
90
+
91
+ hx7 = self.rebnconv7(hx6)
92
+
93
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
94
+ hx6dup = _upsample_like(hx6d, hx5)
95
+
96
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
97
+ hx5dup = _upsample_like(hx5d, hx4)
98
+
99
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
100
+ hx4dup = _upsample_like(hx4d, hx3)
101
+
102
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
103
+ hx3dup = _upsample_like(hx3d, hx2)
104
+
105
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
106
+ hx2dup = _upsample_like(hx2d, hx1)
107
+
108
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
109
+
110
+ return hx1d + hxin
111
+
112
+
113
+ ### RSU-6 ###
114
+ class RSU6(nn.Module):
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6, self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
141
+
142
+ def forward(self, x):
143
+ hx = x
144
+
145
+ hxin = self.rebnconvin(hx)
146
+
147
+ hx1 = self.rebnconv1(hxin)
148
+ hx = self.pool1(hx1)
149
+
150
+ hx2 = self.rebnconv2(hx)
151
+ hx = self.pool2(hx2)
152
+
153
+ hx3 = self.rebnconv3(hx)
154
+ hx = self.pool3(hx3)
155
+
156
+ hx4 = self.rebnconv4(hx)
157
+ hx = self.pool4(hx4)
158
+
159
+ hx5 = self.rebnconv5(hx)
160
+
161
+ hx6 = self.rebnconv6(hx5)
162
+
163
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
164
+ hx5dup = _upsample_like(hx5d, hx4)
165
+
166
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
167
+ hx4dup = _upsample_like(hx4d, hx3)
168
+
169
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
170
+ hx3dup = _upsample_like(hx3d, hx2)
171
+
172
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
173
+ hx2dup = _upsample_like(hx2d, hx1)
174
+
175
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
176
+
177
+ return hx1d + hxin
178
+
179
+
180
+ ### RSU-5 ###
181
+ class RSU5(nn.Module):
182
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
183
+ super(RSU5, self).__init__()
184
+
185
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
186
+
187
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
188
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
189
+
190
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
191
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
192
+
193
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
194
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
195
+
196
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
197
+
198
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
199
+
200
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
201
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
202
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
204
+
205
+ def forward(self, x):
206
+ hx = x
207
+
208
+ hxin = self.rebnconvin(hx)
209
+
210
+ hx1 = self.rebnconv1(hxin)
211
+ hx = self.pool1(hx1)
212
+
213
+ hx2 = self.rebnconv2(hx)
214
+ hx = self.pool2(hx2)
215
+
216
+ hx3 = self.rebnconv3(hx)
217
+ hx = self.pool3(hx3)
218
+
219
+ hx4 = self.rebnconv4(hx)
220
+
221
+ hx5 = self.rebnconv5(hx4)
222
+
223
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
224
+ hx4dup = _upsample_like(hx4d, hx3)
225
+
226
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
227
+ hx3dup = _upsample_like(hx3d, hx2)
228
+
229
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
230
+ hx2dup = _upsample_like(hx2d, hx1)
231
+
232
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
233
+
234
+ return hx1d + hxin
235
+
236
+
237
+ ### RSU-4 ###
238
+ class RSU4(nn.Module):
239
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
240
+ super(RSU4, self).__init__()
241
+
242
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
243
+
244
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
245
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
246
+
247
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
248
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
+
250
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
+
252
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
253
+
254
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
255
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
256
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
257
+
258
+ def forward(self, x):
259
+ hx = x
260
+
261
+ hxin = self.rebnconvin(hx)
262
+
263
+ hx1 = self.rebnconv1(hxin)
264
+ hx = self.pool1(hx1)
265
+
266
+ hx2 = self.rebnconv2(hx)
267
+ hx = self.pool2(hx2)
268
+
269
+ hx3 = self.rebnconv3(hx)
270
+
271
+ hx4 = self.rebnconv4(hx3)
272
+
273
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
274
+ hx3dup = _upsample_like(hx3d, hx2)
275
+
276
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
277
+ hx2dup = _upsample_like(hx2d, hx1)
278
+
279
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
280
+
281
+ return hx1d + hxin
282
+
283
+
284
+ ### RSU-4F ###
285
+ class RSU4F(nn.Module):
286
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
287
+ super(RSU4F, self).__init__()
288
+
289
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
290
+
291
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
292
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
293
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
294
+
295
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
296
+
297
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
298
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
299
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
300
+
301
+ def forward(self, x):
302
+ hx = x
303
+
304
+ hxin = self.rebnconvin(hx)
305
+
306
+ hx1 = self.rebnconv1(hxin)
307
+ hx2 = self.rebnconv2(hx1)
308
+ hx3 = self.rebnconv3(hx2)
309
+
310
+ hx4 = self.rebnconv4(hx3)
311
+
312
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
313
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
314
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
315
+
316
+ return hx1d + hxin
317
+
318
+
319
+ class myrebnconv(nn.Module):
320
+ def __init__(
321
+ self,
322
+ in_ch=3,
323
+ out_ch=1,
324
+ kernel_size=3,
325
+ stride=1,
326
+ padding=1,
327
+ dilation=1,
328
+ groups=1,
329
+ ):
330
+ super(myrebnconv, self).__init__()
331
+
332
+ self.conv = nn.Conv2d(
333
+ in_ch,
334
+ out_ch,
335
+ kernel_size=kernel_size,
336
+ stride=stride,
337
+ padding=padding,
338
+ dilation=dilation,
339
+ groups=groups,
340
+ )
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self, x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
350
+ super(BriaRMBG, self).__init__()
351
+ in_ch = config["in_ch"]
352
+ out_ch = config["out_ch"]
353
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
354
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage1 = RSU7(64, 32, 64)
357
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage2 = RSU6(64, 32, 128)
360
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage3 = RSU5(128, 64, 256)
363
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage4 = RSU4(256, 128, 512)
366
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage5 = RSU4F(512, 256, 512)
369
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage6 = RSU4F(512, 256, 512)
372
+
373
+ # decoder
374
+ self.stage5d = RSU4F(1024, 256, 512)
375
+ self.stage4d = RSU4(1024, 128, 256)
376
+ self.stage3d = RSU5(512, 64, 128)
377
+ self.stage2d = RSU6(256, 32, 64)
378
+ self.stage1d = RSU7(128, 16, 64)
379
+
380
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
381
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
382
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
383
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
384
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
385
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
386
+
387
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
388
+
389
+ def forward(self, x):
390
+ hx = x
391
+
392
+ hxin = self.conv_in(hx)
393
+ # hx = self.pool_in(hxin)
394
+
395
+ # stage 1
396
+ hx1 = self.stage1(hxin)
397
+ hx = self.pool12(hx1)
398
+
399
+ # stage 2
400
+ hx2 = self.stage2(hx)
401
+ hx = self.pool23(hx2)
402
+
403
+ # stage 3
404
+ hx3 = self.stage3(hx)
405
+ hx = self.pool34(hx3)
406
+
407
+ # stage 4
408
+ hx4 = self.stage4(hx)
409
+ hx = self.pool45(hx4)
410
+
411
+ # stage 5
412
+ hx5 = self.stage5(hx)
413
+ hx = self.pool56(hx5)
414
+
415
+ # stage 6
416
+ hx6 = self.stage6(hx)
417
+ hx6up = _upsample_like(hx6, hx5)
418
+
419
+ # -------------------- decoder --------------------
420
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
421
+ hx5dup = _upsample_like(hx5d, hx4)
422
+
423
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
424
+ hx4dup = _upsample_like(hx4d, hx3)
425
+
426
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
427
+ hx3dup = _upsample_like(hx3d, hx2)
428
+
429
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
430
+ hx2dup = _upsample_like(hx2d, hx1)
431
+
432
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
433
+
434
+ # side output
435
+ d1 = self.side1(hx1d)
436
+ d1 = _upsample_like(d1, x)
437
+
438
+ d2 = self.side2(hx2d)
439
+ d2 = _upsample_like(d2, x)
440
+
441
+ d3 = self.side3(hx3d)
442
+ d3 = _upsample_like(d3, x)
443
+
444
+ d4 = self.side4(hx4d)
445
+ d4 = _upsample_like(d4, x)
446
+
447
+ d5 = self.side5(hx5d)
448
+ d5 = _upsample_like(d5, x)
449
+
450
+ d6 = self.side6(hx6)
451
+ d6 = _upsample_like(d6, x)
452
+
453
+ return [
454
+ F.sigmoid(d1),
455
+ F.sigmoid(d2),
456
+ F.sigmoid(d3),
457
+ F.sigmoid(d4),
458
+ F.sigmoid(d5),
459
+ F.sigmoid(d6),
460
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
external/llite/library/custom_train_functions.py CHANGED
@@ -1,529 +1,529 @@
1
- import torch
2
- import argparse
3
- import random
4
- import re
5
- from typing import List, Optional, Union
6
-
7
-
8
- def prepare_scheduler_for_custom_training(noise_scheduler, device):
9
- if hasattr(noise_scheduler, "all_snr"):
10
- return
11
-
12
- alphas_cumprod = noise_scheduler.alphas_cumprod
13
- sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
14
- sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
15
- alpha = sqrt_alphas_cumprod
16
- sigma = sqrt_one_minus_alphas_cumprod
17
- all_snr = (alpha / sigma) ** 2
18
-
19
- noise_scheduler.all_snr = all_snr.to(device)
20
-
21
-
22
- def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
23
- # fix beta: zero terminal SNR
24
- print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
25
-
26
- def enforce_zero_terminal_snr(betas):
27
- # Convert betas to alphas_bar_sqrt
28
- alphas = 1 - betas
29
- alphas_bar = alphas.cumprod(0)
30
- alphas_bar_sqrt = alphas_bar.sqrt()
31
-
32
- # Store old values.
33
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
34
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
35
- # Shift so last timestep is zero.
36
- alphas_bar_sqrt -= alphas_bar_sqrt_T
37
- # Scale so first timestep is back to old value.
38
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
39
-
40
- # Convert alphas_bar_sqrt to betas
41
- alphas_bar = alphas_bar_sqrt**2
42
- alphas = alphas_bar[1:] / alphas_bar[:-1]
43
- alphas = torch.cat([alphas_bar[0:1], alphas])
44
- betas = 1 - alphas
45
- return betas
46
-
47
- betas = noise_scheduler.betas
48
- betas = enforce_zero_terminal_snr(betas)
49
- alphas = 1.0 - betas
50
- alphas_cumprod = torch.cumprod(alphas, dim=0)
51
-
52
- # print("original:", noise_scheduler.betas)
53
- # print("fixed:", betas)
54
-
55
- noise_scheduler.betas = betas
56
- noise_scheduler.alphas = alphas
57
- noise_scheduler.alphas_cumprod = alphas_cumprod
58
-
59
-
60
- def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
61
- snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
62
- min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
63
- if v_prediction:
64
- snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
65
- else:
66
- snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
67
- loss = loss * snr_weight
68
- return loss
69
-
70
-
71
- def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
72
- scale = get_snr_scale(timesteps, noise_scheduler)
73
- loss = loss * scale
74
- return loss
75
-
76
-
77
- def get_snr_scale(timesteps, noise_scheduler):
78
- snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
79
- snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
80
- scale = snr_t / (snr_t + 1)
81
- # # show debug info
82
- # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
83
- return scale
84
-
85
-
86
- def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
87
- scale = get_snr_scale(timesteps, noise_scheduler)
88
- # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
89
- loss = loss + loss / scale * v_pred_like_loss
90
- return loss
91
-
92
- def apply_debiased_estimation(loss, timesteps, noise_scheduler):
93
- snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
94
- snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
95
- weight = 1/torch.sqrt(snr_t)
96
- loss = weight * loss
97
- return loss
98
-
99
- # TODO train_utilと分散しているのでどちらかに寄せる
100
-
101
-
102
- def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
103
- parser.add_argument(
104
- "--min_snr_gamma",
105
- type=float,
106
- default=None,
107
- help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
108
- )
109
- parser.add_argument(
110
- "--scale_v_pred_loss_like_noise_pred",
111
- action="store_true",
112
- help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
113
- )
114
- parser.add_argument(
115
- "--v_pred_like_loss",
116
- type=float,
117
- default=None,
118
- help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
119
- )
120
- parser.add_argument(
121
- "--debiased_estimation_loss",
122
- action="store_true",
123
- help="debiased estimation loss / debiased estimation loss",
124
- )
125
- if support_weighted_captions:
126
- parser.add_argument(
127
- "--weighted_captions",
128
- action="store_true",
129
- default=False,
130
- help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
131
- )
132
-
133
-
134
- re_attention = re.compile(
135
- r"""
136
- \\\(|
137
- \\\)|
138
- \\\[|
139
- \\]|
140
- \\\\|
141
- \\|
142
- \(|
143
- \[|
144
- :([+-]?[.\d]+)\)|
145
- \)|
146
- ]|
147
- [^\\()\[\]:]+|
148
- :
149
- """,
150
- re.X,
151
- )
152
-
153
-
154
- def parse_prompt_attention(text):
155
- """
156
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
157
- Accepted tokens are:
158
- (abc) - increases attention to abc by a multiplier of 1.1
159
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
160
- [abc] - decreases attention to abc by a multiplier of 1.1
161
- \( - literal character '('
162
- \[ - literal character '['
163
- \) - literal character ')'
164
- \] - literal character ']'
165
- \\ - literal character '\'
166
- anything else - just text
167
- >>> parse_prompt_attention('normal text')
168
- [['normal text', 1.0]]
169
- >>> parse_prompt_attention('an (important) word')
170
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
171
- >>> parse_prompt_attention('(unbalanced')
172
- [['unbalanced', 1.1]]
173
- >>> parse_prompt_attention('\(literal\]')
174
- [['(literal]', 1.0]]
175
- >>> parse_prompt_attention('(unnecessary)(parens)')
176
- [['unnecessaryparens', 1.1]]
177
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
178
- [['a ', 1.0],
179
- ['house', 1.5730000000000004],
180
- [' ', 1.1],
181
- ['on', 1.0],
182
- [' a ', 1.1],
183
- ['hill', 0.55],
184
- [', sun, ', 1.1],
185
- ['sky', 1.4641000000000006],
186
- ['.', 1.1]]
187
- """
188
-
189
- res = []
190
- round_brackets = []
191
- square_brackets = []
192
-
193
- round_bracket_multiplier = 1.1
194
- square_bracket_multiplier = 1 / 1.1
195
-
196
- def multiply_range(start_position, multiplier):
197
- for p in range(start_position, len(res)):
198
- res[p][1] *= multiplier
199
-
200
- for m in re_attention.finditer(text):
201
- text = m.group(0)
202
- weight = m.group(1)
203
-
204
- if text.startswith("\\"):
205
- res.append([text[1:], 1.0])
206
- elif text == "(":
207
- round_brackets.append(len(res))
208
- elif text == "[":
209
- square_brackets.append(len(res))
210
- elif weight is not None and len(round_brackets) > 0:
211
- multiply_range(round_brackets.pop(), float(weight))
212
- elif text == ")" and len(round_brackets) > 0:
213
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
214
- elif text == "]" and len(square_brackets) > 0:
215
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
216
- else:
217
- res.append([text, 1.0])
218
-
219
- for pos in round_brackets:
220
- multiply_range(pos, round_bracket_multiplier)
221
-
222
- for pos in square_brackets:
223
- multiply_range(pos, square_bracket_multiplier)
224
-
225
- if len(res) == 0:
226
- res = [["", 1.0]]
227
-
228
- # merge runs of identical weights
229
- i = 0
230
- while i + 1 < len(res):
231
- if res[i][1] == res[i + 1][1]:
232
- res[i][0] += res[i + 1][0]
233
- res.pop(i + 1)
234
- else:
235
- i += 1
236
-
237
- return res
238
-
239
-
240
- def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
241
- r"""
242
- Tokenize a list of prompts and return its tokens with weights of each token.
243
-
244
- No padding, starting or ending token is included.
245
- """
246
- tokens = []
247
- weights = []
248
- truncated = False
249
- for text in prompt:
250
- texts_and_weights = parse_prompt_attention(text)
251
- text_token = []
252
- text_weight = []
253
- for word, weight in texts_and_weights:
254
- # tokenize and discard the starting and the ending token
255
- token = tokenizer(word).input_ids[1:-1]
256
- text_token += token
257
- # copy the weight by length of token
258
- text_weight += [weight] * len(token)
259
- # stop if the text is too long (longer than truncation limit)
260
- if len(text_token) > max_length:
261
- truncated = True
262
- break
263
- # truncate
264
- if len(text_token) > max_length:
265
- truncated = True
266
- text_token = text_token[:max_length]
267
- text_weight = text_weight[:max_length]
268
- tokens.append(text_token)
269
- weights.append(text_weight)
270
- if truncated:
271
- print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
272
- return tokens, weights
273
-
274
-
275
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
276
- r"""
277
- Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
278
- """
279
- max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
280
- weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
281
- for i in range(len(tokens)):
282
- tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
283
- if no_boseos_middle:
284
- weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
285
- else:
286
- w = []
287
- if len(weights[i]) == 0:
288
- w = [1.0] * weights_length
289
- else:
290
- for j in range(max_embeddings_multiples):
291
- w.append(1.0) # weight for starting token in this chunk
292
- w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
293
- w.append(1.0) # weight for ending token in this chunk
294
- w += [1.0] * (weights_length - len(w))
295
- weights[i] = w[:]
296
-
297
- return tokens, weights
298
-
299
-
300
- def get_unweighted_text_embeddings(
301
- tokenizer,
302
- text_encoder,
303
- text_input: torch.Tensor,
304
- chunk_length: int,
305
- clip_skip: int,
306
- eos: int,
307
- pad: int,
308
- no_boseos_middle: Optional[bool] = True,
309
- ):
310
- """
311
- When the length of tokens is a multiple of the capacity of the text encoder,
312
- it should be split into chunks and sent to the text encoder individually.
313
- """
314
- max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
315
- if max_embeddings_multiples > 1:
316
- text_embeddings = []
317
- for i in range(max_embeddings_multiples):
318
- # extract the i-th chunk
319
- text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
320
-
321
- # cover the head and the tail by the starting and the ending tokens
322
- text_input_chunk[:, 0] = text_input[0, 0]
323
- if pad == eos: # v1
324
- text_input_chunk[:, -1] = text_input[0, -1]
325
- else: # v2
326
- for j in range(len(text_input_chunk)):
327
- if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
328
- text_input_chunk[j, -1] = eos
329
- if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
330
- text_input_chunk[j, 1] = eos
331
-
332
- if clip_skip is None or clip_skip == 1:
333
- text_embedding = text_encoder(text_input_chunk)[0]
334
- else:
335
- enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
336
- text_embedding = enc_out["hidden_states"][-clip_skip]
337
- text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
338
-
339
- if no_boseos_middle:
340
- if i == 0:
341
- # discard the ending token
342
- text_embedding = text_embedding[:, :-1]
343
- elif i == max_embeddings_multiples - 1:
344
- # discard the starting token
345
- text_embedding = text_embedding[:, 1:]
346
- else:
347
- # discard both starting and ending tokens
348
- text_embedding = text_embedding[:, 1:-1]
349
-
350
- text_embeddings.append(text_embedding)
351
- text_embeddings = torch.concat(text_embeddings, axis=1)
352
- else:
353
- if clip_skip is None or clip_skip == 1:
354
- text_embeddings = text_encoder(text_input)[0]
355
- else:
356
- enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
357
- text_embeddings = enc_out["hidden_states"][-clip_skip]
358
- text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
359
- return text_embeddings
360
-
361
-
362
- def get_weighted_text_embeddings(
363
- tokenizer,
364
- text_encoder,
365
- prompt: Union[str, List[str]],
366
- device,
367
- max_embeddings_multiples: Optional[int] = 3,
368
- no_boseos_middle: Optional[bool] = False,
369
- clip_skip=None,
370
- ):
371
- r"""
372
- Prompts can be assigned with local weights using brackets. For example,
373
- prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
374
- and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
375
-
376
- Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
377
-
378
- Args:
379
- prompt (`str` or `List[str]`):
380
- The prompt or prompts to guide the image generation.
381
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
382
- The max multiple length of prompt embeddings compared to the max output length of text encoder.
383
- no_boseos_middle (`bool`, *optional*, defaults to `False`):
384
- If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
385
- ending token in each of the chunk in the middle.
386
- skip_parsing (`bool`, *optional*, defaults to `False`):
387
- Skip the parsing of brackets.
388
- skip_weighting (`bool`, *optional*, defaults to `False`):
389
- Skip the weighting. When the parsing is skipped, it is forced True.
390
- """
391
- max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
392
- if isinstance(prompt, str):
393
- prompt = [prompt]
394
-
395
- prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
396
-
397
- # round up the longest length of tokens to a multiple of (model_max_length - 2)
398
- max_length = max([len(token) for token in prompt_tokens])
399
-
400
- max_embeddings_multiples = min(
401
- max_embeddings_multiples,
402
- (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
403
- )
404
- max_embeddings_multiples = max(1, max_embeddings_multiples)
405
- max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
406
-
407
- # pad the length of tokens and weights
408
- bos = tokenizer.bos_token_id
409
- eos = tokenizer.eos_token_id
410
- pad = tokenizer.pad_token_id
411
- prompt_tokens, prompt_weights = pad_tokens_and_weights(
412
- prompt_tokens,
413
- prompt_weights,
414
- max_length,
415
- bos,
416
- eos,
417
- no_boseos_middle=no_boseos_middle,
418
- chunk_length=tokenizer.model_max_length,
419
- )
420
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
421
-
422
- # get the embeddings
423
- text_embeddings = get_unweighted_text_embeddings(
424
- tokenizer,
425
- text_encoder,
426
- prompt_tokens,
427
- tokenizer.model_max_length,
428
- clip_skip,
429
- eos,
430
- pad,
431
- no_boseos_middle=no_boseos_middle,
432
- )
433
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
434
-
435
- # assign weights to the prompts and normalize in the sense of mean
436
- previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
437
- text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
438
- current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
439
- text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
440
-
441
- return text_embeddings
442
-
443
-
444
- # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
445
- def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
446
- b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
447
- u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
448
- for i in range(iterations):
449
- r = random.random() * 2 + 2 # Rather than always going 2x,
450
- wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
451
- noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
452
- if wn == 1 or hn == 1:
453
- break # Lowest resolution is 1x1
454
- return noise / noise.std() # Scaled back to roughly unit variance
455
-
456
-
457
- # https://www.crosslabs.org//blog/diffusion-with-offset-noise
458
- def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
459
- if noise_offset is None:
460
- return noise
461
- if adaptive_noise_scale is not None:
462
- # latent shape: (batch_size, channels, height, width)
463
- # abs mean value for each channel
464
- latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
465
-
466
- # multiply adaptive noise scale to the mean value and add it to the noise offset
467
- noise_offset = noise_offset + adaptive_noise_scale * latent_mean
468
- noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
469
-
470
- noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
471
- return noise
472
-
473
-
474
- """
475
- ##########################################
476
- # Perlin Noise
477
- def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
478
- delta = (res[0] / shape[0], res[1] / shape[1])
479
- d = (shape[0] // res[0], shape[1] // res[1])
480
-
481
- grid = (
482
- torch.stack(
483
- torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
484
- dim=-1,
485
- )
486
- % 1
487
- )
488
- angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
489
- gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
490
-
491
- tile_grads = (
492
- lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
493
- .repeat_interleave(d[0], 0)
494
- .repeat_interleave(d[1], 1)
495
- )
496
- dot = lambda grad, shift: (
497
- torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
498
- * grad[: shape[0], : shape[1]]
499
- ).sum(dim=-1)
500
-
501
- n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
502
- n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
503
- n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
504
- n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
505
- t = fade(grid[: shape[0], : shape[1]])
506
- return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
507
-
508
-
509
- def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
510
- noise = torch.zeros(shape, device=device)
511
- frequency = 1
512
- amplitude = 1
513
- for _ in range(octaves):
514
- noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
515
- frequency *= 2
516
- amplitude *= persistence
517
- return noise
518
-
519
-
520
- def perlin_noise(noise, device, octaves):
521
- _, c, w, h = noise.shape
522
- perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
523
- noise_perlin = []
524
- for _ in range(c):
525
- noise_perlin.append(perlin())
526
- noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
527
- noise += noise_perlin # broadcast for each batch
528
- return noise / noise.std() # Scaled back to roughly unit variance
529
- """
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+
7
+
8
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
9
+ if hasattr(noise_scheduler, "all_snr"):
10
+ return
11
+
12
+ alphas_cumprod = noise_scheduler.alphas_cumprod
13
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
14
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
15
+ alpha = sqrt_alphas_cumprod
16
+ sigma = sqrt_one_minus_alphas_cumprod
17
+ all_snr = (alpha / sigma) ** 2
18
+
19
+ noise_scheduler.all_snr = all_snr.to(device)
20
+
21
+
22
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
23
+ # fix beta: zero terminal SNR
24
+ print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
25
+
26
+ def enforce_zero_terminal_snr(betas):
27
+ # Convert betas to alphas_bar_sqrt
28
+ alphas = 1 - betas
29
+ alphas_bar = alphas.cumprod(0)
30
+ alphas_bar_sqrt = alphas_bar.sqrt()
31
+
32
+ # Store old values.
33
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
34
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
35
+ # Shift so last timestep is zero.
36
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
37
+ # Scale so first timestep is back to old value.
38
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
39
+
40
+ # Convert alphas_bar_sqrt to betas
41
+ alphas_bar = alphas_bar_sqrt**2
42
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
43
+ alphas = torch.cat([alphas_bar[0:1], alphas])
44
+ betas = 1 - alphas
45
+ return betas
46
+
47
+ betas = noise_scheduler.betas
48
+ betas = enforce_zero_terminal_snr(betas)
49
+ alphas = 1.0 - betas
50
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
51
+
52
+ # print("original:", noise_scheduler.betas)
53
+ # print("fixed:", betas)
54
+
55
+ noise_scheduler.betas = betas
56
+ noise_scheduler.alphas = alphas
57
+ noise_scheduler.alphas_cumprod = alphas_cumprod
58
+
59
+
60
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
61
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
62
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
63
+ if v_prediction:
64
+ snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
65
+ else:
66
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
67
+ loss = loss * snr_weight
68
+ return loss
69
+
70
+
71
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
72
+ scale = get_snr_scale(timesteps, noise_scheduler)
73
+ loss = loss * scale
74
+ return loss
75
+
76
+
77
+ def get_snr_scale(timesteps, noise_scheduler):
78
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
79
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
80
+ scale = snr_t / (snr_t + 1)
81
+ # # show debug info
82
+ # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
83
+ return scale
84
+
85
+
86
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
87
+ scale = get_snr_scale(timesteps, noise_scheduler)
88
+ # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
89
+ loss = loss + loss / scale * v_pred_like_loss
90
+ return loss
91
+
92
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler):
93
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
94
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
95
+ weight = 1/torch.sqrt(snr_t)
96
+ loss = weight * loss
97
+ return loss
98
+
99
+ # TODO train_utilと分散しているのでどちらかに寄せる
100
+
101
+
102
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
103
+ parser.add_argument(
104
+ "--min_snr_gamma",
105
+ type=float,
106
+ default=None,
107
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
108
+ )
109
+ parser.add_argument(
110
+ "--scale_v_pred_loss_like_noise_pred",
111
+ action="store_true",
112
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
113
+ )
114
+ parser.add_argument(
115
+ "--v_pred_like_loss",
116
+ type=float,
117
+ default=None,
118
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
119
+ )
120
+ parser.add_argument(
121
+ "--debiased_estimation_loss",
122
+ action="store_true",
123
+ help="debiased estimation loss / debiased estimation loss",
124
+ )
125
+ if support_weighted_captions:
126
+ parser.add_argument(
127
+ "--weighted_captions",
128
+ action="store_true",
129
+ default=False,
130
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
131
+ )
132
+
133
+
134
+ re_attention = re.compile(
135
+ r"""
136
+ \\\(|
137
+ \\\)|
138
+ \\\[|
139
+ \\]|
140
+ \\\\|
141
+ \\|
142
+ \(|
143
+ \[|
144
+ :([+-]?[.\d]+)\)|
145
+ \)|
146
+ ]|
147
+ [^\\()\[\]:]+|
148
+ :
149
+ """,
150
+ re.X,
151
+ )
152
+
153
+
154
+ def parse_prompt_attention(text):
155
+ """
156
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
157
+ Accepted tokens are:
158
+ (abc) - increases attention to abc by a multiplier of 1.1
159
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
160
+ [abc] - decreases attention to abc by a multiplier of 1.1
161
+ \( - literal character '('
162
+ \[ - literal character '['
163
+ \) - literal character ')'
164
+ \] - literal character ']'
165
+ \\ - literal character '\'
166
+ anything else - just text
167
+ >>> parse_prompt_attention('normal text')
168
+ [['normal text', 1.0]]
169
+ >>> parse_prompt_attention('an (important) word')
170
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
171
+ >>> parse_prompt_attention('(unbalanced')
172
+ [['unbalanced', 1.1]]
173
+ >>> parse_prompt_attention('\(literal\]')
174
+ [['(literal]', 1.0]]
175
+ >>> parse_prompt_attention('(unnecessary)(parens)')
176
+ [['unnecessaryparens', 1.1]]
177
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
178
+ [['a ', 1.0],
179
+ ['house', 1.5730000000000004],
180
+ [' ', 1.1],
181
+ ['on', 1.0],
182
+ [' a ', 1.1],
183
+ ['hill', 0.55],
184
+ [', sun, ', 1.1],
185
+ ['sky', 1.4641000000000006],
186
+ ['.', 1.1]]
187
+ """
188
+
189
+ res = []
190
+ round_brackets = []
191
+ square_brackets = []
192
+
193
+ round_bracket_multiplier = 1.1
194
+ square_bracket_multiplier = 1 / 1.1
195
+
196
+ def multiply_range(start_position, multiplier):
197
+ for p in range(start_position, len(res)):
198
+ res[p][1] *= multiplier
199
+
200
+ for m in re_attention.finditer(text):
201
+ text = m.group(0)
202
+ weight = m.group(1)
203
+
204
+ if text.startswith("\\"):
205
+ res.append([text[1:], 1.0])
206
+ elif text == "(":
207
+ round_brackets.append(len(res))
208
+ elif text == "[":
209
+ square_brackets.append(len(res))
210
+ elif weight is not None and len(round_brackets) > 0:
211
+ multiply_range(round_brackets.pop(), float(weight))
212
+ elif text == ")" and len(round_brackets) > 0:
213
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
214
+ elif text == "]" and len(square_brackets) > 0:
215
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
216
+ else:
217
+ res.append([text, 1.0])
218
+
219
+ for pos in round_brackets:
220
+ multiply_range(pos, round_bracket_multiplier)
221
+
222
+ for pos in square_brackets:
223
+ multiply_range(pos, square_bracket_multiplier)
224
+
225
+ if len(res) == 0:
226
+ res = [["", 1.0]]
227
+
228
+ # merge runs of identical weights
229
+ i = 0
230
+ while i + 1 < len(res):
231
+ if res[i][1] == res[i + 1][1]:
232
+ res[i][0] += res[i + 1][0]
233
+ res.pop(i + 1)
234
+ else:
235
+ i += 1
236
+
237
+ return res
238
+
239
+
240
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
241
+ r"""
242
+ Tokenize a list of prompts and return its tokens with weights of each token.
243
+
244
+ No padding, starting or ending token is included.
245
+ """
246
+ tokens = []
247
+ weights = []
248
+ truncated = False
249
+ for text in prompt:
250
+ texts_and_weights = parse_prompt_attention(text)
251
+ text_token = []
252
+ text_weight = []
253
+ for word, weight in texts_and_weights:
254
+ # tokenize and discard the starting and the ending token
255
+ token = tokenizer(word).input_ids[1:-1]
256
+ text_token += token
257
+ # copy the weight by length of token
258
+ text_weight += [weight] * len(token)
259
+ # stop if the text is too long (longer than truncation limit)
260
+ if len(text_token) > max_length:
261
+ truncated = True
262
+ break
263
+ # truncate
264
+ if len(text_token) > max_length:
265
+ truncated = True
266
+ text_token = text_token[:max_length]
267
+ text_weight = text_weight[:max_length]
268
+ tokens.append(text_token)
269
+ weights.append(text_weight)
270
+ if truncated:
271
+ print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
272
+ return tokens, weights
273
+
274
+
275
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
276
+ r"""
277
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
278
+ """
279
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
280
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
281
+ for i in range(len(tokens)):
282
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
283
+ if no_boseos_middle:
284
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
285
+ else:
286
+ w = []
287
+ if len(weights[i]) == 0:
288
+ w = [1.0] * weights_length
289
+ else:
290
+ for j in range(max_embeddings_multiples):
291
+ w.append(1.0) # weight for starting token in this chunk
292
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
293
+ w.append(1.0) # weight for ending token in this chunk
294
+ w += [1.0] * (weights_length - len(w))
295
+ weights[i] = w[:]
296
+
297
+ return tokens, weights
298
+
299
+
300
+ def get_unweighted_text_embeddings(
301
+ tokenizer,
302
+ text_encoder,
303
+ text_input: torch.Tensor,
304
+ chunk_length: int,
305
+ clip_skip: int,
306
+ eos: int,
307
+ pad: int,
308
+ no_boseos_middle: Optional[bool] = True,
309
+ ):
310
+ """
311
+ When the length of tokens is a multiple of the capacity of the text encoder,
312
+ it should be split into chunks and sent to the text encoder individually.
313
+ """
314
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
315
+ if max_embeddings_multiples > 1:
316
+ text_embeddings = []
317
+ for i in range(max_embeddings_multiples):
318
+ # extract the i-th chunk
319
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
320
+
321
+ # cover the head and the tail by the starting and the ending tokens
322
+ text_input_chunk[:, 0] = text_input[0, 0]
323
+ if pad == eos: # v1
324
+ text_input_chunk[:, -1] = text_input[0, -1]
325
+ else: # v2
326
+ for j in range(len(text_input_chunk)):
327
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
328
+ text_input_chunk[j, -1] = eos
329
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
330
+ text_input_chunk[j, 1] = eos
331
+
332
+ if clip_skip is None or clip_skip == 1:
333
+ text_embedding = text_encoder(text_input_chunk)[0]
334
+ else:
335
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
336
+ text_embedding = enc_out["hidden_states"][-clip_skip]
337
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
338
+
339
+ if no_boseos_middle:
340
+ if i == 0:
341
+ # discard the ending token
342
+ text_embedding = text_embedding[:, :-1]
343
+ elif i == max_embeddings_multiples - 1:
344
+ # discard the starting token
345
+ text_embedding = text_embedding[:, 1:]
346
+ else:
347
+ # discard both starting and ending tokens
348
+ text_embedding = text_embedding[:, 1:-1]
349
+
350
+ text_embeddings.append(text_embedding)
351
+ text_embeddings = torch.concat(text_embeddings, axis=1)
352
+ else:
353
+ if clip_skip is None or clip_skip == 1:
354
+ text_embeddings = text_encoder(text_input)[0]
355
+ else:
356
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
357
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
358
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
359
+ return text_embeddings
360
+
361
+
362
+ def get_weighted_text_embeddings(
363
+ tokenizer,
364
+ text_encoder,
365
+ prompt: Union[str, List[str]],
366
+ device,
367
+ max_embeddings_multiples: Optional[int] = 3,
368
+ no_boseos_middle: Optional[bool] = False,
369
+ clip_skip=None,
370
+ ):
371
+ r"""
372
+ Prompts can be assigned with local weights using brackets. For example,
373
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
374
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
375
+
376
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
377
+
378
+ Args:
379
+ prompt (`str` or `List[str]`):
380
+ The prompt or prompts to guide the image generation.
381
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
382
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
383
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
384
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
385
+ ending token in each of the chunk in the middle.
386
+ skip_parsing (`bool`, *optional*, defaults to `False`):
387
+ Skip the parsing of brackets.
388
+ skip_weighting (`bool`, *optional*, defaults to `False`):
389
+ Skip the weighting. When the parsing is skipped, it is forced True.
390
+ """
391
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
392
+ if isinstance(prompt, str):
393
+ prompt = [prompt]
394
+
395
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
396
+
397
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
398
+ max_length = max([len(token) for token in prompt_tokens])
399
+
400
+ max_embeddings_multiples = min(
401
+ max_embeddings_multiples,
402
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
403
+ )
404
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
405
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
406
+
407
+ # pad the length of tokens and weights
408
+ bos = tokenizer.bos_token_id
409
+ eos = tokenizer.eos_token_id
410
+ pad = tokenizer.pad_token_id
411
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
412
+ prompt_tokens,
413
+ prompt_weights,
414
+ max_length,
415
+ bos,
416
+ eos,
417
+ no_boseos_middle=no_boseos_middle,
418
+ chunk_length=tokenizer.model_max_length,
419
+ )
420
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
421
+
422
+ # get the embeddings
423
+ text_embeddings = get_unweighted_text_embeddings(
424
+ tokenizer,
425
+ text_encoder,
426
+ prompt_tokens,
427
+ tokenizer.model_max_length,
428
+ clip_skip,
429
+ eos,
430
+ pad,
431
+ no_boseos_middle=no_boseos_middle,
432
+ )
433
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
434
+
435
+ # assign weights to the prompts and normalize in the sense of mean
436
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
437
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
438
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
439
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
440
+
441
+ return text_embeddings
442
+
443
+
444
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
445
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
446
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
447
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
448
+ for i in range(iterations):
449
+ r = random.random() * 2 + 2 # Rather than always going 2x,
450
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
451
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
452
+ if wn == 1 or hn == 1:
453
+ break # Lowest resolution is 1x1
454
+ return noise / noise.std() # Scaled back to roughly unit variance
455
+
456
+
457
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
458
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
459
+ if noise_offset is None:
460
+ return noise
461
+ if adaptive_noise_scale is not None:
462
+ # latent shape: (batch_size, channels, height, width)
463
+ # abs mean value for each channel
464
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
465
+
466
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
467
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
468
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
469
+
470
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
471
+ return noise
472
+
473
+
474
+ """
475
+ ##########################################
476
+ # Perlin Noise
477
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
478
+ delta = (res[0] / shape[0], res[1] / shape[1])
479
+ d = (shape[0] // res[0], shape[1] // res[1])
480
+
481
+ grid = (
482
+ torch.stack(
483
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
484
+ dim=-1,
485
+ )
486
+ % 1
487
+ )
488
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
489
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
490
+
491
+ tile_grads = (
492
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
493
+ .repeat_interleave(d[0], 0)
494
+ .repeat_interleave(d[1], 1)
495
+ )
496
+ dot = lambda grad, shift: (
497
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
498
+ * grad[: shape[0], : shape[1]]
499
+ ).sum(dim=-1)
500
+
501
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
502
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
503
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
504
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
505
+ t = fade(grid[: shape[0], : shape[1]])
506
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
507
+
508
+
509
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
510
+ noise = torch.zeros(shape, device=device)
511
+ frequency = 1
512
+ amplitude = 1
513
+ for _ in range(octaves):
514
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
515
+ frequency *= 2
516
+ amplitude *= persistence
517
+ return noise
518
+
519
+
520
+ def perlin_noise(noise, device, octaves):
521
+ _, c, w, h = noise.shape
522
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
523
+ noise_perlin = []
524
+ for _ in range(c):
525
+ noise_perlin.append(perlin())
526
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
527
+ noise += noise_perlin # broadcast for each batch
528
+ return noise / noise.std() # Scaled back to roughly unit variance
529
+ """
external/midas/__init__.py CHANGED
@@ -1,39 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
- from einops import rearrange
5
-
6
- from .api import MiDaSInference
7
-
8
- model = None
9
-
10
-
11
- def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
12
- global model
13
- if not model:
14
- model = MiDaSInference(model_type="dpt_hybrid").cuda()
15
- assert input_image.ndim == 3
16
- image_depth = input_image
17
- with torch.no_grad():
18
- image_depth = torch.from_numpy(image_depth).float().cuda()
19
- image_depth = image_depth / 127.5 - 1.0
20
- image_depth = rearrange(image_depth, "h w c -> 1 c h w")
21
- depth = model(image_depth)[0]
22
-
23
- depth_pt = depth.clone()
24
- depth_pt -= torch.min(depth_pt)
25
- depth_pt /= torch.max(depth_pt)
26
- depth_pt = depth_pt.cpu().numpy()
27
- depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
28
-
29
- depth_np = depth.cpu().numpy()
30
- x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
31
- y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
32
- z = np.ones_like(x) * a
33
- x[depth_pt < bg_th] = 0
34
- y[depth_pt < bg_th] = 0
35
- normal = np.stack([x, y, z], axis=2)
36
- normal /= np.sum(normal**2.0, axis=2, keepdims=True) ** 0.5
37
- normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
38
-
39
- return depth_image, normal_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
external/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
external/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
external/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
external/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
external/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
external/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
external/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
external/realesrgan/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
6
+ #from .version import *
external/realesrgan/archs/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import arch modules for registry
6
+ # scan all the files that end with '_arch.py' under the archs folder
7
+ arch_folder = osp.dirname(osp.abspath(__file__))
8
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9
+ # import all the arch modules
10
+ _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
external/realesrgan/archs/discriminator_arch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils import spectral_norm
5
+
6
+
7
+ @ARCH_REGISTRY.register()
8
+ class UNetDiscriminatorSN(nn.Module):
9
+ """Defines a U-Net discriminator with spectral normalization (SN)
10
+
11
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12
+
13
+ Arg:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_feat (int): Channel number of base intermediate features. Default: 64.
16
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20
+ super(UNetDiscriminatorSN, self).__init__()
21
+ self.skip_connection = skip_connection
22
+ norm = spectral_norm
23
+ # the first convolution
24
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25
+ # downsample
26
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29
+ # upsample
30
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33
+ # extra convolutions
34
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37
+
38
+ def forward(self, x):
39
+ # downsample
40
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
43
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
44
+
45
+ # upsample
46
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
47
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
48
+
49
+ if self.skip_connection:
50
+ x4 = x4 + x2
51
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
52
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
53
+
54
+ if self.skip_connection:
55
+ x5 = x5 + x1
56
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
57
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
58
+
59
+ if self.skip_connection:
60
+ x6 = x6 + x0
61
+
62
+ # extra convolutions
63
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65
+ out = self.conv9(out)
66
+
67
+ return out
external/realesrgan/archs/srvgg_arch.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ @ARCH_REGISTRY.register()
7
+ class SRVGGNetCompact(nn.Module):
8
+ """A compact VGG-style network structure for super-resolution.
9
+
10
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
11
+ conducted on the HR feature space.
12
+
13
+ Args:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_out_ch (int): Channel number of outputs. Default: 3.
16
+ num_feat (int): Channel number of intermediate features. Default: 64.
17
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
18
+ upscale (int): Upsampling factor. Default: 4.
19
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
20
+ """
21
+
22
+ def __init__(self, num_in_ch = 3, num_out_ch = 3, num_feat = 64, num_conv = 16, upscale = 4, act_type = 'prelu'):
23
+ super(SRVGGNetCompact, self).__init__()
24
+ self.num_in_ch = num_in_ch
25
+ self.num_out_ch = num_out_ch
26
+ self.num_feat = num_feat
27
+ self.num_conv = num_conv
28
+ self.upscale = upscale
29
+ self.act_type = act_type
30
+
31
+ self.body = nn.ModuleList()
32
+ # the first conv
33
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
34
+ # the first activation
35
+ if act_type == 'relu':
36
+ activation = nn.ReLU(inplace = True)
37
+ elif act_type == 'prelu':
38
+ activation = nn.PReLU(num_parameters = num_feat)
39
+ elif act_type == 'leakyrelu':
40
+ activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True)
41
+ self.body.append(activation)
42
+
43
+ # the body structure
44
+ for _ in range(num_conv):
45
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
46
+ # activation
47
+ if act_type == 'relu':
48
+ activation = nn.ReLU(inplace = True)
49
+ elif act_type == 'prelu':
50
+ activation = nn.PReLU(num_parameters = num_feat)
51
+ elif act_type == 'leakyrelu':
52
+ activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True)
53
+ self.body.append(activation)
54
+
55
+ # the last conv
56
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
57
+ # upsample
58
+ self.upsampler = nn.PixelShuffle(upscale)
59
+
60
+ def forward(self, x):
61
+ out = x
62
+ for i in range(0, len(self.body)):
63
+ out = self.body[i](out)
64
+
65
+ out = self.upsampler(out)
66
+ # add the nearest upsampled image, so that the network learns the residual
67
+ base = F.interpolate(x, scale_factor = self.upscale, mode = 'nearest')
68
+ out += base
69
+ return out
external/realesrgan/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import dataset modules for registry
6
+ # scan all the files that end with '_dataset.py' under the data folder
7
+ data_folder = osp.dirname(osp.abspath(__file__))
8
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9
+ # import all the dataset modules
10
+ _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
external/realesrgan/data/realesrgan_dataset.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ import torch
9
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
10
+ from basicsr.data.transforms import augment
11
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
12
+ from basicsr.utils.registry import DATASET_REGISTRY
13
+ from torch.utils import data as data
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class RealESRGANDataset(data.Dataset):
18
+ """Dataset used for Real-ESRGAN model:
19
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
20
+
21
+ It loads gt (Ground-Truth) images, and augments them.
22
+ It also generates blur kernels and sinc kernels for generating low-quality images.
23
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ use_hflip (bool): Use horizontal flips.
31
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
32
+ Please see more options in the codes.
33
+ """
34
+
35
+ def __init__(self, opt):
36
+ super(RealESRGANDataset, self).__init__()
37
+ self.opt = opt
38
+ self.file_client = None
39
+ self.io_backend_opt = opt['io_backend']
40
+ self.gt_folder = opt['dataroot_gt']
41
+
42
+ # file client (lmdb io backend)
43
+ if self.io_backend_opt['type'] == 'lmdb':
44
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
45
+ self.io_backend_opt['client_keys'] = ['gt']
46
+ if not self.gt_folder.endswith('.lmdb'):
47
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
48
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
49
+ self.paths = [line.split('.')[0] for line in fin]
50
+ else:
51
+ # disk backend with meta_info
52
+ # Each line in the meta_info describes the relative path to an image
53
+ with open(self.opt['meta_info']) as fin:
54
+ paths = [line.strip().split(' ')[0] for line in fin]
55
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
56
+
57
+ # blur settings for the first degradation
58
+ self.blur_kernel_size = opt['blur_kernel_size']
59
+ self.kernel_list = opt['kernel_list']
60
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
61
+ self.blur_sigma = opt['blur_sigma']
62
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
63
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
64
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
65
+
66
+ # blur settings for the second degradation
67
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
68
+ self.kernel_list2 = opt['kernel_list2']
69
+ self.kernel_prob2 = opt['kernel_prob2']
70
+ self.blur_sigma2 = opt['blur_sigma2']
71
+ self.betag_range2 = opt['betag_range2']
72
+ self.betap_range2 = opt['betap_range2']
73
+ self.sinc_prob2 = opt['sinc_prob2']
74
+
75
+ # a final sinc filter
76
+ self.final_sinc_prob = opt['final_sinc_prob']
77
+
78
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
79
+ # TODO: kernel range is now hard-coded, should be in the configure file
80
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
81
+ self.pulse_tensor[10, 10] = 1
82
+
83
+ def __getitem__(self, index):
84
+ if self.file_client is None:
85
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
86
+
87
+ # -------------------------------- Load gt images -------------------------------- #
88
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
89
+ gt_path = self.paths[index]
90
+ # avoid errors caused by high latency in reading files
91
+ retry = 3
92
+ while retry > 0:
93
+ try:
94
+ img_bytes = self.file_client.get(gt_path, 'gt')
95
+ except (IOError, OSError) as e:
96
+ logger = get_root_logger()
97
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
98
+ # change another file to read
99
+ index = random.randint(0, self.__len__())
100
+ gt_path = self.paths[index]
101
+ time.sleep(1) # sleep 1s for occasional server congestion
102
+ else:
103
+ break
104
+ finally:
105
+ retry -= 1
106
+ img_gt = imfrombytes(img_bytes, float32=True)
107
+
108
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
109
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
110
+
111
+ # crop or pad to 400
112
+ # TODO: 400 is hard-coded. You may change it accordingly
113
+ h, w = img_gt.shape[0:2]
114
+ crop_pad_size = 400
115
+ # pad
116
+ if h < crop_pad_size or w < crop_pad_size:
117
+ pad_h = max(0, crop_pad_size - h)
118
+ pad_w = max(0, crop_pad_size - w)
119
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
120
+ # crop
121
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
122
+ h, w = img_gt.shape[0:2]
123
+ # randomly choose top and left coordinates
124
+ top = random.randint(0, h - crop_pad_size)
125
+ left = random.randint(0, w - crop_pad_size)
126
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
127
+
128
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
129
+ kernel_size = random.choice(self.kernel_range)
130
+ if np.random.uniform() < self.opt['sinc_prob']:
131
+ # this sinc filter setting is for kernels ranging from [7, 21]
132
+ if kernel_size < 13:
133
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
134
+ else:
135
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
136
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
137
+ else:
138
+ kernel = random_mixed_kernels(
139
+ self.kernel_list,
140
+ self.kernel_prob,
141
+ kernel_size,
142
+ self.blur_sigma,
143
+ self.blur_sigma, [-math.pi, math.pi],
144
+ self.betag_range,
145
+ self.betap_range,
146
+ noise_range=None)
147
+ # pad kernel
148
+ pad_size = (21 - kernel_size) // 2
149
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
150
+
151
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
152
+ kernel_size = random.choice(self.kernel_range)
153
+ if np.random.uniform() < self.opt['sinc_prob2']:
154
+ if kernel_size < 13:
155
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
156
+ else:
157
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
158
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
159
+ else:
160
+ kernel2 = random_mixed_kernels(
161
+ self.kernel_list2,
162
+ self.kernel_prob2,
163
+ kernel_size,
164
+ self.blur_sigma2,
165
+ self.blur_sigma2, [-math.pi, math.pi],
166
+ self.betag_range2,
167
+ self.betap_range2,
168
+ noise_range=None)
169
+
170
+ # pad kernel
171
+ pad_size = (21 - kernel_size) // 2
172
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
173
+
174
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
175
+ if np.random.uniform() < self.opt['final_sinc_prob']:
176
+ kernel_size = random.choice(self.kernel_range)
177
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
178
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
179
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
180
+ else:
181
+ sinc_kernel = self.pulse_tensor
182
+
183
+ # BGR to RGB, HWC to CHW, numpy to tensor
184
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
185
+ kernel = torch.FloatTensor(kernel)
186
+ kernel2 = torch.FloatTensor(kernel2)
187
+
188
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
189
+ return return_d
190
+
191
+ def __len__(self):
192
+ return len(self.paths)
external/realesrgan/data/realesrgan_paired_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
3
+ from basicsr.data.transforms import augment, paired_random_crop
4
+ from basicsr.utils import FileClient
5
+ from basicsr.utils.img_util import imfrombytes, img2tensor
6
+ from basicsr.utils.registry import DATASET_REGISTRY
7
+ from torch.utils import data as data
8
+ from torchvision.transforms.functional import normalize
9
+
10
+
11
+ @DATASET_REGISTRY.register()
12
+ class RealESRGANPairedDataset(data.Dataset):
13
+ """Paired image dataset for image restoration.
14
+
15
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
16
+
17
+ There are three modes:
18
+ 1. 'lmdb': Use lmdb files.
19
+ If opt['io_backend'] == lmdb.
20
+ 2. 'meta_info': Use meta information file to generate paths.
21
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
22
+ 3. 'folder': Scan folders to generate paths.
23
+ The rest.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ dataroot_lq (str): Data root path for lq.
29
+ meta_info (str): Path for meta information file.
30
+ io_backend (dict): IO backend type and other kwarg.
31
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
32
+ Default: '{}'.
33
+ gt_size (int): Cropped patched size for gt patches.
34
+ use_hflip (bool): Use horizontal flips.
35
+ use_rot (bool): Use rotation (use vertical flip and transposing h
36
+ and w for implementation).
37
+
38
+ scale (bool): Scale, which will be added automatically.
39
+ phase (str): 'train' or 'val'.
40
+ """
41
+
42
+ def __init__(self, opt):
43
+ super(RealESRGANPairedDataset, self).__init__()
44
+ self.opt = opt
45
+ self.file_client = None
46
+ self.io_backend_opt = opt['io_backend']
47
+ # mean and std for normalizing the input images
48
+ self.mean = opt['mean'] if 'mean' in opt else None
49
+ self.std = opt['std'] if 'std' in opt else None
50
+
51
+ in_channels = opt['in_channels'] if 'in_channels' in opt else 3
52
+ if in_channels == 1:
53
+ self.flag = 'grayscale'
54
+ elif in_channels == 3:
55
+ self.flag = 'color'
56
+ else:
57
+ self.flag = 'unchanged'
58
+
59
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
60
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
61
+
62
+ # file client (lmdb io backend)
63
+ if self.io_backend_opt['type'] == 'lmdb':
64
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
65
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
66
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
67
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
68
+ # disk backend with meta_info
69
+ # Each line in the meta_info describes the relative path to an image
70
+ with open(self.opt['meta_info']) as fin:
71
+ paths = [line.strip() for line in fin]
72
+ self.paths = []
73
+ for path in paths:
74
+ gt_path, lq_path = path.split(', ')
75
+ gt_path = os.path.join(self.gt_folder, gt_path)
76
+ lq_path = os.path.join(self.lq_folder, lq_path)
77
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
78
+ else:
79
+ # disk backend
80
+ # it will scan the whole folder to get meta info
81
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
82
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
83
+
84
+ def __getitem__(self, index):
85
+ if self.file_client is None:
86
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87
+
88
+ scale = self.opt['scale']
89
+
90
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
91
+ # image range: [0, 1], float32.
92
+ gt_path = self.paths[index]['gt_path']
93
+ img_bytes = self.file_client.get(gt_path, 'gt')
94
+ img_gt = imfrombytes(img_bytes, flag = self.flag, float32=True)
95
+ lq_path = self.paths[index]['lq_path']
96
+ img_bytes = self.file_client.get(lq_path, 'lq')
97
+ img_lq = imfrombytes(img_bytes, flag = self.flag, float32=True)
98
+
99
+ # augmentation for training
100
+ if self.opt['phase'] == 'train':
101
+ gt_size = self.opt['gt_size']
102
+ # random crop
103
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
104
+ # flip, rotation
105
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
106
+
107
+ # BGR to RGB, HWC to CHW, numpy to tensor
108
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
109
+ # normalize
110
+ if self.mean is not None or self.std is not None:
111
+ normalize(img_lq, self.mean, self.std, inplace=True)
112
+ normalize(img_gt, self.mean, self.std, inplace=True)
113
+
114
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
115
+
116
+ def __len__(self):
117
+ return len(self.paths)
external/realesrgan/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import model modules for registry
6
+ # scan all the files that end with '_model.py' under the model folder
7
+ model_folder = osp.dirname(osp.abspath(__file__))
8
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9
+ # import all the model modules
10
+ _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
external/realesrgan/models/realesrgan_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.srgan_model import SRGANModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from collections import OrderedDict
11
+ from torch.nn import functional as F
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class RealESRGANModel(SRGANModel):
16
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
17
+
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRGANModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ self.gt_usm = self.usm_sharpener(self.gt)
74
+
75
+ self.kernel1 = data['kernel1'].to(self.device)
76
+ self.kernel2 = data['kernel2'].to(self.device)
77
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
78
+
79
+ ori_h, ori_w = self.gt.size()[2:4]
80
+
81
+ # ----------------------- The first degradation process ----------------------- #
82
+ # blur
83
+ out = filter2D(self.gt_usm, self.kernel1)
84
+ # random resize
85
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
86
+ if updown_type == 'up':
87
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
88
+ elif updown_type == 'down':
89
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
90
+ else:
91
+ scale = 1
92
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
93
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
94
+ # add noise
95
+ gray_noise_prob = self.opt['gray_noise_prob']
96
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
97
+ out = random_add_gaussian_noise_pt(
98
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
99
+ else:
100
+ out = random_add_poisson_noise_pt(
101
+ out,
102
+ scale_range=self.opt['poisson_scale_range'],
103
+ gray_prob=gray_noise_prob,
104
+ clip=True,
105
+ rounds=False)
106
+ # JPEG compression
107
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
108
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
109
+ out = self.jpeger(out, quality=jpeg_p)
110
+
111
+ # ----------------------- The second degradation process ----------------------- #
112
+ # blur
113
+ if np.random.uniform() < self.opt['second_blur_prob']:
114
+ out = filter2D(out, self.kernel2)
115
+ # random resize
116
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
117
+ if updown_type == 'up':
118
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
119
+ elif updown_type == 'down':
120
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
121
+ else:
122
+ scale = 1
123
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
124
+ out = F.interpolate(
125
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
126
+ # add noise
127
+ gray_noise_prob = self.opt['gray_noise_prob2']
128
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
129
+ out = random_add_gaussian_noise_pt(
130
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
131
+ else:
132
+ out = random_add_poisson_noise_pt(
133
+ out,
134
+ scale_range=self.opt['poisson_scale_range2'],
135
+ gray_prob=gray_noise_prob,
136
+ clip=True,
137
+ rounds=False)
138
+
139
+ # JPEG compression + the final sinc filter
140
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
141
+ # as one operation.
142
+ # We consider two orders:
143
+ # 1. [resize back + sinc filter] + JPEG compression
144
+ # 2. JPEG compression + [resize back + sinc filter]
145
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
146
+ if np.random.uniform() < 0.5:
147
+ # resize back + the final sinc filter
148
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
149
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
150
+ out = filter2D(out, self.sinc_kernel)
151
+ # JPEG compression
152
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
153
+ out = torch.clamp(out, 0, 1)
154
+ out = self.jpeger(out, quality=jpeg_p)
155
+ else:
156
+ # JPEG compression
157
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
158
+ out = torch.clamp(out, 0, 1)
159
+ out = self.jpeger(out, quality=jpeg_p)
160
+ # resize back + the final sinc filter
161
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
162
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
163
+ out = filter2D(out, self.sinc_kernel)
164
+
165
+ # clamp and round
166
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
167
+
168
+ # random crop
169
+ gt_size = self.opt['gt_size']
170
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
171
+ self.opt['scale'])
172
+
173
+ # training pair pool
174
+ self._dequeue_and_enqueue()
175
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
176
+ self.gt_usm = self.usm_sharpener(self.gt)
177
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
178
+ else:
179
+ # for paired training or validation
180
+ self.lq = data['lq'].to(self.device)
181
+ if 'gt' in data:
182
+ self.gt = data['gt'].to(self.device)
183
+ self.gt_usm = self.usm_sharpener(self.gt)
184
+
185
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
186
+ # do not use the synthetic process during validation
187
+ self.is_train = False
188
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
189
+ self.is_train = True
190
+
191
+ def optimize_parameters(self, current_iter):
192
+ # usm sharpening
193
+ l1_gt = self.gt_usm
194
+ percep_gt = self.gt_usm
195
+ gan_gt = self.gt_usm
196
+ if self.opt['l1_gt_usm'] is False:
197
+ l1_gt = self.gt
198
+ if self.opt['percep_gt_usm'] is False:
199
+ percep_gt = self.gt
200
+ if self.opt['gan_gt_usm'] is False:
201
+ gan_gt = self.gt
202
+
203
+ # optimize net_g
204
+ for p in self.net_d.parameters():
205
+ p.requires_grad = False
206
+
207
+ self.optimizer_g.zero_grad()
208
+ self.output = self.net_g(self.lq)
209
+
210
+ l_g_total = 0
211
+ loss_dict = OrderedDict()
212
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
213
+ # pixel loss
214
+ if self.cri_pix:
215
+ l_g_pix = self.cri_pix(self.output, l1_gt)
216
+ l_g_total += l_g_pix
217
+ loss_dict['l_g_pix'] = l_g_pix
218
+ # perceptual loss
219
+ if self.cri_perceptual:
220
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
221
+ if l_g_percep is not None:
222
+ l_g_total += l_g_percep
223
+ loss_dict['l_g_percep'] = l_g_percep
224
+ if l_g_style is not None:
225
+ l_g_total += l_g_style
226
+ loss_dict['l_g_style'] = l_g_style
227
+ # gan loss
228
+ fake_g_pred = self.net_d(self.output)
229
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
230
+ l_g_total += l_g_gan
231
+ loss_dict['l_g_gan'] = l_g_gan
232
+
233
+ l_g_total.backward()
234
+ self.optimizer_g.step()
235
+
236
+ # optimize net_d
237
+ for p in self.net_d.parameters():
238
+ p.requires_grad = True
239
+
240
+ self.optimizer_d.zero_grad()
241
+ # real
242
+ real_d_pred = self.net_d(gan_gt)
243
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
244
+ loss_dict['l_d_real'] = l_d_real
245
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
246
+ l_d_real.backward()
247
+ # fake
248
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
249
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
250
+ loss_dict['l_d_fake'] = l_d_fake
251
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
252
+ l_d_fake.backward()
253
+ self.optimizer_d.step()
254
+
255
+ if self.ema_decay > 0:
256
+ self.model_ema(decay=self.ema_decay)
257
+
258
+ self.log_dict = self.reduce_loss_dict(loss_dict)
external/realesrgan/models/realesrnet_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.sr_model import SRModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from torch.nn import functional as F
11
+
12
+
13
+ @MODEL_REGISTRY.register()
14
+ class RealESRNetModel(SRModel):
15
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
16
+
17
+ It is trained without GAN losses.
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRNetModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ # USM sharpen the GT images
74
+ if self.opt['gt_usm'] is True:
75
+ self.gt = self.usm_sharpener(self.gt)
76
+
77
+ self.kernel1 = data['kernel1'].to(self.device)
78
+ self.kernel2 = data['kernel2'].to(self.device)
79
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
80
+
81
+ ori_h, ori_w = self.gt.size()[2:4]
82
+
83
+ # ----------------------- The first degradation process ----------------------- #
84
+ # blur
85
+ out = filter2D(self.gt, self.kernel1)
86
+ # random resize
87
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
88
+ if updown_type == 'up':
89
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
90
+ elif updown_type == 'down':
91
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
92
+ else:
93
+ scale = 1
94
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
95
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
96
+ # add noise
97
+ gray_noise_prob = self.opt['gray_noise_prob']
98
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
99
+ out = random_add_gaussian_noise_pt(
100
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
101
+ else:
102
+ out = random_add_poisson_noise_pt(
103
+ out,
104
+ scale_range=self.opt['poisson_scale_range'],
105
+ gray_prob=gray_noise_prob,
106
+ clip=True,
107
+ rounds=False)
108
+ # JPEG compression
109
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
110
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
111
+ out = self.jpeger(out, quality=jpeg_p)
112
+
113
+ # ----------------------- The second degradation process ----------------------- #
114
+ # blur
115
+ if np.random.uniform() < self.opt['second_blur_prob']:
116
+ out = filter2D(out, self.kernel2)
117
+ # random resize
118
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
119
+ if updown_type == 'up':
120
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
121
+ elif updown_type == 'down':
122
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
123
+ else:
124
+ scale = 1
125
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
126
+ out = F.interpolate(
127
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
128
+ # add noise
129
+ gray_noise_prob = self.opt['gray_noise_prob2']
130
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
131
+ out = random_add_gaussian_noise_pt(
132
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
133
+ else:
134
+ out = random_add_poisson_noise_pt(
135
+ out,
136
+ scale_range=self.opt['poisson_scale_range2'],
137
+ gray_prob=gray_noise_prob,
138
+ clip=True,
139
+ rounds=False)
140
+
141
+ # JPEG compression + the final sinc filter
142
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
143
+ # as one operation.
144
+ # We consider two orders:
145
+ # 1. [resize back + sinc filter] + JPEG compression
146
+ # 2. JPEG compression + [resize back + sinc filter]
147
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
148
+ if np.random.uniform() < 0.5:
149
+ # resize back + the final sinc filter
150
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
151
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
152
+ out = filter2D(out, self.sinc_kernel)
153
+ # JPEG compression
154
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
155
+ out = torch.clamp(out, 0, 1)
156
+ out = self.jpeger(out, quality=jpeg_p)
157
+ else:
158
+ # JPEG compression
159
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
160
+ out = torch.clamp(out, 0, 1)
161
+ out = self.jpeger(out, quality=jpeg_p)
162
+ # resize back + the final sinc filter
163
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
164
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
165
+ out = filter2D(out, self.sinc_kernel)
166
+
167
+ # clamp and round
168
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
169
+
170
+ # random crop
171
+ gt_size = self.opt['gt_size']
172
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
173
+
174
+ # training pair pool
175
+ self._dequeue_and_enqueue()
176
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
177
+ else:
178
+ # for paired training or validation
179
+ self.lq = data['lq'].to(self.device)
180
+ if 'gt' in data:
181
+ self.gt = data['gt'].to(self.device)
182
+ self.gt_usm = self.usm_sharpener(self.gt)
183
+
184
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185
+ # do not use the synthetic process during validation
186
+ self.is_train = False
187
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
188
+ self.is_train = True
external/realesrgan/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import realesrgan.archs
6
+ import realesrgan.data
7
+ import realesrgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
external/realesrgan/utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import queue
6
+ import threading
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from torch.nn import functional as F
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
+
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self, scale, model_path, dni_weight = None, model = None, tile = 0, tile_pad = 10, pre_pad = 10, half = False, device = None, gpu_id = None):
30
+ self.scale = scale
31
+ self.tile_size = tile
32
+ self.tile_pad = tile_pad
33
+ self.pre_pad = pre_pad
34
+ self.mod_scale = None
35
+ self.half = half
36
+
37
+ # initialize model
38
+ if gpu_id:
39
+ self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
40
+ else:
41
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
42
+
43
+ if isinstance(model_path, list):
44
+ # dni
45
+ assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
46
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
47
+ else:
48
+ # if the model_path starts with https, it will first download models to the folder: weights
49
+ if model_path.startswith('https://'):
50
+ model_path = load_file_from_url(url = model_path, model_dir = os.path.join(ROOT_DIR, 'weights'), progress = True, file_name = None)
51
+ loadnet = torch.load(model_path, map_location = torch.device('cpu'))
52
+
53
+ # prefer to use params_ema
54
+ if 'params_ema' in loadnet:
55
+ keyname = 'params_ema'
56
+ else:
57
+ keyname = 'params'
58
+ model.load_state_dict(loadnet[keyname], strict=True)
59
+
60
+ model.eval()
61
+ self.model = model.to(self.device)
62
+ if self.half:
63
+ self.model = self.model.half()
64
+
65
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
66
+ """Deep network interpolation.
67
+
68
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
69
+ """
70
+ net_a = torch.load(net_a, map_location = torch.device(loc))
71
+ net_b = torch.load(net_b, map_location = torch.device(loc))
72
+ for k, v_a in net_a[key].items():
73
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
74
+ return net_a
75
+
76
+ def pre_process(self, img):
77
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
78
+ """
79
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
80
+ self.img = img.unsqueeze(0).to(self.device)
81
+ if self.half:
82
+ self.img = self.img.half()
83
+
84
+ # pre_pad
85
+ if self.pre_pad != 0:
86
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
87
+ # mod pad for divisible borders
88
+ if self.scale == 2:
89
+ self.mod_scale = 2
90
+ elif self.scale == 1:
91
+ self.mod_scale = 4
92
+ if self.mod_scale is not None:
93
+ self.mod_pad_h, self.mod_pad_w = 0, 0
94
+ _, _, h, w = self.img.size()
95
+ if (h % self.mod_scale != 0):
96
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
97
+ if (w % self.mod_scale != 0):
98
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
99
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
100
+
101
+ def process(self):
102
+ # model inference
103
+ self.output = self.model(self.img)
104
+
105
+ def tile_process(self):
106
+ """It will first crop input images to tiles, and then process each tile.
107
+ Finally, all the processed tiles are merged into one images.
108
+
109
+ Modified from: https://github.com/ata4/esrgan-launcher
110
+ """
111
+ batch, channel, height, width = self.img.shape
112
+ output_height = height * self.scale
113
+ output_width = width * self.scale
114
+ output_shape = (batch, channel, output_height, output_width)
115
+
116
+ # start with black image
117
+ self.output = self.img.new_zeros(output_shape)
118
+ tiles_x = math.ceil(width / self.tile_size)
119
+ tiles_y = math.ceil(height / self.tile_size)
120
+
121
+ # loop over all tiles
122
+ for y in range(tiles_y):
123
+ for x in range(tiles_x):
124
+ # extract tile from input image
125
+ ofs_x = x * self.tile_size
126
+ ofs_y = y * self.tile_size
127
+ # input tile area on total image
128
+ input_start_x = ofs_x
129
+ input_end_x = min(ofs_x + self.tile_size, width)
130
+ input_start_y = ofs_y
131
+ input_end_y = min(ofs_y + self.tile_size, height)
132
+
133
+ # input tile area on total image with padding
134
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
135
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
136
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
137
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
138
+
139
+ # input tile dimensions
140
+ input_tile_width = input_end_x - input_start_x
141
+ input_tile_height = input_end_y - input_start_y
142
+ tile_idx = y * tiles_x + x + 1
143
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
144
+
145
+ # upscale tile
146
+ try:
147
+ with torch.no_grad():
148
+ output_tile = self.model(input_tile)
149
+ except RuntimeError as error:
150
+ print('Error', error)
151
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
152
+
153
+ # output tile area on total image
154
+ output_start_x = input_start_x * self.scale
155
+ output_end_x = input_end_x * self.scale
156
+ output_start_y = input_start_y * self.scale
157
+ output_end_y = input_end_y * self.scale
158
+
159
+ # output tile area without padding
160
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
161
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
162
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
163
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
164
+
165
+ # put tile into output image
166
+ self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile]
167
+
168
+ def post_process(self):
169
+ # remove extra pad
170
+ if self.mod_scale is not None:
171
+ _, _, h, w = self.output.size()
172
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
173
+ # remove prepad
174
+ if self.pre_pad != 0:
175
+ _, _, h, w = self.output.size()
176
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
177
+ return self.output
178
+
179
+ @torch.no_grad()
180
+ def enhance(self, img, outscale = None, num_out_ch = 3, alpha_upsampler = 'realesrgan'):
181
+ h_input, w_input = img.shape[0:2]
182
+ # img: numpy
183
+ img = img.astype(np.float32)
184
+ if np.max(img) > 256: # 16-bit image
185
+ max_range = 65535
186
+ print('\tInput is a 16-bit image')
187
+ else:
188
+ max_range = 255
189
+ img = img / max_range
190
+ if len(img.shape) == 2: # gray image
191
+ img_mode = 'L'
192
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
193
+ elif img.shape[2] == 4: # RGBA image with alpha channel
194
+ img_mode = 'RGBA'
195
+ if num_out_ch != 3:
196
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
197
+ else:
198
+ alpha = img[:, :, 3]
199
+ img = img[:, :, 0:3]
200
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
201
+ if alpha_upsampler == 'realesrgan':
202
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
203
+ else:
204
+ img_mode = 'RGB'
205
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
206
+
207
+ # ------------------- process image (without the alpha channel) ------------------- #
208
+ self.pre_process(img)
209
+ if self.tile_size > 0:
210
+ self.tile_process()
211
+ else:
212
+ self.process()
213
+ output_img = self.post_process()
214
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
215
+ img_struct_list = []
216
+ for i in range(3, num_out_ch):
217
+ img_struct_list.append(i)
218
+ output_img = output_img[[2, 1, 0] + img_struct_list, :, :]
219
+ output_img = np.transpose(output_img, (1, 2, 0))
220
+ if img_mode == 'L':
221
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
222
+
223
+ # ------------------- process the alpha channel if necessary ------------------- #
224
+ if img_mode == 'RGBA' and num_out_ch == 3:
225
+ if alpha_upsampler == 'realesrgan':
226
+ self.pre_process(alpha)
227
+ if self.tile_size > 0:
228
+ self.tile_process()
229
+ else:
230
+ self.process()
231
+ output_alpha = self.post_process()
232
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
233
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
234
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
235
+ else: # use the cv2 resize for alpha channel
236
+ h, w = alpha.shape[0:2]
237
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
238
+
239
+ # merge the alpha channel
240
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
241
+ output_img[:, :, 3] = output_alpha
242
+
243
+ # ------------------------------ return ------------------------------ #
244
+ if max_range == 65535: # 16-bit image
245
+ output = (output_img * 65535.0).round().astype(np.uint16)
246
+ else:
247
+ output = (output_img * 255.0).round().astype(np.uint8)
248
+
249
+ if outscale is not None and outscale != float(self.scale):
250
+ output = cv2.resize(output, (int(w_input * outscale), int(h_input * outscale)), interpolation = cv2.INTER_LANCZOS4)
251
+
252
+ return output, img_mode
253
+
254
+
255
+ class PrefetchReader(threading.Thread):
256
+ """Prefetch images.
257
+
258
+ Args:
259
+ img_list (list[str]): A image list of image paths to be read.
260
+ num_prefetch_queue (int): Number of prefetch queue.
261
+ """
262
+
263
+ def __init__(self, img_list, num_prefetch_queue):
264
+ super().__init__()
265
+ self.que = queue.Queue(num_prefetch_queue)
266
+ self.img_list = img_list
267
+
268
+ def run(self):
269
+ for img_path in self.img_list:
270
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
271
+ self.que.put(img)
272
+
273
+ self.que.put(None)
274
+
275
+ def __next__(self):
276
+ next_item = self.que.get()
277
+ if next_item is None:
278
+ raise StopIteration
279
+ return next_item
280
+
281
+ def __iter__(self):
282
+ return self
283
+
284
+
285
+ class IOConsumer(threading.Thread):
286
+
287
+ def __init__(self, opt, que, qid):
288
+ super().__init__()
289
+ self._queue = que
290
+ self.qid = qid
291
+ self.opt = opt
292
+
293
+ def run(self):
294
+ while True:
295
+ msg = self._queue.get()
296
+ if isinstance(msg, str) and msg == 'quit':
297
+ break
298
+
299
+ output = msg['output']
300
+ save_path = msg['save_path']
301
+ cv2.imwrite(save_path, output)
302
+ print(f'IO worker {self.qid} is done.')
handler.py CHANGED
@@ -1,5 +1,10 @@
1
- import json
2
  import os
 
 
 
 
 
 
3
  from pathlib import Path
4
  from typing import Any, Dict, List
5
 
 
 
1
  import os
2
+ import sys
3
+
4
+ path = os.path.dirname(os.path.abspath(__file__))
5
+ sys.path.insert(1, os.path.join(path, "external"))
6
+
7
+
8
  from pathlib import Path
9
  from typing import Any, Dict, List
10
 
inference.py CHANGED
@@ -17,10 +17,9 @@ from internals.pipelines.img_classifier import ImageClassifier
17
  from internals.pipelines.img_to_text import Image2Text
18
  from internals.pipelines.inpainter import InPainter
19
  from internals.pipelines.object_remove import ObjectRemoval
20
- from internals.pipelines.pose_detector import PoseDetector
21
  from internals.pipelines.prompt_modifier import PromptModifier
22
  from internals.pipelines.realtime_draw import RealtimeDraw
23
- from internals.pipelines.remove_background import RemoveBackgroundV2
24
  from internals.pipelines.replace_background import ReplaceBackground
25
  from internals.pipelines.safety_checker import SafetyChecker
26
  from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
@@ -45,7 +44,6 @@ from internals.util.config import (
45
  set_model_config,
46
  set_root_dir,
47
  )
48
- from internals.util.failure_hander import FailureHandler
49
  from internals.util.lora_style import LoraStyle
50
  from internals.util.model_loader import load_model_from_config
51
  from internals.util.slack import Slack
@@ -57,14 +55,13 @@ auto_mode = False
57
 
58
  prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences())
59
  upscaler = Upscaler()
60
- pose_detector = PoseDetector()
61
  inpainter = InPainter()
62
  high_res = HighRes()
63
  img2text = Image2Text()
64
  img_classifier = ImageClassifier()
65
  object_removal = ObjectRemoval()
66
  replace_background = ReplaceBackground()
67
- remove_background_v2 = RemoveBackgroundV2()
68
  replace_background = ReplaceBackground()
69
  controlnet = ControlNet()
70
  lora_style = LoraStyle()
@@ -92,7 +89,7 @@ def get_patched_prompt_text2img(task: Task):
92
 
93
  def get_patched_prompt_tile_upscale(task: Task):
94
  return prompt_util.get_patched_prompt_tile_upscale(
95
- task, avatar, lora_style, img_classifier, img2text
96
  )
97
 
98
 
@@ -126,20 +123,19 @@ def canny(task: Task):
126
  "num_inference_steps": task.get_steps(),
127
  "width": width,
128
  "height": height,
129
- "negative_prompt": [
130
- f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
131
- ]
132
- * get_num_return_sequences(),
133
  **task.cnc_kwargs(),
134
  **lora_patcher.kwargs(),
135
  }
136
- images, has_nsfw = controlnet.process(**kwargs)
137
  if task.get_high_res_fix():
138
  kwargs = {
139
  "prompt": prompt,
140
  "negative_prompt": [task.get_negative_prompt()]
141
  * get_num_return_sequences(),
142
  "images": images,
 
143
  "width": task.get_width(),
144
  "height": task.get_height(),
145
  "num_inference_steps": task.get_steps(),
@@ -147,6 +143,9 @@ def canny(task: Task):
147
  }
148
  images, _ = high_res.apply(**kwargs)
149
 
 
 
 
150
  generated_image_urls = upload_images(images, "_canny", task.get_taskId())
151
 
152
  lora_patcher.cleanup()
@@ -162,48 +161,102 @@ def canny(task: Task):
162
  @update_db
163
  @auto_clear_cuda_and_gc(controlnet)
164
  @slack.auto_send_alert
165
- def tile_upscale(task: Task):
166
- output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
167
-
168
- prompt = get_patched_prompt_tile_upscale(task)
169
-
170
- if get_is_sdxl():
171
- lora_patcher = lora_style.get_patcher(
172
- [sdxl_tileupscaler.pipe, high_res.pipe], task.get_style()
173
- )
174
- lora_patcher.patch()
175
 
176
- images, has_nsfw = sdxl_tileupscaler.process(
177
- prompt=prompt,
178
- imageUrl=task.get_imageUrl(),
179
- resize_dimension=task.get_resize_dimension(),
180
- negative_prompt=task.get_negative_prompt(),
181
- width=task.get_width(),
182
- height=task.get_height(),
183
- model_id=task.get_model_id(),
184
- )
185
 
186
- lora_patcher.cleanup()
187
- else:
188
- controlnet.load_model("tile_upscaler")
189
 
190
- lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
191
- lora_patcher.patch()
 
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  kwargs = {
194
- "imageUrl": task.get_imageUrl(),
 
 
 
195
  "seed": task.get_seed(),
196
- "num_inference_steps": task.get_steps(),
197
- "negative_prompt": task.get_negative_prompt(),
198
  "width": task.get_width(),
199
  "height": task.get_height(),
200
- "prompt": prompt,
201
- "resize_dimension": task.get_resize_dimension(),
202
- **task.cnt_kwargs(),
203
  }
204
- images, has_nsfw = controlnet.process(**kwargs)
205
- lora_patcher.cleanup()
206
- controlnet.cleanup()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  generated_image_url = upload_image(images[0], output_key)
209
 
@@ -229,12 +282,7 @@ def scribble(task: Task):
229
  )
230
  lora_patcher.patch()
231
 
232
- image = download_image(task.get_imageUrl()).resize((width, height))
233
- if get_is_sdxl():
234
- # We use sketch in SDXL
235
- image = ControlNet.pidinet_image(image)
236
- else:
237
- image = ControlNet.scribble_image(image)
238
 
239
  kwargs = {
240
  "image": [image] * get_num_return_sequences(),
@@ -244,9 +292,10 @@ def scribble(task: Task):
244
  "height": height,
245
  "prompt": prompt,
246
  "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
 
247
  **task.cns_kwargs(),
248
  }
249
- images, has_nsfw = controlnet.process(**kwargs)
250
 
251
  if task.get_high_res_fix():
252
  kwargs = {
@@ -256,11 +305,15 @@ def scribble(task: Task):
256
  "images": images,
257
  "width": task.get_width(),
258
  "height": task.get_height(),
 
259
  "num_inference_steps": task.get_steps(),
260
  **task.high_res_kwargs(),
261
  }
262
  images, _ = high_res.apply(**kwargs)
263
 
 
 
 
264
  generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
265
 
266
  lora_patcher.cleanup()
@@ -296,16 +349,21 @@ def linearart(task: Task):
296
  "height": height,
297
  "prompt": prompt,
298
  "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
 
299
  **task.cnl_kwargs(),
300
  }
301
- images, has_nsfw = controlnet.process(**kwargs)
302
 
303
  if task.get_high_res_fix():
 
 
 
304
  kwargs = {
305
  "prompt": prompt,
306
  "negative_prompt": [task.get_negative_prompt()]
307
  * get_num_return_sequences(),
308
  "images": images,
 
309
  "width": task.get_width(),
310
  "height": task.get_height(),
311
  "num_inference_steps": task.get_steps(),
@@ -313,6 +371,22 @@ def linearart(task: Task):
313
  }
314
  images, _ = high_res.apply(**kwargs)
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
317
 
318
  lora_patcher.cleanup()
@@ -341,20 +415,14 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
341
  )
342
  lora_patcher.patch()
343
 
344
- if not task.get_pose_estimation():
 
 
345
  print("Not detecting pose")
346
  pose = download_image(task.get_imageUrl()).resize(
347
  (task.get_width(), task.get_height())
348
  )
349
  poses = [pose] * get_num_return_sequences()
350
- elif task.get_pose_coordinates():
351
- infered_pose = pose_detector.transform(
352
- image=task.get_imageUrl(),
353
- client_coordinates=task.get_pose_coordinates(),
354
- width=task.get_width(),
355
- height=task.get_height(),
356
- )
357
- poses = [infered_pose] * get_num_return_sequences()
358
  else:
359
  poses = [
360
  controlnet.detect_pose(task.get_imageUrl())
@@ -370,8 +438,11 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
370
 
371
  upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId()))
372
 
 
 
373
  kwargs = {
374
- "control_guidance_end": [0.5, 1.0],
 
375
  }
376
  else:
377
  images = poses[0]
@@ -389,7 +460,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
389
  **task.cnp_kwargs(),
390
  **lora_patcher.kwargs(),
391
  }
392
- images, has_nsfw = controlnet.process(**kwargs)
393
 
394
  if task.get_high_res_fix():
395
  kwargs = {
@@ -400,11 +471,12 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
400
  "width": task.get_width(),
401
  "height": task.get_height(),
402
  "num_inference_steps": task.get_steps(),
 
403
  **task.high_res_kwargs(),
404
  }
405
  images, _ = high_res.apply(**kwargs)
406
 
407
- upload_image(poses[0], "crecoAI/{}_pose.png".format(task.get_taskId()))
408
 
409
  generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
410
 
@@ -431,12 +503,11 @@ def text2img(task: Task):
431
  )
432
  lora_patcher.patch()
433
 
434
- torch.manual_seed(task.get_seed())
435
-
436
  kwargs = {
437
  "params": params,
438
  "num_inference_steps": task.get_steps(),
439
  "height": height,
 
440
  "width": width,
441
  "negative_prompt": task.get_negative_prompt(),
442
  **task.t2i_kwargs(),
@@ -455,6 +526,7 @@ def text2img(task: Task):
455
  "width": task.get_width(),
456
  "height": task.get_height(),
457
  "num_inference_steps": task.get_steps(),
 
458
  **task.high_res_kwargs(),
459
  }
460
  images, _ = high_res.apply(**kwargs)
@@ -478,11 +550,9 @@ def img2img(task: Task):
478
 
479
  width, height = get_intermediate_dimension(task)
480
 
481
- torch.manual_seed(task.get_seed())
482
-
483
  if get_is_sdxl():
484
  # we run lineart for img2img
485
- controlnet.load_model("linearart")
486
 
487
  lora_patcher = lora_style.get_patcher(
488
  [controlnet.pipe2, high_res.pipe], task.get_style()
@@ -498,10 +568,11 @@ def img2img(task: Task):
498
  "prompt": prompt,
499
  "negative_prompt": [task.get_negative_prompt()]
500
  * get_num_return_sequences(),
501
- **task.cnl_kwargs(),
502
- "adapter_conditioning_scale": 0.3,
 
503
  }
504
- images, has_nsfw = controlnet.process(**kwargs)
505
  else:
506
  lora_patcher = lora_style.get_patcher(
507
  [img2img_pipe.pipe, high_res.pipe], task.get_style()
@@ -516,6 +587,7 @@ def img2img(task: Task):
516
  "num_inference_steps": task.get_steps(),
517
  "width": width,
518
  "height": height,
 
519
  **task.i2i_kwargs(),
520
  **lora_patcher.kwargs(),
521
  }
@@ -530,6 +602,7 @@ def img2img(task: Task):
530
  "width": task.get_width(),
531
  "height": task.get_height(),
532
  "num_inference_steps": task.get_steps(),
 
533
  **task.high_res_kwargs(),
534
  }
535
  images, _ = high_res.apply(**kwargs)
@@ -568,7 +641,9 @@ def inpaint(task: Task):
568
  "num_inference_steps": task.get_steps(),
569
  **task.ip_kwargs(),
570
  }
571
- images = inpainter.process(**kwargs)
 
 
572
 
573
  generated_image_urls = upload_images(images, key, task.get_taskId())
574
 
@@ -617,9 +692,7 @@ def replace_bg(task: Task):
617
  @update_db
618
  @slack.auto_send_alert
619
  def remove_bg(task: Task):
620
- output_image = remove_background_v2.remove(
621
- task.get_imageUrl(), model_type=task.get_modelType()
622
- )
623
 
624
  output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
625
  image_url = upload_image(output_image, output_key)
@@ -732,6 +805,67 @@ def rt_draw_img(task: Task):
732
  return {"image": base64_image}
733
 
734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  def custom_action(task: Task):
736
  from external.scripts import __scripts__
737
 
@@ -759,6 +893,14 @@ def custom_action(task: Task):
759
 
760
 
761
  def load_model_by_task(task_type: TaskType, model_id=-1):
 
 
 
 
 
 
 
 
762
  if not text2img_pipe.is_loaded():
763
  text2img_pipe.load(get_model_dir())
764
  img2img_pipe.create(text2img_pipe)
@@ -782,12 +924,14 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
782
  upscaler.load()
783
  else:
784
  if task_type == TaskType.TILE_UPSCALE:
785
- if get_is_sdxl():
786
- sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
787
- else:
788
- controlnet.load_model("tile_upscaler")
789
  elif task_type == TaskType.CANNY:
790
  controlnet.load_model("canny")
 
 
791
  elif task_type == TaskType.SCRIBBLE:
792
  controlnet.load_model("scribble")
793
  elif task_type == TaskType.LINEARART:
@@ -798,23 +942,24 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
798
 
799
  def unload_model_by_task(task_type: TaskType):
800
  if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
801
- inpainter.unload()
 
802
  elif task_type == TaskType.REPLACE_BG:
803
  replace_background.unload()
804
  elif task_type == TaskType.OBJECT_REMOVAL:
805
  object_removal.unload()
806
  elif task_type == TaskType.TILE_UPSCALE:
807
- if get_is_sdxl():
808
- sdxl_tileupscaler.unload()
809
- else:
810
- controlnet.unload()
811
- elif task_type == TaskType.CANNY:
812
  controlnet.unload()
813
- elif task_type == TaskType.SCRIBBLE:
814
- controlnet.unload()
815
- elif task_type == TaskType.LINEARART:
816
- controlnet.unload()
817
- elif task_type == TaskType.POSE:
 
 
818
  controlnet.unload()
819
 
820
 
@@ -831,8 +976,6 @@ def model_fn(model_dir):
831
  set_model_config(config)
832
  set_root_dir(__file__)
833
 
834
- FailureHandler.register()
835
-
836
  avatar.load_local(model_dir)
837
 
838
  lora_style.load(model_dir)
@@ -855,15 +998,12 @@ def auto_unload_task(func):
855
 
856
 
857
  @auto_unload_task
858
- @FailureHandler.clear
859
  def predict_fn(data, pipe):
860
  task = Task(data)
861
  print("task is ", data)
862
 
863
  clear_cuda_and_gc()
864
 
865
- FailureHandler.handle(task)
866
-
867
  try:
868
  task_type = task.get_type()
869
 
@@ -894,11 +1034,16 @@ def predict_fn(data, pipe):
894
  avatar.fetch_from_network(task.get_model_id())
895
 
896
  if task_type == TaskType.TEXT_TO_IMAGE:
 
 
 
897
  return text2img(task)
898
  elif task_type == TaskType.IMAGE_TO_IMAGE:
899
  return img2img(task)
900
  elif task_type == TaskType.CANNY:
901
  return canny(task)
 
 
902
  elif task_type == TaskType.POSE:
903
  return pose(task)
904
  elif task_type == TaskType.TILE_UPSCALE:
 
17
  from internals.pipelines.img_to_text import Image2Text
18
  from internals.pipelines.inpainter import InPainter
19
  from internals.pipelines.object_remove import ObjectRemoval
 
20
  from internals.pipelines.prompt_modifier import PromptModifier
21
  from internals.pipelines.realtime_draw import RealtimeDraw
22
+ from internals.pipelines.remove_background import RemoveBackgroundV3
23
  from internals.pipelines.replace_background import ReplaceBackground
24
  from internals.pipelines.safety_checker import SafetyChecker
25
  from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
 
44
  set_model_config,
45
  set_root_dir,
46
  )
 
47
  from internals.util.lora_style import LoraStyle
48
  from internals.util.model_loader import load_model_from_config
49
  from internals.util.slack import Slack
 
55
 
56
  prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences())
57
  upscaler = Upscaler()
 
58
  inpainter = InPainter()
59
  high_res = HighRes()
60
  img2text = Image2Text()
61
  img_classifier = ImageClassifier()
62
  object_removal = ObjectRemoval()
63
  replace_background = ReplaceBackground()
64
+ remove_background_v3 = RemoveBackgroundV3()
65
  replace_background = ReplaceBackground()
66
  controlnet = ControlNet()
67
  lora_style = LoraStyle()
 
89
 
90
  def get_patched_prompt_tile_upscale(task: Task):
91
  return prompt_util.get_patched_prompt_tile_upscale(
92
+ task, avatar, lora_style, img_classifier, img2text, is_sdxl=get_is_sdxl()
93
  )
94
 
95
 
 
123
  "num_inference_steps": task.get_steps(),
124
  "width": width,
125
  "height": height,
126
+ "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
127
+ "apply_preprocess": task.get_apply_preprocess(),
 
 
128
  **task.cnc_kwargs(),
129
  **lora_patcher.kwargs(),
130
  }
131
+ (images, has_nsfw), control_image = controlnet.process(**kwargs)
132
  if task.get_high_res_fix():
133
  kwargs = {
134
  "prompt": prompt,
135
  "negative_prompt": [task.get_negative_prompt()]
136
  * get_num_return_sequences(),
137
  "images": images,
138
+ "seed": task.get_seed(),
139
  "width": task.get_width(),
140
  "height": task.get_height(),
141
  "num_inference_steps": task.get_steps(),
 
143
  }
144
  images, _ = high_res.apply(**kwargs)
145
 
146
+ upload_image(
147
+ control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
148
+ )
149
  generated_image_urls = upload_images(images, "_canny", task.get_taskId())
150
 
151
  lora_patcher.cleanup()
 
161
  @update_db
162
  @auto_clear_cuda_and_gc(controlnet)
163
  @slack.auto_send_alert
164
+ def canny_img2img(task: Task):
165
+ prompt, _ = get_patched_prompt(task)
 
 
 
 
 
 
 
 
166
 
167
+ width, height = get_intermediate_dimension(task)
 
 
 
 
 
 
 
 
168
 
169
+ controlnet.load_model("canny_2x")
 
 
170
 
171
+ lora_patcher = lora_style.get_patcher(
172
+ [controlnet.pipe, high_res.pipe], task.get_style()
173
+ )
174
+ lora_patcher.patch()
175
 
176
+ kwargs = {
177
+ "prompt": prompt,
178
+ "imageUrl": task.get_imageUrl(),
179
+ "seed": task.get_seed(),
180
+ "num_inference_steps": task.get_steps(),
181
+ "width": width,
182
+ "height": height,
183
+ "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
184
+ **task.cnci2i_kwargs(),
185
+ **lora_patcher.kwargs(),
186
+ }
187
+ (images, has_nsfw), control_image = controlnet.process(**kwargs)
188
+ if task.get_high_res_fix():
189
+ # we run both here normal upscaler and highres
190
+ # and show normal upscaler image as output
191
+ # but use highres image for tile upscale
192
  kwargs = {
193
+ "prompt": prompt,
194
+ "negative_prompt": [task.get_negative_prompt()]
195
+ * get_num_return_sequences(),
196
+ "images": images,
197
  "seed": task.get_seed(),
 
 
198
  "width": task.get_width(),
199
  "height": task.get_height(),
200
+ "num_inference_steps": task.get_steps(),
201
+ **task.high_res_kwargs(),
 
202
  }
203
+ images, _ = high_res.apply(**kwargs)
204
+
205
+ # upload_images(images_high_res, "_canny_2x_highres", task.get_taskId())
206
+
207
+ for i, image in enumerate(images):
208
+ img = upscaler.upscale(
209
+ image=image,
210
+ width=task.get_width(),
211
+ height=task.get_height(),
212
+ face_enhance=task.get_face_enhance(),
213
+ resize_dimension=None,
214
+ )
215
+ img = Upscaler.to_pil(img)
216
+ images[i] = img.resize((task.get_width(), task.get_height()))
217
+
218
+ upload_image(
219
+ control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
220
+ )
221
+ generated_image_urls = upload_images(images, "_canny_2x", task.get_taskId())
222
+
223
+ lora_patcher.cleanup()
224
+ controlnet.cleanup()
225
+
226
+ return {
227
+ "modified_prompts": prompt,
228
+ "generated_image_urls": generated_image_urls,
229
+ "has_nsfw": has_nsfw,
230
+ }
231
+
232
+
233
+ @update_db
234
+ @auto_clear_cuda_and_gc(controlnet)
235
+ @slack.auto_send_alert
236
+ def tile_upscale(task: Task):
237
+ output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
238
+
239
+ prompt = get_patched_prompt_tile_upscale(task)
240
+
241
+ controlnet.load_model("tile_upscaler")
242
+
243
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
244
+ lora_patcher.patch()
245
+
246
+ kwargs = {
247
+ "imageUrl": task.get_imageUrl(),
248
+ "seed": task.get_seed(),
249
+ "num_inference_steps": task.get_steps(),
250
+ "negative_prompt": task.get_negative_prompt(),
251
+ "width": task.get_width(),
252
+ "height": task.get_height(),
253
+ "prompt": prompt,
254
+ "resize_dimension": task.get_resize_dimension(),
255
+ **task.cnt_kwargs(),
256
+ }
257
+ (images, has_nsfw), _ = controlnet.process(**kwargs)
258
+ lora_patcher.cleanup()
259
+ controlnet.cleanup()
260
 
261
  generated_image_url = upload_image(images[0], output_key)
262
 
 
282
  )
283
  lora_patcher.patch()
284
 
285
+ image = controlnet.preprocess_image(task.get_imageUrl(), width, height)
 
 
 
 
 
286
 
287
  kwargs = {
288
  "image": [image] * get_num_return_sequences(),
 
292
  "height": height,
293
  "prompt": prompt,
294
  "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
295
+ "apply_preprocess": task.get_apply_preprocess(),
296
  **task.cns_kwargs(),
297
  }
298
+ (images, has_nsfw), condition_image = controlnet.process(**kwargs)
299
 
300
  if task.get_high_res_fix():
301
  kwargs = {
 
305
  "images": images,
306
  "width": task.get_width(),
307
  "height": task.get_height(),
308
+ "seed": task.get_seed(),
309
  "num_inference_steps": task.get_steps(),
310
  **task.high_res_kwargs(),
311
  }
312
  images, _ = high_res.apply(**kwargs)
313
 
314
+ upload_image(
315
+ condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
316
+ )
317
  generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
318
 
319
  lora_patcher.cleanup()
 
349
  "height": height,
350
  "prompt": prompt,
351
  "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
352
+ "apply_preprocess": task.get_apply_preprocess(),
353
  **task.cnl_kwargs(),
354
  }
355
+ (images, has_nsfw), condition_image = controlnet.process(**kwargs)
356
 
357
  if task.get_high_res_fix():
358
+ # we run both here normal upscaler and highres
359
+ # and show normal upscaler image as output
360
+ # but use highres image for tile upscale
361
  kwargs = {
362
  "prompt": prompt,
363
  "negative_prompt": [task.get_negative_prompt()]
364
  * get_num_return_sequences(),
365
  "images": images,
366
+ "seed": task.get_seed(),
367
  "width": task.get_width(),
368
  "height": task.get_height(),
369
  "num_inference_steps": task.get_steps(),
 
371
  }
372
  images, _ = high_res.apply(**kwargs)
373
 
374
+ # upload_images(images_high_res, "_linearart_highres", task.get_taskId())
375
+ #
376
+ # for i, image in enumerate(images):
377
+ # img = upscaler.upscale(
378
+ # image=image,
379
+ # width=task.get_width(),
380
+ # height=task.get_height(),
381
+ # face_enhance=task.get_face_enhance(),
382
+ # resize_dimension=None,
383
+ # )
384
+ # img = Upscaler.to_pil(img)
385
+ # images[i] = img
386
+
387
+ upload_image(
388
+ condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
389
+ )
390
  generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
391
 
392
  lora_patcher.cleanup()
 
415
  )
416
  lora_patcher.patch()
417
 
418
+ if not task.get_apply_preprocess():
419
+ poses = [download_image(task.get_imageUrl()).resize((width, height))]
420
+ elif not task.get_pose_estimation():
421
  print("Not detecting pose")
422
  pose = download_image(task.get_imageUrl()).resize(
423
  (task.get_width(), task.get_height())
424
  )
425
  poses = [pose] * get_num_return_sequences()
 
 
 
 
 
 
 
 
426
  else:
427
  poses = [
428
  controlnet.detect_pose(task.get_imageUrl())
 
438
 
439
  upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId()))
440
 
441
+ scale = task.cnp_kwargs().pop("controlnet_conditioning_scale", None)
442
+ factor = task.cnp_kwargs().pop("control_guidance_end", None)
443
  kwargs = {
444
+ "controlnet_conditioning_scale": [1.0, scale or 1.0],
445
+ "control_guidance_end": [0.5, factor or 1.0],
446
  }
447
  else:
448
  images = poses[0]
 
460
  **task.cnp_kwargs(),
461
  **lora_patcher.kwargs(),
462
  }
463
+ (images, has_nsfw), _ = controlnet.process(**kwargs)
464
 
465
  if task.get_high_res_fix():
466
  kwargs = {
 
471
  "width": task.get_width(),
472
  "height": task.get_height(),
473
  "num_inference_steps": task.get_steps(),
474
+ "seed": task.get_seed(),
475
  **task.high_res_kwargs(),
476
  }
477
  images, _ = high_res.apply(**kwargs)
478
 
479
+ upload_image(poses[0], "crecoAI/{}_condition.png".format(task.get_taskId()))
480
 
481
  generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
482
 
 
503
  )
504
  lora_patcher.patch()
505
 
 
 
506
  kwargs = {
507
  "params": params,
508
  "num_inference_steps": task.get_steps(),
509
  "height": height,
510
+ "seed": task.get_seed(),
511
  "width": width,
512
  "negative_prompt": task.get_negative_prompt(),
513
  **task.t2i_kwargs(),
 
526
  "width": task.get_width(),
527
  "height": task.get_height(),
528
  "num_inference_steps": task.get_steps(),
529
+ "seed": task.get_seed(),
530
  **task.high_res_kwargs(),
531
  }
532
  images, _ = high_res.apply(**kwargs)
 
550
 
551
  width, height = get_intermediate_dimension(task)
552
 
 
 
553
  if get_is_sdxl():
554
  # we run lineart for img2img
555
+ controlnet.load_model("canny")
556
 
557
  lora_patcher = lora_style.get_patcher(
558
  [controlnet.pipe2, high_res.pipe], task.get_style()
 
568
  "prompt": prompt,
569
  "negative_prompt": [task.get_negative_prompt()]
570
  * get_num_return_sequences(),
571
+ "controlnet_conditioning_scale": 0.5,
572
+ # "adapter_conditioning_scale": 0.3,
573
+ **task.i2i_kwargs(),
574
  }
575
+ (images, has_nsfw), _ = controlnet.process(**kwargs)
576
  else:
577
  lora_patcher = lora_style.get_patcher(
578
  [img2img_pipe.pipe, high_res.pipe], task.get_style()
 
587
  "num_inference_steps": task.get_steps(),
588
  "width": width,
589
  "height": height,
590
+ "seed": task.get_seed(),
591
  **task.i2i_kwargs(),
592
  **lora_patcher.kwargs(),
593
  }
 
602
  "width": task.get_width(),
603
  "height": task.get_height(),
604
  "num_inference_steps": task.get_steps(),
605
+ "seed": task.get_seed(),
606
  **task.high_res_kwargs(),
607
  }
608
  images, _ = high_res.apply(**kwargs)
 
641
  "num_inference_steps": task.get_steps(),
642
  **task.ip_kwargs(),
643
  }
644
+ images, mask = inpainter.process(**kwargs)
645
+
646
+ upload_image(mask, "crecoAI/{}_mask.png".format(task.get_taskId()))
647
 
648
  generated_image_urls = upload_images(images, key, task.get_taskId())
649
 
 
692
  @update_db
693
  @slack.auto_send_alert
694
  def remove_bg(task: Task):
695
+ output_image = remove_background_v3.remove(task.get_imageUrl())
 
 
696
 
697
  output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
698
  image_url = upload_image(output_image, output_key)
 
805
  return {"image": base64_image}
806
 
807
 
808
+ @update_db
809
+ @auto_clear_cuda_and_gc(controlnet)
810
+ @slack.auto_send_alert
811
+ def depth_rig(task: Task):
812
+ # Note : This task is for only processing a hardcoded character rig model using depth controlnet
813
+ # Hack : This model requires hardcoded depth images for optimal processing, so we pass it by default
814
+ default_depth_url = "https://s3.ap-south-1.amazonaws.com/assets.autodraft.in/character-sheet/rigs/character-rig-depth-map.png"
815
+
816
+ params = get_patched_prompt_text2img(task)
817
+
818
+ width, height = get_intermediate_dimension(task)
819
+
820
+ controlnet.load_model("depth")
821
+
822
+ lora_patcher = lora_style.get_patcher(
823
+ [controlnet.pipe2, high_res.pipe], task.get_style()
824
+ )
825
+ lora_patcher.patch()
826
+
827
+ kwargs = {
828
+ "params": params,
829
+ "prompt": params.prompt,
830
+ "num_inference_steps": task.get_steps(),
831
+ "imageUrl": default_depth_url,
832
+ "height": height,
833
+ "seed": task.get_seed(),
834
+ "width": width,
835
+ "negative_prompt": task.get_negative_prompt(),
836
+ **task.t2i_kwargs(),
837
+ **lora_patcher.kwargs(),
838
+ }
839
+ (images, has_nsfw), condition_image = controlnet.process(**kwargs)
840
+
841
+ if task.get_high_res_fix():
842
+ kwargs = {
843
+ "prompt": params.prompt
844
+ if params.prompt
845
+ else [""] * get_num_return_sequences(),
846
+ "negative_prompt": [task.get_negative_prompt()]
847
+ * get_num_return_sequences(),
848
+ "images": images,
849
+ "width": task.get_width(),
850
+ "height": task.get_height(),
851
+ "num_inference_steps": task.get_steps(),
852
+ "seed": task.get_seed(),
853
+ **task.high_res_kwargs(),
854
+ }
855
+ images, _ = high_res.apply(**kwargs)
856
+
857
+ upload_image(condition_image, "crecoAI/{}_condition.png".format(task.get_taskId()))
858
+ generated_image_urls = upload_images(images, "", task.get_taskId())
859
+
860
+ lora_patcher.cleanup()
861
+
862
+ return {
863
+ **params.__dict__,
864
+ "generated_image_urls": generated_image_urls,
865
+ "has_nsfw": has_nsfw,
866
+ }
867
+
868
+
869
  def custom_action(task: Task):
870
  from external.scripts import __scripts__
871
 
 
893
 
894
 
895
  def load_model_by_task(task_type: TaskType, model_id=-1):
896
+ from internals.pipelines.controlnets import clear_networks
897
+
898
+ # pre-cleanup inpaint and controlnet models
899
+ if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
900
+ clear_networks()
901
+ else:
902
+ inpainter.unload()
903
+
904
  if not text2img_pipe.is_loaded():
905
  text2img_pipe.load(get_model_dir())
906
  img2img_pipe.create(text2img_pipe)
 
924
  upscaler.load()
925
  else:
926
  if task_type == TaskType.TILE_UPSCALE:
927
+ # if get_is_sdxl():
928
+ # sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
929
+ # else:
930
+ controlnet.load_model("tile_upscaler")
931
  elif task_type == TaskType.CANNY:
932
  controlnet.load_model("canny")
933
+ elif task_type == TaskType.CANNY_IMG2IMG:
934
+ controlnet.load_model("canny_2x")
935
  elif task_type == TaskType.SCRIBBLE:
936
  controlnet.load_model("scribble")
937
  elif task_type == TaskType.LINEARART:
 
942
 
943
  def unload_model_by_task(task_type: TaskType):
944
  if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
945
+ # inpainter.unload()
946
+ pass
947
  elif task_type == TaskType.REPLACE_BG:
948
  replace_background.unload()
949
  elif task_type == TaskType.OBJECT_REMOVAL:
950
  object_removal.unload()
951
  elif task_type == TaskType.TILE_UPSCALE:
952
+ # if get_is_sdxl():
953
+ # sdxl_tileupscaler.unload()
954
+ # else:
 
 
955
  controlnet.unload()
956
+ elif (
957
+ task_type == TaskType.CANNY
958
+ or task_type == TaskType.CANNY_IMG2IMG
959
+ or task_type == TaskType.SCRIBBLE
960
+ or task_type == TaskType.LINEARART
961
+ or task_type == TaskType.POSE
962
+ ):
963
  controlnet.unload()
964
 
965
 
 
976
  set_model_config(config)
977
  set_root_dir(__file__)
978
 
 
 
979
  avatar.load_local(model_dir)
980
 
981
  lora_style.load(model_dir)
 
998
 
999
 
1000
  @auto_unload_task
 
1001
  def predict_fn(data, pipe):
1002
  task = Task(data)
1003
  print("task is ", data)
1004
 
1005
  clear_cuda_and_gc()
1006
 
 
 
1007
  try:
1008
  task_type = task.get_type()
1009
 
 
1034
  avatar.fetch_from_network(task.get_model_id())
1035
 
1036
  if task_type == TaskType.TEXT_TO_IMAGE:
1037
+ # Hack : Character Rigging Model Task Redirection
1038
+ if task.get_model_id() == 2000336 or task.get_model_id() == 2000341:
1039
+ return depth_rig(task)
1040
  return text2img(task)
1041
  elif task_type == TaskType.IMAGE_TO_IMAGE:
1042
  return img2img(task)
1043
  elif task_type == TaskType.CANNY:
1044
  return canny(task)
1045
+ elif task_type == TaskType.CANNY_IMG2IMG:
1046
+ return canny_img2img(task)
1047
  elif task_type == TaskType.POSE:
1048
  return pose(task)
1049
  elif task_type == TaskType.TILE_UPSCALE:
internals/data/task.py CHANGED
@@ -11,6 +11,7 @@ class TaskType(Enum):
11
  POSE = "POSE"
12
  CANNY = "CANNY"
13
  REMOVE_BG = "REMOVE_BG"
 
14
  INPAINT = "INPAINT"
15
  UPSCALE_IMAGE = "UPSCALE_IMAGE"
16
  TILE_UPSCALE = "TILE_UPSCALE"
@@ -47,12 +48,18 @@ class Task:
47
  elif len(prompt) > 200:
48
  self.__data["prompt"] = data.get("prompt", "")[:200] + ", "
49
 
 
 
 
50
  def get_taskId(self) -> str:
51
  return self.__data.get("task_id")
52
 
53
  def get_sourceId(self) -> str:
54
  return self.__data.get("source_id")
55
 
 
 
 
56
  def get_imageUrl(self) -> str:
57
  return self.__data.get("imageUrl", None)
58
 
@@ -150,12 +157,18 @@ class Task:
150
  def get_access_token(self) -> str:
151
  return self.__data.get("access_token", "")
152
 
 
 
 
153
  def get_high_res_fix(self) -> bool:
154
  return self.__data.get("high_res_fix", False)
155
 
156
  def get_base_dimension(self):
157
  return self.__data.get("base_dimension", None)
158
 
 
 
 
159
  def get_action_data(self) -> dict:
160
  "If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
161
  return self.__data.get("action_data", {})
@@ -175,6 +188,9 @@ class Task:
175
  def cnc_kwargs(self) -> dict:
176
  return dict(self.__get_kwargs("cnc_"))
177
 
 
 
 
178
  def cnp_kwargs(self) -> dict:
179
  return dict(self.__get_kwargs("cnp_"))
180
 
@@ -192,7 +208,7 @@ class Task:
192
 
193
  def __get_kwargs(self, prefix: str):
194
  for k, v in self.__data.items():
195
- if k.startswith(prefix):
196
  yield k[len(prefix) :], v
197
 
198
  @property
 
11
  POSE = "POSE"
12
  CANNY = "CANNY"
13
  REMOVE_BG = "REMOVE_BG"
14
+ CANNY_IMG2IMG = "CANNY_IMG2IMG"
15
  INPAINT = "INPAINT"
16
  UPSCALE_IMAGE = "UPSCALE_IMAGE"
17
  TILE_UPSCALE = "TILE_UPSCALE"
 
48
  elif len(prompt) > 200:
49
  self.__data["prompt"] = data.get("prompt", "")[:200] + ", "
50
 
51
+ def get_environment(self) -> str:
52
+ return self.__data.get("stage", "prod")
53
+
54
  def get_taskId(self) -> str:
55
  return self.__data.get("task_id")
56
 
57
  def get_sourceId(self) -> str:
58
  return self.__data.get("source_id")
59
 
60
+ def get_slack_url(self) -> str:
61
+ return self.__data.get("slack_url", None)
62
+
63
  def get_imageUrl(self) -> str:
64
  return self.__data.get("imageUrl", None)
65
 
 
157
  def get_access_token(self) -> str:
158
  return self.__data.get("access_token", "")
159
 
160
+ def get_apply_preprocess(self) -> bool:
161
+ return self.__data.get("apply_preprocess", True)
162
+
163
  def get_high_res_fix(self) -> bool:
164
  return self.__data.get("high_res_fix", False)
165
 
166
  def get_base_dimension(self):
167
  return self.__data.get("base_dimension", None)
168
 
169
+ def get_process_mode(self):
170
+ return self.__data.get("process_mode", None)
171
+
172
  def get_action_data(self) -> dict:
173
  "If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
174
  return self.__data.get("action_data", {})
 
188
  def cnc_kwargs(self) -> dict:
189
  return dict(self.__get_kwargs("cnc_"))
190
 
191
+ def cnci2i_kwargs(self) -> dict:
192
+ return dict(self.__get_kwargs("cnci2i_"))
193
+
194
  def cnp_kwargs(self) -> dict:
195
  return dict(self.__get_kwargs("cnp_"))
196
 
 
208
 
209
  def __get_kwargs(self, prefix: str):
210
  for k, v in self.__data.items():
211
+ if k.startswith(prefix) and v != -1:
212
  yield k[len(prefix) :], v
213
 
214
  @property
internals/pipelines/commons.py CHANGED
@@ -11,11 +11,14 @@ from diffusers import (
11
 
12
  from internals.data.result import Result
13
  from internals.pipelines.twoStepPipeline import two_step_pipeline
 
14
  from internals.util.commons import disable_safety_checker, download_image
15
  from internals.util.config import (
 
16
  get_base_model_variant,
17
  get_hf_token,
18
  get_is_sdxl,
 
19
  get_num_return_sequences,
20
  )
21
 
@@ -38,6 +41,9 @@ class Text2Img(AbstractPipeline):
38
 
39
  def load(self, model_dir: str):
40
  if get_is_sdxl():
 
 
 
41
  vae = AutoencoderKL.from_pretrained(
42
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
43
  )
@@ -47,6 +53,7 @@ class Text2Img(AbstractPipeline):
47
  token=get_hf_token(),
48
  use_safetensors=True,
49
  variant=get_base_model_variant(),
 
50
  )
51
  pipe.vae = vae
52
  pipe.to("cuda")
@@ -70,9 +77,9 @@ class Text2Img(AbstractPipeline):
70
  self.__patch()
71
 
72
  def __patch(self):
73
- if get_is_sdxl():
74
- self.pipe.enable_vae_tiling()
75
- self.pipe.enable_vae_slicing()
76
  self.pipe.enable_xformers_memory_efficient_attention()
77
 
78
  @torch.inference_mode()
@@ -82,12 +89,15 @@ class Text2Img(AbstractPipeline):
82
  num_inference_steps: int,
83
  height: int,
84
  width: int,
 
85
  negative_prompt: str,
86
  iteration: float = 3.0,
87
  **kwargs,
88
  ):
89
  prompt = params.prompt
90
 
 
 
91
  if params.prompt_left and params.prompt_right:
92
  # multi-character pipelines
93
  prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
@@ -99,6 +109,7 @@ class Text2Img(AbstractPipeline):
99
  "width": width,
100
  "num_inference_steps": num_inference_steps,
101
  "negative_prompt": [negative_prompt or ""] * len(prompt),
 
102
  **kwargs,
103
  }
104
  result = self.pipe.multi_character_diffusion(**kwargs)
@@ -125,8 +136,11 @@ class Text2Img(AbstractPipeline):
125
  "width": width,
126
  "negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
127
  "num_inference_steps": num_inference_steps,
 
 
128
  **kwargs,
129
  }
 
130
  result = self.pipe.__call__(**kwargs)
131
 
132
  return Result.from_result(result)
@@ -145,6 +159,7 @@ class Img2Img(AbstractPipeline):
145
  torch_dtype=torch.float16,
146
  token=get_hf_token(),
147
  variant=get_base_model_variant(),
 
148
  use_safetensors=True,
149
  ).to("cuda")
150
  else:
@@ -183,20 +198,24 @@ class Img2Img(AbstractPipeline):
183
  num_inference_steps: int,
184
  width: int,
185
  height: int,
 
186
  strength: float = 0.75,
187
  guidance_scale: float = 7.5,
188
  **kwargs,
189
  ):
190
  image = download_image(imageUrl).resize((width, height))
191
 
 
 
192
  kwargs = {
193
  "prompt": prompt,
194
- "image": image,
195
  "strength": strength,
196
  "negative_prompt": negative_prompt,
197
  "guidance_scale": guidance_scale,
198
  "num_images_per_prompt": 1,
199
  "num_inference_steps": num_inference_steps,
 
200
  **kwargs,
201
  }
202
  result = self.pipe.__call__(**kwargs)
 
11
 
12
  from internals.data.result import Result
13
  from internals.pipelines.twoStepPipeline import two_step_pipeline
14
+ from internals.util import get_generators
15
  from internals.util.commons import disable_safety_checker, download_image
16
  from internals.util.config import (
17
+ get_base_model_revision,
18
  get_base_model_variant,
19
  get_hf_token,
20
  get_is_sdxl,
21
+ get_low_gpu_mem,
22
  get_num_return_sequences,
23
  )
24
 
 
41
 
42
  def load(self, model_dir: str):
43
  if get_is_sdxl():
44
+ print(
45
+ f"Loading model {model_dir} - {get_base_model_variant()}, {get_base_model_revision()}"
46
+ )
47
  vae = AutoencoderKL.from_pretrained(
48
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
49
  )
 
53
  token=get_hf_token(),
54
  use_safetensors=True,
55
  variant=get_base_model_variant(),
56
+ revision=get_base_model_revision(),
57
  )
58
  pipe.vae = vae
59
  pipe.to("cuda")
 
77
  self.__patch()
78
 
79
  def __patch(self):
80
+ if get_is_sdxl() or get_low_gpu_mem():
81
+ self.pipe.vae.enable_tiling()
82
+ self.pipe.vae.enable_slicing()
83
  self.pipe.enable_xformers_memory_efficient_attention()
84
 
85
  @torch.inference_mode()
 
89
  num_inference_steps: int,
90
  height: int,
91
  width: int,
92
+ seed: int,
93
  negative_prompt: str,
94
  iteration: float = 3.0,
95
  **kwargs,
96
  ):
97
  prompt = params.prompt
98
 
99
+ generator = get_generators(seed, get_num_return_sequences())
100
+
101
  if params.prompt_left and params.prompt_right:
102
  # multi-character pipelines
103
  prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
 
109
  "width": width,
110
  "num_inference_steps": num_inference_steps,
111
  "negative_prompt": [negative_prompt or ""] * len(prompt),
112
+ "generator": generator,
113
  **kwargs,
114
  }
115
  result = self.pipe.multi_character_diffusion(**kwargs)
 
136
  "width": width,
137
  "negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
138
  "num_inference_steps": num_inference_steps,
139
+ "guidance_scale": 7.5,
140
+ "generator": generator,
141
  **kwargs,
142
  }
143
+ print(kwargs)
144
  result = self.pipe.__call__(**kwargs)
145
 
146
  return Result.from_result(result)
 
159
  torch_dtype=torch.float16,
160
  token=get_hf_token(),
161
  variant=get_base_model_variant(),
162
+ revision=get_base_model_revision(),
163
  use_safetensors=True,
164
  ).to("cuda")
165
  else:
 
198
  num_inference_steps: int,
199
  width: int,
200
  height: int,
201
+ seed: int,
202
  strength: float = 0.75,
203
  guidance_scale: float = 7.5,
204
  **kwargs,
205
  ):
206
  image = download_image(imageUrl).resize((width, height))
207
 
208
+ generator = get_generators(seed, get_num_return_sequences())
209
+
210
  kwargs = {
211
  "prompt": prompt,
212
+ "image": [image] * get_num_return_sequences(),
213
  "strength": strength,
214
  "negative_prompt": negative_prompt,
215
  "guidance_scale": guidance_scale,
216
  "num_images_per_prompt": 1,
217
  "num_inference_steps": num_inference_steps,
218
+ "generator": generator,
219
  **kwargs,
220
  }
221
  result = self.pipe.__call__(**kwargs)
internals/pipelines/controlnets.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import AbstractSet, List, Literal, Optional, Union
2
 
3
  import cv2
@@ -17,6 +18,7 @@ from diffusers import (
17
  StableDiffusionControlNetImg2ImgPipeline,
18
  StableDiffusionControlNetPipeline,
19
  StableDiffusionXLAdapterPipeline,
 
20
  StableDiffusionXLControlNetPipeline,
21
  T2IAdapter,
22
  UniPCMultistepScheduler,
@@ -29,9 +31,9 @@ from tqdm import gui
29
  from transformers import pipeline
30
 
31
  import internals.util.image as ImageUtil
32
- from external.midas import apply_midas
33
  from internals.data.result import Result
34
  from internals.pipelines.commons import AbstractPipeline
 
35
  from internals.util.cache import clear_cuda_and_gc
36
  from internals.util.commons import download_image
37
  from internals.util.config import (
@@ -39,9 +41,51 @@ from internals.util.config import (
39
  get_hf_token,
40
  get_is_sdxl,
41
  get_model_dir,
 
42
  )
43
 
44
- CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  class StableDiffusionNetworkModelPipelineLoader:
@@ -57,11 +101,6 @@ class StableDiffusionNetworkModelPipelineLoader:
57
  pipeline_type,
58
  base_pipe: Optional[AbstractSet] = None,
59
  ):
60
- if is_sdxl and is_img2img:
61
- # Does not matter pipeline type but tile upscale is not supported
62
- print("Warning: Tile upscale is not supported on SDXL")
63
- return None
64
-
65
  if base_pipe is None:
66
  pretrained = True
67
  kwargs = {
@@ -75,7 +114,17 @@ class StableDiffusionNetworkModelPipelineLoader:
75
  kwargs = {
76
  **base_pipe.pipe.components, # pyright: ignore
77
  }
 
 
 
78
 
 
 
 
 
 
 
 
79
  if is_sdxl and pipeline_type == "controlnet":
80
  model = (
81
  StableDiffusionXLControlNetPipeline.from_pretrained
@@ -146,9 +195,10 @@ class ControlNet(AbstractPipeline):
146
  def load_model(self, task_name: CONTROLNET_TYPES):
147
  "Appropriately loads the network module, pipelines and cache it for reuse."
148
 
149
- config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
150
  if self.__current_task_name == task_name:
151
  return
 
 
152
  model = config[task_name]
153
  if not model:
154
  raise Exception(f"ControlNet is not supported for {task_name}")
@@ -176,31 +226,13 @@ class ControlNet(AbstractPipeline):
176
  def __load_network_model(self, model_name, pipeline_type):
177
  "Loads the network module, eg: ControlNet or T2I Adapters"
178
 
179
- def load_controlnet(model):
180
- return ControlNetModel.from_pretrained(
181
- model,
182
- torch_dtype=torch.float16,
183
- cache_dir=get_hf_cache_dir(),
184
- ).to("cuda")
185
-
186
- def load_t2i(model):
187
- return T2IAdapter.from_pretrained(
188
- model,
189
- torch_dtype=torch.float16,
190
- varient="fp16",
191
- ).to("cuda")
192
-
193
  if type(model_name) == str:
194
- if pipeline_type == "controlnet":
195
- return load_controlnet(model_name)
196
- if pipeline_type == "t2i":
197
- return load_t2i(model_name)
198
- raise Exception("Invalid pipeline type")
199
  elif type(model_name) == list:
200
  if pipeline_type == "controlnet":
201
  cns = []
202
  for model in model_name:
203
- cns.append(load_controlnet(model))
204
  return MultiControlNetModel(cns).to("cuda")
205
  elif pipeline_type == "t2i":
206
  raise Exception("Multi T2I adapters are not supported")
@@ -219,9 +251,10 @@ class ControlNet(AbstractPipeline):
219
  pipe.enable_vae_slicing()
220
  pipe.enable_xformers_memory_efficient_attention()
221
  # this scheduler produces good outputs for t2i adapters
222
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
223
- pipe.scheduler.config
224
- )
 
225
  else:
226
  pipe.enable_xformers_memory_efficient_attention()
227
  return pipe
@@ -229,7 +262,7 @@ class ControlNet(AbstractPipeline):
229
  # If the pipeline type is changed we should reload all
230
  # the pipelines
231
  if not self.__loaded or self.__pipe_type != pipeline_type:
232
- # controlnet pipeline for tile upscaler
233
  pipe = StableDiffusionNetworkModelPipelineLoader(
234
  is_sdxl=get_is_sdxl(),
235
  is_img2img=True,
@@ -278,6 +311,8 @@ class ControlNet(AbstractPipeline):
278
  def process(self, **kwargs):
279
  if self.__current_task_name == "pose":
280
  return self.process_pose(**kwargs)
 
 
281
  if self.__current_task_name == "canny":
282
  return self.process_canny(**kwargs)
283
  if self.__current_task_name == "scribble":
@@ -286,6 +321,8 @@ class ControlNet(AbstractPipeline):
286
  return self.process_linearart(**kwargs)
287
  if self.__current_task_name == "tile_upscaler":
288
  return self.process_tile_upscaler(**kwargs)
 
 
289
  raise Exception("ControlNet is not loaded with any model")
290
 
291
  @torch.inference_mode()
@@ -298,16 +335,22 @@ class ControlNet(AbstractPipeline):
298
  negative_prompt: List[str],
299
  height: int,
300
  width: int,
301
- guidance_scale: float = 9,
 
302
  **kwargs,
303
  ):
304
  if self.__current_task_name != "canny":
305
  raise Exception("ControlNet is not loaded with canny model")
306
 
307
- torch.manual_seed(seed)
 
 
 
 
 
308
 
309
- init_image = download_image(imageUrl).resize((width, height))
310
- init_image = ControlNet.canny_detect_edge(init_image)
311
 
312
  kwargs = {
313
  "prompt": prompt,
@@ -318,11 +361,67 @@ class ControlNet(AbstractPipeline):
318
  "num_inference_steps": num_inference_steps,
319
  "height": height,
320
  "width": width,
 
321
  **kwargs,
322
  }
323
 
 
324
  result = self.pipe2.__call__(**kwargs)
325
- return Result.from_result(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  @torch.inference_mode()
328
  def process_pose(
@@ -340,22 +439,23 @@ class ControlNet(AbstractPipeline):
340
  if self.__current_task_name != "pose":
341
  raise Exception("ControlNet is not loaded with pose model")
342
 
343
- torch.manual_seed(seed)
344
 
345
  kwargs = {
346
  "prompt": prompt[0],
347
  "image": image,
348
- "num_images_per_prompt": 4,
349
  "num_inference_steps": num_inference_steps,
350
  "negative_prompt": negative_prompt[0],
351
  "guidance_scale": guidance_scale,
352
  "height": height,
353
  "width": width,
 
354
  **kwargs,
355
  }
356
  print(kwargs)
357
  result = self.pipe2.__call__(**kwargs)
358
- return Result.from_result(result)
359
 
360
  @torch.inference_mode()
361
  def process_tile_upscaler(
@@ -374,26 +474,60 @@ class ControlNet(AbstractPipeline):
374
  if self.__current_task_name != "tile_upscaler":
375
  raise Exception("ControlNet is not loaded with tile_upscaler model")
376
 
377
- torch.manual_seed(seed)
 
 
 
 
 
378
 
379
- init_image = download_image(imageUrl).resize((width, height))
380
- condition_image = self.__resize_for_condition_image(
381
- init_image, resize_dimension
382
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  kwargs = {
385
- "image": condition_image,
386
  "prompt": prompt,
387
  "control_image": condition_image,
388
  "num_inference_steps": num_inference_steps,
389
  "negative_prompt": negative_prompt,
390
  "height": condition_image.size[1],
391
  "width": condition_image.size[0],
392
- "guidance_scale": guidance_scale,
393
  **kwargs,
394
  }
395
  result = self.pipe.__call__(**kwargs)
396
- return Result.from_result(result)
397
 
398
  @torch.inference_mode()
399
  def process_scribble(
@@ -406,16 +540,28 @@ class ControlNet(AbstractPipeline):
406
  height: int,
407
  width: int,
408
  guidance_scale: float = 7.5,
 
409
  **kwargs,
410
  ):
411
  if self.__current_task_name != "scribble":
412
  raise Exception("ControlNet is not loaded with scribble model")
413
 
414
- torch.manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  sdxl_args = (
417
  {
418
- "guidance_scale": 6,
419
  "adapter_conditioning_scale": 1.0,
420
  "adapter_conditioning_factor": 1.0,
421
  }
@@ -431,11 +577,12 @@ class ControlNet(AbstractPipeline):
431
  "height": height,
432
  "width": width,
433
  "guidance_scale": guidance_scale,
 
434
  **sdxl_args,
435
  **kwargs,
436
  }
437
  result = self.pipe2.__call__(**kwargs)
438
- return Result.from_result(result)
439
 
440
  @torch.inference_mode()
441
  def process_linearart(
@@ -448,20 +595,26 @@ class ControlNet(AbstractPipeline):
448
  height: int,
449
  width: int,
450
  guidance_scale: float = 7.5,
 
451
  **kwargs,
452
  ):
453
  if self.__current_task_name != "linearart":
454
  raise Exception("ControlNet is not loaded with linearart model")
455
 
456
- torch.manual_seed(seed)
457
 
458
- init_image = download_image(imageUrl).resize((width, height))
459
- condition_image = ControlNet.linearart_condition_image(init_image)
 
 
 
 
 
460
 
461
  # we use t2i adapter and the conditioning scale should always be 0.8
462
  sdxl_args = (
463
  {
464
- "guidance_scale": 6,
465
  "adapter_conditioning_scale": 1.0,
466
  "adapter_conditioning_factor": 1.0,
467
  }
@@ -470,18 +623,68 @@ class ControlNet(AbstractPipeline):
470
  )
471
 
472
  kwargs = {
473
- "image": [condition_image] * 4,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  "prompt": prompt,
475
  "num_inference_steps": num_inference_steps,
476
  "negative_prompt": negative_prompt,
477
  "height": height,
478
  "width": width,
479
  "guidance_scale": guidance_scale,
 
480
  **sdxl_args,
481
  **kwargs,
482
  }
483
  result = self.pipe2.__call__(**kwargs)
484
- return Result.from_result(result)
485
 
486
  def cleanup(self):
487
  """Doesn't do anything considering new diffusers has itself a cleanup mechanism
@@ -504,12 +707,15 @@ class ControlNet(AbstractPipeline):
504
  def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
505
  processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
506
  if get_is_sdxl():
507
- kwargs = {"detect_resolution": 384, **kwargs}
 
 
508
 
509
  image = processor.__call__(input_image=image, **kwargs)
510
  return image
511
 
512
  @staticmethod
 
513
  def depth_image(image: Image.Image) -> Image.Image:
514
  global midas, midas_transforms
515
  if "midas" not in globals():
@@ -555,6 +761,10 @@ class ControlNet(AbstractPipeline):
555
  canny_image = Image.fromarray(image_array)
556
  return canny_image
557
 
 
 
 
 
558
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
559
  input_image = image.convert("RGB")
560
  W, H = input_image.size
@@ -572,6 +782,7 @@ class ControlNet(AbstractPipeline):
572
  "linearart": "lllyasviel/control_v11p_sd15_lineart",
573
  "scribble": "lllyasviel/control_v11p_sd15_scribble",
574
  "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
 
575
  }
576
  __model_normal_types = {
577
  "pose": "controlnet",
@@ -579,19 +790,24 @@ class ControlNet(AbstractPipeline):
579
  "linearart": "controlnet",
580
  "scribble": "controlnet",
581
  "tile_upscaler": "controlnet",
 
582
  }
583
 
584
  __model_sdxl = {
585
  "pose": "thibaud/controlnet-openpose-sdxl-1.0",
586
- "canny": "diffusers/controlnet-canny-sdxl-1.0",
 
 
587
  "linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
588
  "scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
589
- "tile_upscaler": None,
590
  }
591
  __model_sdxl_types = {
592
  "pose": "controlnet",
593
  "canny": "controlnet",
 
 
594
  "linearart": "t2i",
595
  "scribble": "t2i",
596
- "tile_upscaler": None,
597
  }
 
1
+ import os
2
  from typing import AbstractSet, List, Literal, Optional, Union
3
 
4
  import cv2
 
18
  StableDiffusionControlNetImg2ImgPipeline,
19
  StableDiffusionControlNetPipeline,
20
  StableDiffusionXLAdapterPipeline,
21
+ StableDiffusionXLControlNetImg2ImgPipeline,
22
  StableDiffusionXLControlNetPipeline,
23
  T2IAdapter,
24
  UniPCMultistepScheduler,
 
31
  from transformers import pipeline
32
 
33
  import internals.util.image as ImageUtil
 
34
  from internals.data.result import Result
35
  from internals.pipelines.commons import AbstractPipeline
36
+ from internals.util import get_generators
37
  from internals.util.cache import clear_cuda_and_gc
38
  from internals.util.commons import download_image
39
  from internals.util.config import (
 
41
  get_hf_token,
42
  get_is_sdxl,
43
  get_model_dir,
44
+ get_num_return_sequences,
45
  )
46
 
47
+ CONTROLNET_TYPES = Literal[
48
+ "pose", "canny", "scribble", "linearart", "tile_upscaler", "canny_2x"
49
+ ]
50
+
51
+ __CN_MODELS = {}
52
+ MAX_CN_MODELS = 3
53
+
54
+
55
+ def clear_networks():
56
+ global __CN_MODELS
57
+ __CN_MODELS = {}
58
+
59
+
60
+ def load_network_model_by_key(repo_id: str, pipeline_type: str):
61
+ global __CN_MODELS
62
+
63
+ if repo_id in __CN_MODELS:
64
+ return __CN_MODELS[repo_id]
65
+
66
+ if len(__CN_MODELS) >= MAX_CN_MODELS:
67
+ __CN_MODELS = {}
68
+
69
+ if pipeline_type == "controlnet":
70
+ model = ControlNetModel.from_pretrained(
71
+ repo_id,
72
+ torch_dtype=torch.float16,
73
+ cache_dir=get_hf_cache_dir(),
74
+ token=get_hf_token(),
75
+ ).to("cuda")
76
+ elif pipeline_type == "t2i":
77
+ model = T2IAdapter.from_pretrained(
78
+ repo_id,
79
+ torch_dtype=torch.float16,
80
+ varient="fp16",
81
+ token=get_hf_token(),
82
+ ).to("cuda")
83
+ else:
84
+ raise Exception("Invalid pipeline type")
85
+
86
+ __CN_MODELS[repo_id] = model
87
+
88
+ return model
89
 
90
 
91
  class StableDiffusionNetworkModelPipelineLoader:
 
101
  pipeline_type,
102
  base_pipe: Optional[AbstractSet] = None,
103
  ):
 
 
 
 
 
104
  if base_pipe is None:
105
  pretrained = True
106
  kwargs = {
 
114
  kwargs = {
115
  **base_pipe.pipe.components, # pyright: ignore
116
  }
117
+ if get_is_sdxl():
118
+ kwargs.pop("image_encoder", None)
119
+ kwargs.pop("feature_extractor", None)
120
 
121
+ if is_sdxl and is_img2img and pipeline_type == "controlnet":
122
+ model = (
123
+ StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained
124
+ if pretrained
125
+ else StableDiffusionXLControlNetImg2ImgPipeline
126
+ )
127
+ return model(controlnet=network_model, **kwargs).to("cuda")
128
  if is_sdxl and pipeline_type == "controlnet":
129
  model = (
130
  StableDiffusionXLControlNetPipeline.from_pretrained
 
195
  def load_model(self, task_name: CONTROLNET_TYPES):
196
  "Appropriately loads the network module, pipelines and cache it for reuse."
197
 
 
198
  if self.__current_task_name == task_name:
199
  return
200
+
201
+ config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
202
  model = config[task_name]
203
  if not model:
204
  raise Exception(f"ControlNet is not supported for {task_name}")
 
226
  def __load_network_model(self, model_name, pipeline_type):
227
  "Loads the network module, eg: ControlNet or T2I Adapters"
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  if type(model_name) == str:
230
+ return load_network_model_by_key(model_name, pipeline_type)
 
 
 
 
231
  elif type(model_name) == list:
232
  if pipeline_type == "controlnet":
233
  cns = []
234
  for model in model_name:
235
+ cns.append(load_network_model_by_key(model, pipeline_type))
236
  return MultiControlNetModel(cns).to("cuda")
237
  elif pipeline_type == "t2i":
238
  raise Exception("Multi T2I adapters are not supported")
 
251
  pipe.enable_vae_slicing()
252
  pipe.enable_xformers_memory_efficient_attention()
253
  # this scheduler produces good outputs for t2i adapters
254
+ if pipeline_type == "t2i":
255
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
256
+ pipe.scheduler.config
257
+ )
258
  else:
259
  pipe.enable_xformers_memory_efficient_attention()
260
  return pipe
 
262
  # If the pipeline type is changed we should reload all
263
  # the pipelines
264
  if not self.__loaded or self.__pipe_type != pipeline_type:
265
+ # controlnet pipeline for tile upscaler or any pipeline with img2img + network support
266
  pipe = StableDiffusionNetworkModelPipelineLoader(
267
  is_sdxl=get_is_sdxl(),
268
  is_img2img=True,
 
311
  def process(self, **kwargs):
312
  if self.__current_task_name == "pose":
313
  return self.process_pose(**kwargs)
314
+ if self.__current_task_name == "depth":
315
+ return self.process_depth(**kwargs)
316
  if self.__current_task_name == "canny":
317
  return self.process_canny(**kwargs)
318
  if self.__current_task_name == "scribble":
 
321
  return self.process_linearart(**kwargs)
322
  if self.__current_task_name == "tile_upscaler":
323
  return self.process_tile_upscaler(**kwargs)
324
+ if self.__current_task_name == "canny_2x":
325
+ return self.process_canny_2x(**kwargs)
326
  raise Exception("ControlNet is not loaded with any model")
327
 
328
  @torch.inference_mode()
 
335
  negative_prompt: List[str],
336
  height: int,
337
  width: int,
338
+ guidance_scale: float = 7.5,
339
+ apply_preprocess: bool = True,
340
  **kwargs,
341
  ):
342
  if self.__current_task_name != "canny":
343
  raise Exception("ControlNet is not loaded with canny model")
344
 
345
+ generator = get_generators(seed, get_num_return_sequences())
346
+
347
+ init_image = self.preprocess_image(imageUrl, width, height)
348
+ if apply_preprocess:
349
+ init_image = ControlNet.canny_detect_edge(init_image)
350
+ init_image = init_image.resize((width, height))
351
 
352
+ # if get_is_sdxl():
353
+ # kwargs["controlnet_conditioning_scale"] = 0.5
354
 
355
  kwargs = {
356
  "prompt": prompt,
 
361
  "num_inference_steps": num_inference_steps,
362
  "height": height,
363
  "width": width,
364
+ "generator": generator,
365
  **kwargs,
366
  }
367
 
368
+ print(kwargs)
369
  result = self.pipe2.__call__(**kwargs)
370
+ return Result.from_result(result), init_image
371
+
372
+ @torch.inference_mode()
373
+ def process_canny_2x(
374
+ self,
375
+ prompt: List[str],
376
+ imageUrl: str,
377
+ seed: int,
378
+ num_inference_steps: int,
379
+ negative_prompt: List[str],
380
+ height: int,
381
+ width: int,
382
+ guidance_scale: float = 8.5,
383
+ **kwargs,
384
+ ):
385
+ if self.__current_task_name != "canny_2x":
386
+ raise Exception("ControlNet is not loaded with canny model")
387
+
388
+ generator = get_generators(seed, get_num_return_sequences())
389
+
390
+ init_image = self.preprocess_image(imageUrl, width, height)
391
+ canny_image = ControlNet.canny_detect_edge(init_image).resize((width, height))
392
+ depth_image = ControlNet.depth_image(init_image).resize((width, height))
393
+
394
+ condition_scale = kwargs.get("controlnet_conditioning_scale", None)
395
+ condition_factor = kwargs.get("control_guidance_end", None)
396
+ print("condition_scale", condition_scale)
397
+
398
+ if not get_is_sdxl():
399
+ kwargs["guidance_scale"] = 7.5
400
+ kwargs["strength"] = 0.8
401
+ kwargs["controlnet_conditioning_scale"] = [condition_scale or 1.0, 0.3]
402
+ else:
403
+ kwargs["controlnet_conditioning_scale"] = [condition_scale or 0.8, 0.3]
404
+
405
+ kwargs["control_guidance_end"] = [condition_factor or 1.0, 1.0]
406
+
407
+ kwargs = {
408
+ "prompt": prompt[0],
409
+ "image": [init_image] * get_num_return_sequences(),
410
+ "control_image": [canny_image, depth_image],
411
+ "guidance_scale": guidance_scale,
412
+ "num_images_per_prompt": get_num_return_sequences(),
413
+ "negative_prompt": negative_prompt[0],
414
+ "num_inference_steps": num_inference_steps,
415
+ "strength": 1.0,
416
+ "height": height,
417
+ "width": width,
418
+ "generator": generator,
419
+ **kwargs,
420
+ }
421
+ print(kwargs)
422
+
423
+ result = self.pipe.__call__(**kwargs)
424
+ return Result.from_result(result), canny_image
425
 
426
  @torch.inference_mode()
427
  def process_pose(
 
439
  if self.__current_task_name != "pose":
440
  raise Exception("ControlNet is not loaded with pose model")
441
 
442
+ generator = get_generators(seed, get_num_return_sequences())
443
 
444
  kwargs = {
445
  "prompt": prompt[0],
446
  "image": image,
447
+ "num_images_per_prompt": get_num_return_sequences(),
448
  "num_inference_steps": num_inference_steps,
449
  "negative_prompt": negative_prompt[0],
450
  "guidance_scale": guidance_scale,
451
  "height": height,
452
  "width": width,
453
+ "generator": generator,
454
  **kwargs,
455
  }
456
  print(kwargs)
457
  result = self.pipe2.__call__(**kwargs)
458
+ return Result.from_result(result), image
459
 
460
  @torch.inference_mode()
461
  def process_tile_upscaler(
 
474
  if self.__current_task_name != "tile_upscaler":
475
  raise Exception("ControlNet is not loaded with tile_upscaler model")
476
 
477
+ init_image = None
478
+ # find the correct seed and imageUrl from imageUrl
479
+ try:
480
+ p = os.path.splitext(imageUrl)[0]
481
+ p = p.split("/")[-1]
482
+ p = p.split("_")[-1]
483
 
484
+ seed = seed + int(p)
485
+
486
+ if "_canny_2x" or "_linearart" in imageUrl:
487
+ imageUrl = imageUrl.replace("_canny_2x", "_canny_2x_highres").replace(
488
+ "_linearart_highres", ""
489
+ )
490
+ init_image = download_image(imageUrl)
491
+ width, height = init_image.size
492
+
493
+ print("Setting imageUrl with width and height", imageUrl, width, height)
494
+ except Exception as e:
495
+ print("Failed to extract seed from imageUrl", e)
496
+
497
+ print("Setting seed", seed)
498
+ generator = get_generators(seed)
499
+
500
+ if not init_image:
501
+ init_image = download_image(imageUrl).resize((width, height))
502
+
503
+ condition_image = ImageUtil.resize_image(init_image, 1024)
504
+ if get_is_sdxl():
505
+ condition_image = condition_image.resize(init_image.size)
506
+ else:
507
+ condition_image = self.__resize_for_condition_image(
508
+ init_image, resize_dimension
509
+ )
510
+
511
+ if get_is_sdxl():
512
+ kwargs["strength"] = 1.0
513
+ kwargs["controlnet_conditioning_scale"] = 1.0
514
+ kwargs["image"] = init_image
515
+ else:
516
+ kwargs["image"] = condition_image
517
+ kwargs["guidance_scale"] = guidance_scale
518
 
519
  kwargs = {
 
520
  "prompt": prompt,
521
  "control_image": condition_image,
522
  "num_inference_steps": num_inference_steps,
523
  "negative_prompt": negative_prompt,
524
  "height": condition_image.size[1],
525
  "width": condition_image.size[0],
526
+ "generator": generator,
527
  **kwargs,
528
  }
529
  result = self.pipe.__call__(**kwargs)
530
+ return Result.from_result(result), condition_image
531
 
532
  @torch.inference_mode()
533
  def process_scribble(
 
540
  height: int,
541
  width: int,
542
  guidance_scale: float = 7.5,
543
+ apply_preprocess: bool = True,
544
  **kwargs,
545
  ):
546
  if self.__current_task_name != "scribble":
547
  raise Exception("ControlNet is not loaded with scribble model")
548
 
549
+ generator = get_generators(seed, get_num_return_sequences())
550
+
551
+ if apply_preprocess:
552
+ if get_is_sdxl():
553
+ # We use sketch in SDXL
554
+ image = [
555
+ ControlNet.pidinet_image(image[0]).resize((width, height))
556
+ ] * len(image)
557
+ else:
558
+ image = [
559
+ ControlNet.scribble_image(image[0]).resize((width, height))
560
+ ] * len(image)
561
 
562
  sdxl_args = (
563
  {
564
+ "guidance_scale": guidance_scale,
565
  "adapter_conditioning_scale": 1.0,
566
  "adapter_conditioning_factor": 1.0,
567
  }
 
577
  "height": height,
578
  "width": width,
579
  "guidance_scale": guidance_scale,
580
+ "generator": generator,
581
  **sdxl_args,
582
  **kwargs,
583
  }
584
  result = self.pipe2.__call__(**kwargs)
585
+ return Result.from_result(result), image[0]
586
 
587
  @torch.inference_mode()
588
  def process_linearart(
 
595
  height: int,
596
  width: int,
597
  guidance_scale: float = 7.5,
598
+ apply_preprocess: bool = True,
599
  **kwargs,
600
  ):
601
  if self.__current_task_name != "linearart":
602
  raise Exception("ControlNet is not loaded with linearart model")
603
 
604
+ generator = get_generators(seed, get_num_return_sequences())
605
 
606
+ init_image = self.preprocess_image(imageUrl, width, height)
607
+
608
+ if apply_preprocess:
609
+ condition_image = ControlNet.linearart_condition_image(init_image)
610
+ condition_image = condition_image.resize(init_image.size)
611
+ else:
612
+ condition_image = init_image
613
 
614
  # we use t2i adapter and the conditioning scale should always be 0.8
615
  sdxl_args = (
616
  {
617
+ "guidance_scale": guidance_scale,
618
  "adapter_conditioning_scale": 1.0,
619
  "adapter_conditioning_factor": 1.0,
620
  }
 
623
  )
624
 
625
  kwargs = {
626
+ "image": [condition_image] * get_num_return_sequences(),
627
+ "prompt": prompt,
628
+ "num_inference_steps": num_inference_steps,
629
+ "negative_prompt": negative_prompt,
630
+ "height": height,
631
+ "width": width,
632
+ "guidance_scale": guidance_scale,
633
+ "generator": generator,
634
+ **sdxl_args,
635
+ **kwargs,
636
+ }
637
+ result = self.pipe2.__call__(**kwargs)
638
+ return Result.from_result(result), condition_image
639
+
640
+ @torch.inference_mode()
641
+ def process_depth(
642
+ self,
643
+ imageUrl: str,
644
+ prompt: Union[str, List[str]],
645
+ negative_prompt: Union[str, List[str]],
646
+ num_inference_steps: int,
647
+ seed: int,
648
+ height: int,
649
+ width: int,
650
+ guidance_scale: float = 7.5,
651
+ apply_preprocess: bool = True,
652
+ **kwargs,
653
+ ):
654
+ if self.__current_task_name != "depth":
655
+ raise Exception("ControlNet is not loaded with depth model")
656
+
657
+ generator = get_generators(seed, get_num_return_sequences())
658
+
659
+ init_image = self.preprocess_image(imageUrl, width, height)
660
+
661
+ if apply_preprocess:
662
+ condition_image = ControlNet.depth_image(init_image)
663
+ condition_image = condition_image.resize(init_image.size)
664
+ else:
665
+ condition_image = init_image
666
+
667
+ # for using the depth controlnet in this SDXL model, these hyperparamters are optimal
668
+ sdxl_args = (
669
+ {"controlnet_conditioning_scale": 0.2, "control_guidance_end": 0.2}
670
+ if get_is_sdxl()
671
+ else {}
672
+ )
673
+
674
+ kwargs = {
675
+ "image": [condition_image] * get_num_return_sequences(),
676
  "prompt": prompt,
677
  "num_inference_steps": num_inference_steps,
678
  "negative_prompt": negative_prompt,
679
  "height": height,
680
  "width": width,
681
  "guidance_scale": guidance_scale,
682
+ "generator": generator,
683
  **sdxl_args,
684
  **kwargs,
685
  }
686
  result = self.pipe2.__call__(**kwargs)
687
+ return Result.from_result(result), condition_image
688
 
689
  def cleanup(self):
690
  """Doesn't do anything considering new diffusers has itself a cleanup mechanism
 
707
  def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
708
  processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
709
  if get_is_sdxl():
710
+ kwargs = {"detect_resolution": 384, "image_resolution": 1024, **kwargs}
711
+ else:
712
+ kwargs = {}
713
 
714
  image = processor.__call__(input_image=image, **kwargs)
715
  return image
716
 
717
  @staticmethod
718
+ @torch.inference_mode()
719
  def depth_image(image: Image.Image) -> Image.Image:
720
  global midas, midas_transforms
721
  if "midas" not in globals():
 
761
  canny_image = Image.fromarray(image_array)
762
  return canny_image
763
 
764
+ def preprocess_image(self, imageUrl, width, height) -> Image.Image:
765
+ image = download_image(imageUrl, mode="RGBA").resize((width, height))
766
+ return ImageUtil.alpha_to_white(image)
767
+
768
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
769
  input_image = image.convert("RGB")
770
  W, H = input_image.size
 
782
  "linearart": "lllyasviel/control_v11p_sd15_lineart",
783
  "scribble": "lllyasviel/control_v11p_sd15_scribble",
784
  "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
785
+ "canny_2x": "lllyasviel/control_v11p_sd15_canny, lllyasviel/control_v11f1p_sd15_depth",
786
  }
787
  __model_normal_types = {
788
  "pose": "controlnet",
 
790
  "linearart": "controlnet",
791
  "scribble": "controlnet",
792
  "tile_upscaler": "controlnet",
793
+ "canny_2x": "controlnet",
794
  }
795
 
796
  __model_sdxl = {
797
  "pose": "thibaud/controlnet-openpose-sdxl-1.0",
798
+ "canny": "Autodraft/controlnet-canny-sdxl-1.0",
799
+ "depth": "Autodraft/controlnet-depth-sdxl-1.0",
800
+ "canny_2x": "Autodraft/controlnet-canny-sdxl-1.0, Autodraft/controlnet-depth-sdxl-1.0",
801
  "linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
802
  "scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
803
+ "tile_upscaler": "Autodraft/ControlNet_SDXL_tile_upscale",
804
  }
805
  __model_sdxl_types = {
806
  "pose": "controlnet",
807
  "canny": "controlnet",
808
+ "canny_2x": "controlnet",
809
+ "depth": "controlnet",
810
  "linearart": "t2i",
811
  "scribble": "t2i",
812
+ "tile_upscaler": "controlnet",
813
  }
internals/pipelines/high_res.py CHANGED
@@ -1,15 +1,22 @@
1
  import math
2
- from typing import List, Optional
3
 
4
  from PIL import Image
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
 
8
  from internals.util.cache import clear_cuda_and_gc
9
- from internals.util.config import get_base_dimension, get_is_sdxl, get_model_dir
 
 
 
 
 
 
10
 
11
 
12
- class HighRes(AbstractPipeline):
13
  def load(self, img2img: Optional[Img2Img] = None):
14
  if hasattr(self, "pipe"):
15
  return
@@ -21,6 +28,9 @@ class HighRes(AbstractPipeline):
21
  self.pipe = img2img.pipe
22
  self.img2img = img2img
23
 
 
 
 
24
  def apply(
25
  self,
26
  prompt: List[str],
@@ -28,6 +38,7 @@ class HighRes(AbstractPipeline):
28
  images,
29
  width: int,
30
  height: int,
 
31
  num_inference_steps: int,
32
  strength: float = 0.5,
33
  guidance_scale: int = 9,
@@ -35,7 +46,18 @@ class HighRes(AbstractPipeline):
35
  ):
36
  clear_cuda_and_gc()
37
 
 
 
38
  images = [image.resize((width, height)) for image in images]
 
 
 
 
 
 
 
 
 
39
  kwargs = {
40
  "prompt": prompt,
41
  "image": images,
@@ -43,9 +65,16 @@ class HighRes(AbstractPipeline):
43
  "negative_prompt": negative_prompt,
44
  "guidance_scale": guidance_scale,
45
  "num_inference_steps": num_inference_steps,
 
46
  **kwargs,
47
  }
 
 
48
  result = self.pipe.__call__(**kwargs)
 
 
 
 
49
  return Result.from_result(result)
50
 
51
  @staticmethod
 
1
  import math
2
+ from typing import Dict, List, Optional
3
 
4
  from PIL import Image
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
8
+ from internals.util import get_generators
9
  from internals.util.cache import clear_cuda_and_gc
10
+ from internals.util.config import (
11
+ get_base_dimension,
12
+ get_is_sdxl,
13
+ get_model_dir,
14
+ get_num_return_sequences,
15
+ )
16
+ from internals.util.sdxl_lightning import LightningMixin
17
 
18
 
19
+ class HighRes(AbstractPipeline, LightningMixin):
20
  def load(self, img2img: Optional[Img2Img] = None):
21
  if hasattr(self, "pipe"):
22
  return
 
28
  self.pipe = img2img.pipe
29
  self.img2img = img2img
30
 
31
+ if get_is_sdxl():
32
+ self.configure_sdxl_lightning(img2img.pipe)
33
+
34
  def apply(
35
  self,
36
  prompt: List[str],
 
38
  images,
39
  width: int,
40
  height: int,
41
+ seed: int,
42
  num_inference_steps: int,
43
  strength: float = 0.5,
44
  guidance_scale: int = 9,
 
46
  ):
47
  clear_cuda_and_gc()
48
 
49
+ generator = get_generators(seed, get_num_return_sequences())
50
+
51
  images = [image.resize((width, height)) for image in images]
52
+
53
+ # if get_is_sdxl():
54
+ # kwargs["guidance_scale"] = kwargs.get("guidance_scale", 15)
55
+ # kwargs["strength"] = kwargs.get("strength", 0.6)
56
+
57
+ if get_is_sdxl():
58
+ extra_args = self.enable_sdxl_lightning()
59
+ kwargs.update(extra_args)
60
+
61
  kwargs = {
62
  "prompt": prompt,
63
  "image": images,
 
65
  "negative_prompt": negative_prompt,
66
  "guidance_scale": guidance_scale,
67
  "num_inference_steps": num_inference_steps,
68
+ "generator": generator,
69
  **kwargs,
70
  }
71
+
72
+ print(kwargs)
73
  result = self.pipe.__call__(**kwargs)
74
+
75
+ if get_is_sdxl():
76
+ self.disable_sdxl_lightning()
77
+
78
  return Result.from_result(result)
79
 
80
  @staticmethod
internals/pipelines/inpaint_imageprocessor.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ PipelineImageInput = Union[
26
+ PIL.Image.Image,
27
+ np.ndarray,
28
+ torch.FloatTensor,
29
+ List[PIL.Image.Image],
30
+ List[np.ndarray],
31
+ List[torch.FloatTensor],
32
+ ]
33
+
34
+ PipelineDepthInput = PipelineImageInput
35
+
36
+
37
+ class VaeImageProcessor(ConfigMixin):
38
+ """
39
+ Image processor for VAE.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
44
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
45
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
46
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
47
+ resample (`str`, *optional*, defaults to `lanczos`):
48
+ Resampling filter to use when resizing the image.
49
+ do_normalize (`bool`, *optional*, defaults to `True`):
50
+ Whether to normalize the image to [-1,1].
51
+ do_binarize (`bool`, *optional*, defaults to `False`):
52
+ Whether to binarize the image to 0/1.
53
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
54
+ Whether to convert the images to RGB format.
55
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
56
+ Whether to convert the images to grayscale format.
57
+ """
58
+
59
+ config_name = CONFIG_NAME
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ vae_scale_factor: int = 8,
66
+ resample: str = "lanczos",
67
+ do_normalize: bool = True,
68
+ do_binarize: bool = False,
69
+ do_convert_rgb: bool = False,
70
+ do_convert_grayscale: bool = False,
71
+ ):
72
+ super().__init__()
73
+ if do_convert_rgb and do_convert_grayscale:
74
+ raise ValueError(
75
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
76
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
77
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
78
+ )
79
+ self.config.do_convert_rgb = False
80
+
81
+ @staticmethod
82
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
83
+ """
84
+ Convert a numpy image or a batch of images to a PIL image.
85
+ """
86
+ if images.ndim == 3:
87
+ images = images[None, ...]
88
+ images = (images * 255).round().astype("uint8")
89
+ if images.shape[-1] == 1:
90
+ # special case for grayscale (single channel) images
91
+ pil_images = [
92
+ Image.fromarray(image.squeeze(), mode="L") for image in images
93
+ ]
94
+ else:
95
+ pil_images = [Image.fromarray(image) for image in images]
96
+
97
+ return pil_images
98
+
99
+ @staticmethod
100
+ def pil_to_numpy(
101
+ images: Union[List[PIL.Image.Image], PIL.Image.Image]
102
+ ) -> np.ndarray:
103
+ """
104
+ Convert a PIL image or a list of PIL images to NumPy arrays.
105
+ """
106
+ if not isinstance(images, list):
107
+ images = [images]
108
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
109
+ images = np.stack(images, axis=0)
110
+
111
+ return images
112
+
113
+ @staticmethod
114
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
115
+ """
116
+ Convert a NumPy image to a PyTorch tensor.
117
+ """
118
+ if images.ndim == 3:
119
+ images = images[..., None]
120
+
121
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
122
+ return images
123
+
124
+ @staticmethod
125
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
126
+ """
127
+ Convert a PyTorch tensor to a NumPy image.
128
+ """
129
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
130
+ return images
131
+
132
+ @staticmethod
133
+ def normalize(
134
+ images: Union[np.ndarray, torch.Tensor]
135
+ ) -> Union[np.ndarray, torch.Tensor]:
136
+ """
137
+ Normalize an image array to [-1,1].
138
+ """
139
+ return 2.0 * images - 1.0
140
+
141
+ @staticmethod
142
+ def denormalize(
143
+ images: Union[np.ndarray, torch.Tensor]
144
+ ) -> Union[np.ndarray, torch.Tensor]:
145
+ """
146
+ Denormalize an image array to [0,1].
147
+ """
148
+ return (images / 2 + 0.5).clamp(0, 1)
149
+
150
+ @staticmethod
151
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
152
+ """
153
+ Converts a PIL image to RGB format.
154
+ """
155
+ image = image.convert("RGB")
156
+
157
+ return image
158
+
159
+ @staticmethod
160
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
161
+ """
162
+ Converts a PIL image to grayscale format.
163
+ """
164
+ image = image.convert("L")
165
+
166
+ return image
167
+
168
+ @staticmethod
169
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
170
+ """
171
+ Applies Gaussian blur to an image.
172
+ """
173
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
174
+
175
+ return image
176
+
177
+ @staticmethod
178
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
179
+ """
180
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
181
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
182
+
183
+ Args:
184
+ mask_image (PIL.Image.Image): Mask image.
185
+ width (int): Width of the image to be processed.
186
+ height (int): Height of the image to be processed.
187
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
188
+
189
+ Returns:
190
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
191
+ """
192
+
193
+ mask_image = mask_image.convert("L")
194
+ mask = np.array(mask_image)
195
+
196
+ # 1. find a rectangular region that contains all masked ares in an image
197
+ h, w = mask.shape
198
+ crop_left = 0
199
+ for i in range(w):
200
+ if not (mask[:, i] == 0).all():
201
+ break
202
+ crop_left += 1
203
+
204
+ crop_right = 0
205
+ for i in reversed(range(w)):
206
+ if not (mask[:, i] == 0).all():
207
+ break
208
+ crop_right += 1
209
+
210
+ crop_top = 0
211
+ for i in range(h):
212
+ if not (mask[i] == 0).all():
213
+ break
214
+ crop_top += 1
215
+
216
+ crop_bottom = 0
217
+ for i in reversed(range(h)):
218
+ if not (mask[i] == 0).all():
219
+ break
220
+ crop_bottom += 1
221
+
222
+ # 2. add padding to the crop region
223
+ x1, y1, x2, y2 = (
224
+ int(max(crop_left - pad, 0)),
225
+ int(max(crop_top - pad, 0)),
226
+ int(min(w - crop_right + pad, w)),
227
+ int(min(h - crop_bottom + pad, h)),
228
+ )
229
+
230
+ # 3. expands crop region to match the aspect ratio of the image to be processed
231
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
232
+ ratio_processing = width / height
233
+
234
+ if ratio_crop_region > ratio_processing:
235
+ desired_height = (x2 - x1) / ratio_processing
236
+ desired_height_diff = int(desired_height - (y2 - y1))
237
+ y1 -= desired_height_diff // 2
238
+ y2 += desired_height_diff - desired_height_diff // 2
239
+ if y2 >= mask_image.height:
240
+ diff = y2 - mask_image.height
241
+ y2 -= diff
242
+ y1 -= diff
243
+ if y1 < 0:
244
+ y2 -= y1
245
+ y1 -= y1
246
+ if y2 >= mask_image.height:
247
+ y2 = mask_image.height
248
+ else:
249
+ desired_width = (y2 - y1) * ratio_processing
250
+ desired_width_diff = int(desired_width - (x2 - x1))
251
+ x1 -= desired_width_diff // 2
252
+ x2 += desired_width_diff - desired_width_diff // 2
253
+ if x2 >= mask_image.width:
254
+ diff = x2 - mask_image.width
255
+ x2 -= diff
256
+ x1 -= diff
257
+ if x1 < 0:
258
+ x2 -= x1
259
+ x1 -= x1
260
+ if x2 >= mask_image.width:
261
+ x2 = mask_image.width
262
+
263
+ return x1, y1, x2, y2
264
+
265
+ def _resize_and_fill(
266
+ self,
267
+ image: PIL.Image.Image,
268
+ width: int,
269
+ height: int,
270
+ ) -> PIL.Image.Image:
271
+ """
272
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
273
+
274
+ Args:
275
+ image: The image to resize.
276
+ width: The width to resize the image to.
277
+ height: The height to resize the image to.
278
+ """
279
+
280
+ ratio = width / height
281
+ src_ratio = image.width / image.height
282
+
283
+ src_w = width if ratio < src_ratio else image.width * height // image.height
284
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
285
+
286
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
287
+ res = Image.new("RGB", (width, height))
288
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
289
+
290
+ if ratio < src_ratio:
291
+ fill_height = height // 2 - src_h // 2
292
+ if fill_height > 0:
293
+ res.paste(
294
+ resized.resize((width, fill_height), box=(0, 0, width, 0)),
295
+ box=(0, 0),
296
+ )
297
+ res.paste(
298
+ resized.resize(
299
+ (width, fill_height),
300
+ box=(0, resized.height, width, resized.height),
301
+ ),
302
+ box=(0, fill_height + src_h),
303
+ )
304
+ elif ratio > src_ratio:
305
+ fill_width = width // 2 - src_w // 2
306
+ if fill_width > 0:
307
+ res.paste(
308
+ resized.resize((fill_width, height), box=(0, 0, 0, height)),
309
+ box=(0, 0),
310
+ )
311
+ res.paste(
312
+ resized.resize(
313
+ (fill_width, height),
314
+ box=(resized.width, 0, resized.width, height),
315
+ ),
316
+ box=(fill_width + src_w, 0),
317
+ )
318
+
319
+ return res
320
+
321
+ def _resize_and_crop(
322
+ self,
323
+ image: PIL.Image.Image,
324
+ width: int,
325
+ height: int,
326
+ ) -> PIL.Image.Image:
327
+ """
328
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
329
+
330
+ Args:
331
+ image: The image to resize.
332
+ width: The width to resize the image to.
333
+ height: The height to resize the image to.
334
+ """
335
+ ratio = width / height
336
+ src_ratio = image.width / image.height
337
+
338
+ src_w = width if ratio > src_ratio else image.width * height // image.height
339
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
340
+
341
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
342
+ res = Image.new("RGB", (width, height))
343
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
344
+ return res
345
+
346
+ def resize(
347
+ self,
348
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
349
+ height: int,
350
+ width: int,
351
+ resize_mode: str = "default", # "defalt", "fill", "crop"
352
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
353
+ """
354
+ Resize image.
355
+
356
+ Args:
357
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
358
+ The image input, can be a PIL image, numpy array or pytorch tensor.
359
+ height (`int`):
360
+ The height to resize to.
361
+ width (`int`):
362
+ The width to resize to.
363
+ resize_mode (`str`, *optional*, defaults to `default`):
364
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
365
+ within the specified width and height, and it may not maintaining the original aspect ratio.
366
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
367
+ within the dimensions, filling empty with data from image.
368
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
369
+ within the dimensions, cropping the excess.
370
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
371
+
372
+ Returns:
373
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
374
+ The resized image.
375
+ """
376
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
377
+ raise ValueError(
378
+ f"Only PIL image input is supported for resize_mode {resize_mode}"
379
+ )
380
+ if isinstance(image, PIL.Image.Image):
381
+ if resize_mode == "default":
382
+ image = image.resize(
383
+ (width, height), resample=PIL_INTERPOLATION[self.config.resample]
384
+ )
385
+ elif resize_mode == "fill":
386
+ image = self._resize_and_fill(image, width, height)
387
+ elif resize_mode == "crop":
388
+ image = self._resize_and_crop(image, width, height)
389
+ else:
390
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
391
+
392
+ elif isinstance(image, torch.Tensor):
393
+ image = torch.nn.functional.interpolate(
394
+ image,
395
+ size=(height, width),
396
+ )
397
+ elif isinstance(image, np.ndarray):
398
+ image = self.numpy_to_pt(image)
399
+ image = torch.nn.functional.interpolate(
400
+ image,
401
+ size=(height, width),
402
+ )
403
+ image = self.pt_to_numpy(image)
404
+ return image
405
+
406
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
407
+ """
408
+ Create a mask.
409
+
410
+ Args:
411
+ image (`PIL.Image.Image`):
412
+ The image input, should be a PIL image.
413
+
414
+ Returns:
415
+ `PIL.Image.Image`:
416
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
417
+ """
418
+ image[image < 0.5] = 0
419
+ image[image >= 0.5] = 1
420
+
421
+ return image
422
+
423
+ def get_default_height_width(
424
+ self,
425
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
426
+ height: Optional[int] = None,
427
+ width: Optional[int] = None,
428
+ ) -> Tuple[int, int]:
429
+ """
430
+ This function return the height and width that are downscaled to the next integer multiple of
431
+ `vae_scale_factor`.
432
+
433
+ Args:
434
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
435
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
436
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
437
+ have shape `[batch, channel, height, width]`.
438
+ height (`int`, *optional*, defaults to `None`):
439
+ The height in preprocessed image. If `None`, will use the height of `image` input.
440
+ width (`int`, *optional*`, defaults to `None`):
441
+ The width in preprocessed. If `None`, will use the width of the `image` input.
442
+ """
443
+
444
+ if height is None:
445
+ if isinstance(image, PIL.Image.Image):
446
+ height = image.height
447
+ elif isinstance(image, torch.Tensor):
448
+ height = image.shape[2]
449
+ else:
450
+ height = image.shape[1]
451
+
452
+ if width is None:
453
+ if isinstance(image, PIL.Image.Image):
454
+ width = image.width
455
+ elif isinstance(image, torch.Tensor):
456
+ width = image.shape[3]
457
+ else:
458
+ width = image.shape[2]
459
+
460
+ width, height = (
461
+ x - x % self.config.vae_scale_factor for x in (width, height)
462
+ ) # resize to integer multiple of vae_scale_factor
463
+
464
+ return height, width
465
+
466
+ def preprocess(
467
+ self,
468
+ image: PipelineImageInput,
469
+ height: Optional[int] = None,
470
+ width: Optional[int] = None,
471
+ resize_mode: str = "default", # "defalt", "fill", "crop"
472
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
473
+ ) -> torch.Tensor:
474
+ """
475
+ Preprocess the image input.
476
+
477
+ Args:
478
+ image (`pipeline_image_input`):
479
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
480
+ height (`int`, *optional*, defaults to `None`):
481
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
482
+ width (`int`, *optional*`, defaults to `None`):
483
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
484
+ resize_mode (`str`, *optional*, defaults to `default`):
485
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
486
+ within the specified width and height, and it may not maintaining the original aspect ratio.
487
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
488
+ within the dimensions, filling empty with data from image.
489
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
490
+ within the dimensions, cropping the excess.
491
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
492
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
493
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
494
+ """
495
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
496
+
497
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
498
+ if (
499
+ self.config.do_convert_grayscale
500
+ and isinstance(image, (torch.Tensor, np.ndarray))
501
+ and image.ndim == 3
502
+ ):
503
+ if isinstance(image, torch.Tensor):
504
+ # if image is a pytorch tensor could have 2 possible shapes:
505
+ # 1. batch x height x width: we should insert the channel dimension at position 1
506
+ # 2. channnel x height x width: we should insert batch dimension at position 0,
507
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
508
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
509
+ image = image.unsqueeze(1)
510
+ else:
511
+ # if it is a numpy array, it could have 2 possible shapes:
512
+ # 1. batch x height x width: insert channel dimension on last position
513
+ # 2. height x width x channel: insert batch dimension on first position
514
+ if image.shape[-1] == 1:
515
+ image = np.expand_dims(image, axis=0)
516
+ else:
517
+ image = np.expand_dims(image, axis=-1)
518
+
519
+ if isinstance(image, supported_formats):
520
+ image = [image]
521
+ elif not (
522
+ isinstance(image, list)
523
+ and all(isinstance(i, supported_formats) for i in image)
524
+ ):
525
+ raise ValueError(
526
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
527
+ )
528
+
529
+ if isinstance(image[0], PIL.Image.Image):
530
+ if crops_coords is not None:
531
+ image = [i.crop(crops_coords) for i in image]
532
+ if self.config.do_resize:
533
+ height, width = self.get_default_height_width(image[0], height, width)
534
+ image = [
535
+ self.resize(i, height, width, resize_mode=resize_mode)
536
+ for i in image
537
+ ]
538
+ if self.config.do_convert_rgb:
539
+ image = [self.convert_to_rgb(i) for i in image]
540
+ elif self.config.do_convert_grayscale:
541
+ image = [self.convert_to_grayscale(i) for i in image]
542
+ image = self.pil_to_numpy(image) # to np
543
+ image = self.numpy_to_pt(image) # to pt
544
+
545
+ elif isinstance(image[0], np.ndarray):
546
+ image = (
547
+ np.concatenate(image, axis=0)
548
+ if image[0].ndim == 4
549
+ else np.stack(image, axis=0)
550
+ )
551
+
552
+ image = self.numpy_to_pt(image)
553
+
554
+ height, width = self.get_default_height_width(image, height, width)
555
+ if self.config.do_resize:
556
+ image = self.resize(image, height, width)
557
+
558
+ elif isinstance(image[0], torch.Tensor):
559
+ image = (
560
+ torch.cat(image, axis=0)
561
+ if image[0].ndim == 4
562
+ else torch.stack(image, axis=0)
563
+ )
564
+
565
+ if self.config.do_convert_grayscale and image.ndim == 3:
566
+ image = image.unsqueeze(1)
567
+
568
+ channel = image.shape[1]
569
+ # don't need any preprocess if the image is latents
570
+ if channel == 4:
571
+ return image
572
+
573
+ height, width = self.get_default_height_width(image, height, width)
574
+ if self.config.do_resize:
575
+ image = self.resize(image, height, width)
576
+
577
+ # expected range [0,1], normalize to [-1,1]
578
+ do_normalize = self.config.do_normalize
579
+ if do_normalize and image.min() < 0:
580
+ warnings.warn(
581
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
582
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
583
+ FutureWarning,
584
+ )
585
+ do_normalize = False
586
+
587
+ if do_normalize:
588
+ image = self.normalize(image)
589
+
590
+ if self.config.do_binarize:
591
+ image = self.binarize(image)
592
+
593
+ return image
594
+
595
+ def postprocess(
596
+ self,
597
+ image: torch.FloatTensor,
598
+ output_type: str = "pil",
599
+ do_denormalize: Optional[List[bool]] = None,
600
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
601
+ """
602
+ Postprocess the image output from tensor to `output_type`.
603
+
604
+ Args:
605
+ image (`torch.FloatTensor`):
606
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
607
+ output_type (`str`, *optional*, defaults to `pil`):
608
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
609
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
610
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
611
+ `VaeImageProcessor` config.
612
+
613
+ Returns:
614
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
615
+ The postprocessed image.
616
+ """
617
+ if not isinstance(image, torch.Tensor):
618
+ raise ValueError(
619
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
620
+ )
621
+ if output_type not in ["latent", "pt", "np", "pil"]:
622
+ deprecation_message = (
623
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
624
+ "`pil`, `np`, `pt`, `latent`"
625
+ )
626
+ deprecate(
627
+ "Unsupported output_type",
628
+ "1.0.0",
629
+ deprecation_message,
630
+ standard_warn=False,
631
+ )
632
+ output_type = "np"
633
+
634
+ if output_type == "latent":
635
+ return image
636
+
637
+ if do_denormalize is None:
638
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
639
+
640
+ image = torch.stack(
641
+ [
642
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
643
+ for i in range(image.shape[0])
644
+ ]
645
+ )
646
+
647
+ if output_type == "pt":
648
+ return image
649
+
650
+ image = self.pt_to_numpy(image)
651
+
652
+ if output_type == "np":
653
+ return image
654
+
655
+ if output_type == "pil":
656
+ return self.numpy_to_pil(image)
657
+
658
+ def apply_overlay(
659
+ self,
660
+ mask: PIL.Image.Image,
661
+ init_image: PIL.Image.Image,
662
+ image: PIL.Image.Image,
663
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
664
+ ) -> PIL.Image.Image:
665
+ """
666
+ overlay the inpaint output to the original image
667
+ """
668
+
669
+ image = image.resize(init_image.size)
670
+ width, height = image.width, image.height
671
+
672
+ init_image = self.resize(init_image, width=width, height=height)
673
+ mask = self.resize(mask, width=width, height=height)
674
+
675
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
676
+ init_image_masked.paste(
677
+ init_image.convert("RGBA").convert("RGBa"),
678
+ mask=ImageOps.invert(mask.convert("L")),
679
+ )
680
+ init_image_masked = init_image_masked.convert("RGBA")
681
+
682
+ if crop_coords is not None:
683
+ x, y, x2, y2 = crop_coords
684
+ w = x2 - x
685
+ h = y2 - y
686
+ base_image = PIL.Image.new("RGBA", (width, height))
687
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
688
+ base_image.paste(image, (x, y))
689
+ image = base_image.convert("RGB")
690
+
691
+ image = image.convert("RGBA")
692
+ image.alpha_composite(init_image_masked)
693
+ image = image.convert("RGB")
694
+
695
+ return image
696
+
697
+
698
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
699
+ """
700
+ Image processor for VAE LDM3D.
701
+
702
+ Args:
703
+ do_resize (`bool`, *optional*, defaults to `True`):
704
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
705
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
706
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
707
+ resample (`str`, *optional*, defaults to `lanczos`):
708
+ Resampling filter to use when resizing the image.
709
+ do_normalize (`bool`, *optional*, defaults to `True`):
710
+ Whether to normalize the image to [-1,1].
711
+ """
712
+
713
+ config_name = CONFIG_NAME
714
+
715
+ @register_to_config
716
+ def __init__(
717
+ self,
718
+ do_resize: bool = True,
719
+ vae_scale_factor: int = 8,
720
+ resample: str = "lanczos",
721
+ do_normalize: bool = True,
722
+ ):
723
+ super().__init__()
724
+
725
+ @staticmethod
726
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
727
+ """
728
+ Convert a NumPy image or a batch of images to a PIL image.
729
+ """
730
+ if images.ndim == 3:
731
+ images = images[None, ...]
732
+ images = (images * 255).round().astype("uint8")
733
+ if images.shape[-1] == 1:
734
+ # special case for grayscale (single channel) images
735
+ pil_images = [
736
+ Image.fromarray(image.squeeze(), mode="L") for image in images
737
+ ]
738
+ else:
739
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
740
+
741
+ return pil_images
742
+
743
+ @staticmethod
744
+ def depth_pil_to_numpy(
745
+ images: Union[List[PIL.Image.Image], PIL.Image.Image]
746
+ ) -> np.ndarray:
747
+ """
748
+ Convert a PIL image or a list of PIL images to NumPy arrays.
749
+ """
750
+ if not isinstance(images, list):
751
+ images = [images]
752
+
753
+ images = [
754
+ np.array(image).astype(np.float32) / (2**16 - 1) for image in images
755
+ ]
756
+ images = np.stack(images, axis=0)
757
+ return images
758
+
759
+ @staticmethod
760
+ def rgblike_to_depthmap(
761
+ image: Union[np.ndarray, torch.Tensor]
762
+ ) -> Union[np.ndarray, torch.Tensor]:
763
+ """
764
+ Args:
765
+ image: RGB-like depth image
766
+
767
+ Returns: depth map
768
+
769
+ """
770
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
771
+
772
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
773
+ """
774
+ Convert a NumPy depth image or a batch of images to a PIL image.
775
+ """
776
+ if images.ndim == 3:
777
+ images = images[None, ...]
778
+ images_depth = images[:, :, :, 3:]
779
+ if images.shape[-1] == 6:
780
+ images_depth = (images_depth * 255).round().astype("uint8")
781
+ pil_images = [
782
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16")
783
+ for image_depth in images_depth
784
+ ]
785
+ elif images.shape[-1] == 4:
786
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
787
+ pil_images = [
788
+ Image.fromarray(image_depth, mode="I;16")
789
+ for image_depth in images_depth
790
+ ]
791
+ else:
792
+ raise Exception("Not supported")
793
+
794
+ return pil_images
795
+
796
+ def postprocess(
797
+ self,
798
+ image: torch.FloatTensor,
799
+ output_type: str = "pil",
800
+ do_denormalize: Optional[List[bool]] = None,
801
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
802
+ """
803
+ Postprocess the image output from tensor to `output_type`.
804
+
805
+ Args:
806
+ image (`torch.FloatTensor`):
807
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
808
+ output_type (`str`, *optional*, defaults to `pil`):
809
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
810
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
811
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
812
+ `VaeImageProcessor` config.
813
+
814
+ Returns:
815
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
816
+ The postprocessed image.
817
+ """
818
+ if not isinstance(image, torch.Tensor):
819
+ raise ValueError(
820
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
821
+ )
822
+ if output_type not in ["latent", "pt", "np", "pil"]:
823
+ deprecation_message = (
824
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
825
+ "`pil`, `np`, `pt`, `latent`"
826
+ )
827
+ deprecate(
828
+ "Unsupported output_type",
829
+ "1.0.0",
830
+ deprecation_message,
831
+ standard_warn=False,
832
+ )
833
+ output_type = "np"
834
+
835
+ if do_denormalize is None:
836
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
837
+
838
+ image = torch.stack(
839
+ [
840
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
841
+ for i in range(image.shape[0])
842
+ ]
843
+ )
844
+
845
+ image = self.pt_to_numpy(image)
846
+
847
+ if output_type == "np":
848
+ if image.shape[-1] == 6:
849
+ image_depth = np.stack(
850
+ [self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0
851
+ )
852
+ else:
853
+ image_depth = image[:, :, :, 3:]
854
+ return image[:, :, :, :3], image_depth
855
+
856
+ if output_type == "pil":
857
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
858
+ else:
859
+ raise Exception(f"This type {output_type} is not supported")
860
+
861
+ def preprocess(
862
+ self,
863
+ rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
864
+ depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
865
+ height: Optional[int] = None,
866
+ width: Optional[int] = None,
867
+ target_res: Optional[int] = None,
868
+ ) -> torch.Tensor:
869
+ """
870
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
871
+ """
872
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
873
+
874
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
875
+ if (
876
+ self.config.do_convert_grayscale
877
+ and isinstance(rgb, (torch.Tensor, np.ndarray))
878
+ and rgb.ndim == 3
879
+ ):
880
+ raise Exception("This is not yet supported")
881
+
882
+ if isinstance(rgb, supported_formats):
883
+ rgb = [rgb]
884
+ depth = [depth]
885
+ elif not (
886
+ isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)
887
+ ):
888
+ raise ValueError(
889
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
890
+ )
891
+
892
+ if isinstance(rgb[0], PIL.Image.Image):
893
+ if self.config.do_convert_rgb:
894
+ raise Exception("This is not yet supported")
895
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
896
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
897
+ if self.config.do_resize or target_res:
898
+ height, width = (
899
+ self.get_default_height_width(rgb[0], height, width)
900
+ if not target_res
901
+ else target_res
902
+ )
903
+ rgb = [self.resize(i, height, width) for i in rgb]
904
+ depth = [self.resize(i, height, width) for i in depth]
905
+ rgb = self.pil_to_numpy(rgb) # to np
906
+ rgb = self.numpy_to_pt(rgb) # to pt
907
+
908
+ depth = self.depth_pil_to_numpy(depth) # to np
909
+ depth = self.numpy_to_pt(depth) # to pt
910
+
911
+ elif isinstance(rgb[0], np.ndarray):
912
+ rgb = (
913
+ np.concatenate(rgb, axis=0)
914
+ if rgb[0].ndim == 4
915
+ else np.stack(rgb, axis=0)
916
+ )
917
+ rgb = self.numpy_to_pt(rgb)
918
+ height, width = self.get_default_height_width(rgb, height, width)
919
+ if self.config.do_resize:
920
+ rgb = self.resize(rgb, height, width)
921
+
922
+ depth = (
923
+ np.concatenate(depth, axis=0)
924
+ if rgb[0].ndim == 4
925
+ else np.stack(depth, axis=0)
926
+ )
927
+ depth = self.numpy_to_pt(depth)
928
+ height, width = self.get_default_height_width(depth, height, width)
929
+ if self.config.do_resize:
930
+ depth = self.resize(depth, height, width)
931
+
932
+ elif isinstance(rgb[0], torch.Tensor):
933
+ raise Exception("This is not yet supported")
934
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
935
+
936
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
937
+ # rgb = rgb.unsqueeze(1)
938
+
939
+ # channel = rgb.shape[1]
940
+
941
+ # height, width = self.get_default_height_width(rgb, height, width)
942
+ # if self.config.do_resize:
943
+ # rgb = self.resize(rgb, height, width)
944
+
945
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
946
+
947
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
948
+ # depth = depth.unsqueeze(1)
949
+
950
+ # channel = depth.shape[1]
951
+ # # don't need any preprocess if the image is latents
952
+ # if depth == 4:
953
+ # return rgb, depth
954
+
955
+ # height, width = self.get_default_height_width(depth, height, width)
956
+ # if self.config.do_resize:
957
+ # depth = self.resize(depth, height, width)
958
+ # expected range [0,1], normalize to [-1,1]
959
+ do_normalize = self.config.do_normalize
960
+ if rgb.min() < 0 and do_normalize:
961
+ warnings.warn(
962
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
963
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
964
+ FutureWarning,
965
+ )
966
+ do_normalize = False
967
+
968
+ if do_normalize:
969
+ rgb = self.normalize(rgb)
970
+ depth = self.normalize(depth)
971
+
972
+ if self.config.do_binarize:
973
+ rgb = self.binarize(rgb)
974
+ depth = self.binarize(depth)
975
+
976
+ return rgb, depth
internals/pipelines/inpainter.py CHANGED
@@ -1,18 +1,27 @@
1
  from typing import List, Union
2
 
3
  import torch
4
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
 
 
 
 
5
 
6
  from internals.pipelines.commons import AbstractPipeline
 
 
 
7
  from internals.util.cache import clear_cuda_and_gc
8
  from internals.util.commons import disable_safety_checker, download_image
9
  from internals.util.config import (
 
10
  get_base_inpaint_model_variant,
11
  get_hf_cache_dir,
12
  get_hf_token,
13
  get_inpaint_model_path,
14
  get_is_sdxl,
15
  get_model_dir,
 
16
  )
17
 
18
 
@@ -32,13 +41,27 @@ class InPainter(AbstractPipeline):
32
  return
33
 
34
  if get_is_sdxl():
35
- self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
 
36
  get_inpaint_model_path(),
37
  torch_dtype=torch.float16,
38
  cache_dir=get_hf_cache_dir(),
39
  token=get_hf_token(),
 
40
  variant=get_base_inpaint_model_variant(),
 
41
  ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
42
  else:
43
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
44
  get_inpaint_model_path(),
@@ -90,11 +113,18 @@ class InPainter(AbstractPipeline):
90
  num_inference_steps: int,
91
  **kwargs,
92
  ):
93
- torch.manual_seed(seed)
94
 
95
  input_img = download_image(image_url).resize((width, height))
96
  mask_img = download_image(mask_image_url).resize((width, height))
97
 
 
 
 
 
 
 
 
98
  kwargs = {
99
  "prompt": prompt,
100
  "image": input_img,
@@ -104,6 +134,7 @@ class InPainter(AbstractPipeline):
104
  "negative_prompt": negative_prompt,
105
  "num_inference_steps": num_inference_steps,
106
  "strength": 1.0,
 
107
  **kwargs,
108
  }
109
- return self.pipe.__call__(**kwargs).images
 
1
  from typing import List, Union
2
 
3
  import torch
4
+ from diffusers import (
5
+ StableDiffusionInpaintPipeline,
6
+ StableDiffusionXLInpaintPipeline,
7
+ UNet2DConditionModel,
8
+ )
9
 
10
  from internals.pipelines.commons import AbstractPipeline
11
+ from internals.pipelines.high_res import HighRes
12
+ from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor
13
+ from internals.util import get_generators
14
  from internals.util.cache import clear_cuda_and_gc
15
  from internals.util.commons import disable_safety_checker, download_image
16
  from internals.util.config import (
17
+ get_base_inpaint_model_revision,
18
  get_base_inpaint_model_variant,
19
  get_hf_cache_dir,
20
  get_hf_token,
21
  get_inpaint_model_path,
22
  get_is_sdxl,
23
  get_model_dir,
24
+ get_num_return_sequences,
25
  )
26
 
27
 
 
41
  return
42
 
43
  if get_is_sdxl():
44
+ # only take UNet from the repo
45
+ unet = UNet2DConditionModel.from_pretrained(
46
  get_inpaint_model_path(),
47
  torch_dtype=torch.float16,
48
  cache_dir=get_hf_cache_dir(),
49
  token=get_hf_token(),
50
+ subfolder="unet",
51
  variant=get_base_inpaint_model_variant(),
52
+ revision=get_base_inpaint_model_revision(),
53
  ).to("cuda")
54
+ kwargs = {**self.__base.pipe.components, "unet": unet}
55
+ self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda")
56
+ self.pipe.mask_processor = VaeImageProcessor(
57
+ vae_scale_factor=self.pipe.vae_scale_factor,
58
+ do_normalize=False,
59
+ do_binarize=True,
60
+ do_convert_grayscale=True,
61
+ )
62
+ self.pipe.image_processor = VaeImageProcessor(
63
+ vae_scale_factor=self.pipe.vae_scale_factor
64
+ )
65
  else:
66
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
67
  get_inpaint_model_path(),
 
113
  num_inference_steps: int,
114
  **kwargs,
115
  ):
116
+ generator = get_generators(seed, get_num_return_sequences())
117
 
118
  input_img = download_image(image_url).resize((width, height))
119
  mask_img = download_image(mask_image_url).resize((width, height))
120
 
121
+ if get_is_sdxl():
122
+ width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
123
+ mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33)
124
+
125
+ kwargs["strength"] = 0.999
126
+ kwargs["padding_mask_crop"] = 1000
127
+
128
  kwargs = {
129
  "prompt": prompt,
130
  "image": input_img,
 
134
  "negative_prompt": negative_prompt,
135
  "num_inference_steps": num_inference_steps,
136
  "strength": 1.0,
137
+ "generator": generator,
138
  **kwargs,
139
  }
140
+ return self.pipe.__call__(**kwargs).images, mask_img
internals/pipelines/prompt_modifier.py CHANGED
@@ -2,6 +2,8 @@ from typing import List, Optional
2
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
 
 
5
 
6
  class PromptModifier:
7
  __loaded = False
@@ -38,7 +40,7 @@ class PromptModifier:
38
  do_sample=False,
39
  max_new_tokens=75,
40
  num_beams=4,
41
- num_return_sequences=num_of_sequences,
42
  eos_token_id=eos_id,
43
  pad_token_id=eos_id,
44
  length_penalty=-1.0,
 
2
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
5
+ from internals.util.config import get_num_return_sequences
6
+
7
 
8
  class PromptModifier:
9
  __loaded = False
 
40
  do_sample=False,
41
  max_new_tokens=75,
42
  num_beams=4,
43
+ num_return_sequences=get_num_return_sequences(),
44
  eos_token_id=eos_id,
45
  pad_token_id=eos_id,
46
  length_penalty=-1.0,
internals/pipelines/realtime_draw.py CHANGED
@@ -9,7 +9,13 @@ from internals.pipelines.commons import AbstractPipeline
9
  from internals.pipelines.controlnets import ControlNet
10
  from internals.pipelines.high_res import HighRes
11
  from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
12
- from internals.util.config import get_base_dimension, get_hf_cache_dir, get_is_sdxl
 
 
 
 
 
 
13
 
14
 
15
  class RealtimeDraw(AbstractPipeline):
@@ -60,7 +66,7 @@ class RealtimeDraw(AbstractPipeline):
60
  if get_is_sdxl():
61
  raise Exception("SDXL is not supported for this method")
62
 
63
- torch.manual_seed(seed)
64
 
65
  image = ImageUtil.resize_image(image, 512)
66
 
@@ -70,6 +76,7 @@ class RealtimeDraw(AbstractPipeline):
70
  prompt=prompt,
71
  num_inference_steps=15,
72
  negative_prompt=negative_prompt,
 
73
  guidance_scale=10,
74
  strength=0.8,
75
  ).images[0]
@@ -84,7 +91,7 @@ class RealtimeDraw(AbstractPipeline):
84
  image: Optional[Image.Image] = None,
85
  image2: Optional[Image.Image] = None,
86
  ):
87
- torch.manual_seed(seed)
88
 
89
  b_dimen = get_base_dimension()
90
 
@@ -104,6 +111,8 @@ class RealtimeDraw(AbstractPipeline):
104
  size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
105
  image = image.resize(size)
106
 
 
 
107
  images = self.pipe.__call__(
108
  image=image,
109
  condition_image=image,
@@ -129,6 +138,7 @@ class RealtimeDraw(AbstractPipeline):
129
  num_inference_steps=15,
130
  negative_prompt=negative_prompt,
131
  guidance_scale=10,
 
132
  strength=0.9,
133
  width=image.size[0],
134
  height=image.size[1],
 
9
  from internals.pipelines.controlnets import ControlNet
10
  from internals.pipelines.high_res import HighRes
11
  from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
12
+ from internals.util import get_generators
13
+ from internals.util.config import (
14
+ get_base_dimension,
15
+ get_hf_cache_dir,
16
+ get_is_sdxl,
17
+ get_num_return_sequences,
18
+ )
19
 
20
 
21
  class RealtimeDraw(AbstractPipeline):
 
66
  if get_is_sdxl():
67
  raise Exception("SDXL is not supported for this method")
68
 
69
+ generator = get_generators(seed, get_num_return_sequences())
70
 
71
  image = ImageUtil.resize_image(image, 512)
72
 
 
76
  prompt=prompt,
77
  num_inference_steps=15,
78
  negative_prompt=negative_prompt,
79
+ generator=generator,
80
  guidance_scale=10,
81
  strength=0.8,
82
  ).images[0]
 
91
  image: Optional[Image.Image] = None,
92
  image2: Optional[Image.Image] = None,
93
  ):
94
+ generator = get_generators(seed, get_num_return_sequences())
95
 
96
  b_dimen = get_base_dimension()
97
 
 
111
  size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
112
  image = image.resize(size)
113
 
114
+ torch.manual_seed(seed)
115
+
116
  images = self.pipe.__call__(
117
  image=image,
118
  condition_image=image,
 
138
  num_inference_steps=15,
139
  negative_prompt=negative_prompt,
140
  guidance_scale=10,
141
+ generator=generator,
142
  strength=0.9,
143
  width=image.size[0],
144
  height=image.size[1],
internals/pipelines/remove_background.py CHANGED
@@ -1,20 +1,22 @@
1
  import io
2
  from pathlib import Path
3
  from typing import Union
4
- import numpy as np
5
- import cv2
6
 
 
 
 
 
7
  import torch
8
  import torch.nn.functional as F
 
9
  from PIL import Image
10
  from rembg import remove
11
- from internals.data.task import ModelType
12
 
13
  import internals.util.image as ImageUtil
14
  from carvekit.api.high import HiInterface
 
15
  from internals.util.commons import download_image, read_url
16
- import onnxruntime as rt
17
- import huggingface_hub
18
 
19
 
20
  class RemoveBackground:
@@ -94,3 +96,51 @@ class RemoveBackgroundV2:
94
  img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
95
  mask = mask.repeat(3, axis=2)
96
  return mask, img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
  from pathlib import Path
3
  from typing import Union
 
 
4
 
5
+ import cv2
6
+ import huggingface_hub
7
+ import numpy as np
8
+ import onnxruntime as rt
9
  import torch
10
  import torch.nn.functional as F
11
+ from briarmbg import BriaRMBG # pyright: ignore
12
  from PIL import Image
13
  from rembg import remove
14
+ from torchvision.transforms.functional import normalize
15
 
16
  import internals.util.image as ImageUtil
17
  from carvekit.api.high import HiInterface
18
+ from internals.data.task import ModelType
19
  from internals.util.commons import download_image, read_url
 
 
20
 
21
 
22
  class RemoveBackground:
 
96
  img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
97
  mask = mask.repeat(3, axis=2)
98
  return mask, img
99
+
100
+
101
+ class RemoveBackgroundV3:
102
+ def __init__(self):
103
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
104
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
+ net.to(device)
106
+ self.net = net
107
+
108
+ def remove(self, image: Union[str, Image.Image]) -> Image.Image:
109
+ if type(image) is str:
110
+ image = download_image(image, mode="RGBA")
111
+
112
+ orig_image = image
113
+ w, h = orig_im_size = orig_image.size
114
+ image = self.__resize_image(orig_image)
115
+ im_np = np.array(image)
116
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
117
+ im_tensor = torch.unsqueeze(im_tensor, 0)
118
+ im_tensor = torch.divide(im_tensor, 255.0)
119
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
120
+ if torch.cuda.is_available():
121
+ im_tensor = im_tensor.cuda()
122
+
123
+ # inference
124
+ result = self.net(im_tensor)
125
+ # post process
126
+ result = torch.squeeze(
127
+ F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0
128
+ )
129
+ ma = torch.max(result)
130
+ mi = torch.min(result)
131
+ result = (result - mi) / (ma - mi)
132
+ # image to pil
133
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
134
+ pil_im = Image.fromarray(np.squeeze(im_array))
135
+ # paste the mask on the original image
136
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
137
+ new_im.paste(orig_image, mask=pil_im)
138
+ # new_orig_image = orig_image.convert('RGBA')
139
+
140
+ return new_im
141
+
142
+ def __resize_image(self, image):
143
+ image = image.convert("RGB")
144
+ model_input_size = (1024, 1024)
145
+ image = image.resize(model_input_size, Image.BILINEAR)
146
+ return image
internals/pipelines/replace_background.py CHANGED
@@ -16,11 +16,12 @@ import internals.util.image as ImageUtil
16
  from internals.data.result import Result
17
  from internals.data.task import ModelType
18
  from internals.pipelines.commons import AbstractPipeline
19
- from internals.pipelines.controlnets import ControlNet
20
  from internals.pipelines.high_res import HighRes
21
  from internals.pipelines.inpainter import InPainter
22
  from internals.pipelines.remove_background import RemoveBackgroundV2
23
  from internals.pipelines.upscaler import Upscaler
 
24
  from internals.util.cache import clear_cuda_and_gc
25
  from internals.util.commons import download_image
26
  from internals.util.config import (
@@ -28,6 +29,7 @@ from internals.util.config import (
28
  get_hf_token,
29
  get_inpaint_model_path,
30
  get_model_dir,
 
31
  )
32
 
33
 
@@ -43,11 +45,9 @@ class ReplaceBackground(AbstractPipeline):
43
  ):
44
  if self.__loaded:
45
  return
46
- controlnet_model = ControlNetModel.from_pretrained(
47
- "lllyasviel/control_v11p_sd15_canny",
48
- torch_dtype=torch.float16,
49
- cache_dir=get_hf_cache_dir(),
50
- ).to("cuda")
51
  if base:
52
  pipe = StableDiffusionControlNetPipeline(
53
  **base.pipe.components,
@@ -109,8 +109,7 @@ class ReplaceBackground(AbstractPipeline):
109
  if type(image) is str:
110
  image = download_image(image)
111
 
112
- torch.manual_seed(seed)
113
- torch.cuda.manual_seed(seed)
114
 
115
  image = image.convert("RGB")
116
  if max(image.size) > 1024:
@@ -148,6 +147,7 @@ class ReplaceBackground(AbstractPipeline):
148
  guidance_scale=9,
149
  height=height,
150
  num_inference_steps=steps,
 
151
  width=width,
152
  )
153
  result = Result.from_result(result)
 
16
  from internals.data.result import Result
17
  from internals.data.task import ModelType
18
  from internals.pipelines.commons import AbstractPipeline
19
+ from internals.pipelines.controlnets import ControlNet, load_network_model_by_key
20
  from internals.pipelines.high_res import HighRes
21
  from internals.pipelines.inpainter import InPainter
22
  from internals.pipelines.remove_background import RemoveBackgroundV2
23
  from internals.pipelines.upscaler import Upscaler
24
+ from internals.util import get_generators
25
  from internals.util.cache import clear_cuda_and_gc
26
  from internals.util.commons import download_image
27
  from internals.util.config import (
 
29
  get_hf_token,
30
  get_inpaint_model_path,
31
  get_model_dir,
32
+ get_num_return_sequences,
33
  )
34
 
35
 
 
45
  ):
46
  if self.__loaded:
47
  return
48
+ controlnet_model = load_network_model_by_key(
49
+ "lllyasviel/control_v11p_sd15_canny", "controlnet"
50
+ )
 
 
51
  if base:
52
  pipe = StableDiffusionControlNetPipeline(
53
  **base.pipe.components,
 
109
  if type(image) is str:
110
  image = download_image(image)
111
 
112
+ generator = get_generators(seed, get_num_return_sequences())
 
113
 
114
  image = image.convert("RGB")
115
  if max(image.size) > 1024:
 
147
  guidance_scale=9,
148
  height=height,
149
  num_inference_steps=steps,
150
+ generator=generator,
151
  width=width,
152
  )
153
  result = Result.from_result(result)
internals/pipelines/safety_checker.py CHANGED
@@ -31,10 +31,11 @@ class SafetyChecker:
31
  self.__loaded = True
32
 
33
  def apply(self, pipeline: AbstractPipeline):
34
- model = self.model if not get_nsfw_access() else None
35
- if model:
36
  self.load()
37
 
 
 
38
  if not pipeline:
39
  return
40
  if hasattr(pipeline, "pipe"):
 
31
  self.__loaded = True
32
 
33
  def apply(self, pipeline: AbstractPipeline):
34
+ if not get_nsfw_access():
 
35
  self.load()
36
 
37
+ model = self.model if not get_nsfw_access() else None
38
+
39
  if not pipeline:
40
  return
41
  if hasattr(pipeline, "pipe"):
internals/pipelines/sdxl_llite_pipeline.py CHANGED
@@ -1251,6 +1251,8 @@ class PipelineLike:
1251
 
1252
 
1253
  class SDXLLLiteImg2ImgPipeline:
 
 
1254
  def __init__(self):
1255
  self.SCHEDULER_LINEAR_START = 0.00085
1256
  self.SCHEDULER_LINEAR_END = 0.0120
@@ -1261,7 +1263,7 @@ class SDXLLLiteImg2ImgPipeline:
1261
 
1262
  def replace_unet_modules(
1263
  self,
1264
- unet: diffusers.models.unet_2d_condition.UNet2DConditionModel,
1265
  mem_eff_attn,
1266
  xformers,
1267
  sdpa,
 
1251
 
1252
 
1253
  class SDXLLLiteImg2ImgPipeline:
1254
+ from diffusers import UNet2DConditionModel
1255
+
1256
  def __init__(self):
1257
  self.SCHEDULER_LINEAR_START = 0.00085
1258
  self.SCHEDULER_LINEAR_END = 0.0120
 
1263
 
1264
  def replace_unet_modules(
1265
  self,
1266
+ unet: UNet2DConditionModel,
1267
  mem_eff_attn,
1268
  xformers,
1269
  sdpa,
internals/pipelines/sdxl_tile_upscale.py CHANGED
@@ -4,8 +4,10 @@ from PIL import Image
4
  from torchvision import transforms
5
 
6
  import internals.util.image as ImageUtils
 
7
  from carvekit.api import high
8
  from internals.data.result import Result
 
9
  from internals.pipelines.commons import AbstractPipeline, Text2Img
10
  from internals.pipelines.controlnets import ControlNet
11
  from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
@@ -19,18 +21,16 @@ controlnet = ControlNet()
19
 
20
  class SDXLTileUpscaler(AbstractPipeline):
21
  __loaded = False
 
22
 
23
  def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
24
  if self.__loaded:
25
  return
26
  # temporal hack for upscale model till multicontrolnet support is added
27
- model = (
28
- "thibaud/controlnet-openpose-sdxl-1.0"
29
- if int(model_id) == 2000293
30
- else "diffusers/controlnet-canny-sdxl-1.0"
31
- )
32
 
33
- controlnet = ControlNetModel.from_pretrained(model, torch_dtype=torch.float16)
 
 
34
  pipe = DemoFusionSDXLControlNetPipeline(
35
  **pipeline.pipe.components, controlnet=controlnet
36
  )
@@ -43,6 +43,7 @@ class SDXLTileUpscaler(AbstractPipeline):
43
 
44
  self.pipe = pipe
45
 
 
46
  self.__loaded = True
47
 
48
  def unload(self):
@@ -52,6 +53,26 @@ class SDXLTileUpscaler(AbstractPipeline):
52
 
53
  clear_cuda_and_gc()
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def process(
56
  self,
57
  prompt: str,
@@ -61,21 +82,36 @@ class SDXLTileUpscaler(AbstractPipeline):
61
  width: int,
62
  height: int,
63
  model_id: int,
 
 
64
  ):
65
- if int(model_id) == 2000293:
 
 
 
 
 
66
  condition_image = controlnet.detect_pose(imageUrl)
67
  else:
 
68
  condition_image = download_image(imageUrl)
69
  condition_image = ControlNet.canny_detect_edge(condition_image)
70
- img = download_image(imageUrl).resize((width, height))
71
 
72
- img = ImageUtils.resize_image(img, get_base_dimension())
73
  condition_image = condition_image.resize(img.size)
74
 
75
  img2 = self.__resize_for_condition_image(img, resize_dimension)
76
 
 
77
  image_lr = self.load_and_process_image(img)
78
- print("img", img2.size, img.size)
 
 
 
 
 
 
79
  if int(model_id) == 2000173:
80
  kwargs = {
81
  "prompt": prompt,
@@ -83,6 +119,7 @@ class SDXLTileUpscaler(AbstractPipeline):
83
  "image": img2,
84
  "strength": 0.3,
85
  "num_inference_steps": 30,
 
86
  }
87
  images = self.high_res.pipe.__call__(**kwargs).images
88
  else:
@@ -90,20 +127,24 @@ class SDXLTileUpscaler(AbstractPipeline):
90
  image_lr=image_lr,
91
  prompt=prompt,
92
  condition_image=condition_image,
93
- negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
 
94
  guidance_scale=11,
95
  sigma=0.8,
96
  num_inference_steps=24,
97
- width=img2.size[0],
98
- height=img2.size[1],
 
 
99
  )
100
  images = images[::-1]
 
 
101
  return images, False
102
 
103
  def load_and_process_image(self, pil_image):
104
  transform = transforms.Compose(
105
  [
106
- transforms.Resize((1024, 1024)),
107
  transforms.ToTensor(),
108
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
109
  ]
@@ -113,6 +154,36 @@ class SDXLTileUpscaler(AbstractPipeline):
113
  image = image.to("cuda")
114
  return image
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
117
  input_image = image.convert("RGB")
118
  W, H = input_image.size
 
4
  from torchvision import transforms
5
 
6
  import internals.util.image as ImageUtils
7
+ import internals.util.image as ImageUtil
8
  from carvekit.api import high
9
  from internals.data.result import Result
10
+ from internals.data.task import TaskType
11
  from internals.pipelines.commons import AbstractPipeline, Text2Img
12
  from internals.pipelines.controlnets import ControlNet
13
  from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
 
21
 
22
  class SDXLTileUpscaler(AbstractPipeline):
23
  __loaded = False
24
+ __current_process_mode = None
25
 
26
  def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
27
  if self.__loaded:
28
  return
29
  # temporal hack for upscale model till multicontrolnet support is added
 
 
 
 
 
30
 
31
+ controlnet = ControlNetModel.from_pretrained(
32
+ "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
33
+ )
34
  pipe = DemoFusionSDXLControlNetPipeline(
35
  **pipeline.pipe.components, controlnet=controlnet
36
  )
 
43
 
44
  self.pipe = pipe
45
 
46
+ self.__current_process_mode = TaskType.CANNY.name
47
  self.__loaded = True
48
 
49
  def unload(self):
 
53
 
54
  clear_cuda_and_gc()
55
 
56
+ def __reload_controlnet(self, process_mode: str):
57
+ if self.__current_process_mode == process_mode:
58
+ return
59
+
60
+ model = (
61
+ "thibaud/controlnet-openpose-sdxl-1.0"
62
+ if process_mode == TaskType.POSE.name
63
+ else "diffusers/controlnet-canny-sdxl-1.0"
64
+ )
65
+ controlnet = ControlNetModel.from_pretrained(
66
+ model, torch_dtype=torch.float16
67
+ ).to("cuda")
68
+
69
+ if hasattr(self, "pipe"):
70
+ self.pipe.controlnet = controlnet
71
+
72
+ self.__current_process_mode = process_mode
73
+
74
+ clear_cuda_and_gc()
75
+
76
  def process(
77
  self,
78
  prompt: str,
 
82
  width: int,
83
  height: int,
84
  model_id: int,
85
+ seed: int,
86
+ process_mode: str,
87
  ):
88
+ generator = torch.manual_seed(seed)
89
+
90
+ self.__reload_controlnet(process_mode)
91
+
92
+ if process_mode == TaskType.POSE.name:
93
+ print("Running POSE")
94
  condition_image = controlnet.detect_pose(imageUrl)
95
  else:
96
+ print("Running CANNY")
97
  condition_image = download_image(imageUrl)
98
  condition_image = ControlNet.canny_detect_edge(condition_image)
99
+ width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
100
 
101
+ img = download_image(imageUrl).resize((width, height))
102
  condition_image = condition_image.resize(img.size)
103
 
104
  img2 = self.__resize_for_condition_image(img, resize_dimension)
105
 
106
+ img = self.pad_image(img)
107
  image_lr = self.load_and_process_image(img)
108
+
109
+ out_img = self.pad_image(img2)
110
+ condition_image = self.pad_image(condition_image)
111
+
112
+ print("img", img.size)
113
+ print("img2", img2.size)
114
+ print("condition", condition_image.size)
115
  if int(model_id) == 2000173:
116
  kwargs = {
117
  "prompt": prompt,
 
119
  "image": img2,
120
  "strength": 0.3,
121
  "num_inference_steps": 30,
122
+ "generator": generator,
123
  }
124
  images = self.high_res.pipe.__call__(**kwargs).images
125
  else:
 
127
  image_lr=image_lr,
128
  prompt=prompt,
129
  condition_image=condition_image,
130
+ negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic, "
131
+ + negative_prompt,
132
  guidance_scale=11,
133
  sigma=0.8,
134
  num_inference_steps=24,
135
+ controlnet_conditioning_scale=0.5,
136
+ generator=generator,
137
+ width=out_img.size[0],
138
+ height=out_img.size[1],
139
  )
140
  images = images[::-1]
141
+ iv = ImageUtil.resize_image(img2, images[0].size[0])
142
+ images = [self.unpad_image(images[0], iv.size)]
143
  return images, False
144
 
145
  def load_and_process_image(self, pil_image):
146
  transform = transforms.Compose(
147
  [
 
148
  transforms.ToTensor(),
149
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
150
  ]
 
154
  image = image.to("cuda")
155
  return image
156
 
157
+ def pad_image(self, image):
158
+ w, h = image.size
159
+ if w == h:
160
+ return image
161
+ elif w > h:
162
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
163
+ pad_w = 0
164
+ pad_h = (w - h) // 2
165
+ new_image.paste(image, (0, pad_h))
166
+ return new_image
167
+ else:
168
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
169
+ pad_w = (h - w) // 2
170
+ pad_h = 0
171
+ new_image.paste(image, (pad_w, 0))
172
+ return new_image
173
+
174
+ def unpad_image(self, padded_image, original_size):
175
+ w, h = original_size
176
+ if w == h:
177
+ return padded_image
178
+ elif w > h:
179
+ pad_h = (w - h) // 2
180
+ unpadded_image = padded_image.crop((0, pad_h, w, h + pad_h))
181
+ return unpadded_image
182
+ else:
183
+ pad_w = (h - w) // 2
184
+ unpadded_image = padded_image.crop((pad_w, 0, w + pad_w, h))
185
+ return unpadded_image
186
+
187
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
188
  input_image = image.convert("RGB")
189
  W, H = input_image.size
internals/pipelines/upscaler.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import math
2
  import os
3
  from pathlib import Path
4
- from typing import Union
5
 
6
  import cv2
7
  import numpy as np
@@ -10,7 +11,7 @@ from basicsr.archs.srvgg_arch import SRVGGNetCompact
10
  from basicsr.utils.download_util import load_file_from_url
11
  from gfpgan import GFPGANer
12
  from PIL import Image
13
- from realesrgan import RealESRGANer
14
 
15
  import internals.util.image as ImageUtil
16
  from internals.util.commons import download_image
@@ -55,8 +56,12 @@ class Upscaler:
55
  width: int,
56
  height: int,
57
  face_enhance: bool,
58
- resize_dimension: int,
59
  ) -> bytes:
 
 
 
 
60
  model = SRVGGNetCompact(
61
  num_in_ch=3,
62
  num_out_ch=3,
@@ -67,7 +72,7 @@ class Upscaler:
67
  )
68
  return self.__internal_upscale(
69
  image,
70
- resize_dimension,
71
  face_enhance,
72
  width,
73
  height,
@@ -83,6 +88,10 @@ class Upscaler:
83
  face_enhance: bool,
84
  resize_dimension: int,
85
  ) -> bytes:
 
 
 
 
86
  model = RRDBNet(
87
  num_in_ch=3,
88
  num_out_ch=3,
@@ -124,18 +133,22 @@ class Upscaler:
124
  model,
125
  ) -> bytes:
126
  if type(image) is str:
127
- image = download_image(image)
128
 
129
  w, h = image.size
130
- if max(w, h) > 1024:
131
- image = ImageUtil.resize_image(image, dimension=1024)
132
 
133
  in_path = str(Path.home() / ".cache" / "input_upscale.png")
134
  image.save(in_path)
135
  input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
136
- dimension = min(input_image.shape[0], input_image.shape[1])
 
 
137
  scale = max(math.floor(resize_dimension / dimension), 2)
138
 
 
 
139
  os.chdir(str(Path.home() / ".cache"))
140
  if scale == 4:
141
  print("Using 4x-Ultrasharp")
@@ -174,3 +187,7 @@ class Upscaler:
174
  cv2.imwrite("out.png", output)
175
  out_bytes = cv2.imencode(".png", output)[1].tobytes()
176
  return out_bytes
 
 
 
 
 
1
+ import io
2
  import math
3
  import os
4
  from pathlib import Path
5
+ from typing import Optional, Union
6
 
7
  import cv2
8
  import numpy as np
 
11
  from basicsr.utils.download_util import load_file_from_url
12
  from gfpgan import GFPGANer
13
  from PIL import Image
14
+ from realesrgan import RealESRGANer # pyright: ignore
15
 
16
  import internals.util.image as ImageUtil
17
  from internals.util.commons import download_image
 
56
  width: int,
57
  height: int,
58
  face_enhance: bool,
59
+ resize_dimension: Optional[int] = None,
60
  ) -> bytes:
61
+ "if resize dimension is not provided, use the smaller of width and height"
62
+
63
+ self.load()
64
+
65
  model = SRVGGNetCompact(
66
  num_in_ch=3,
67
  num_out_ch=3,
 
72
  )
73
  return self.__internal_upscale(
74
  image,
75
+ resize_dimension, # type: ignore
76
  face_enhance,
77
  width,
78
  height,
 
88
  face_enhance: bool,
89
  resize_dimension: int,
90
  ) -> bytes:
91
+ "if resize dimension is not provided, use the smaller of width and height"
92
+
93
+ self.load()
94
+
95
  model = RRDBNet(
96
  num_in_ch=3,
97
  num_out_ch=3,
 
133
  model,
134
  ) -> bytes:
135
  if type(image) is str:
136
+ image = download_image(image, mode="RGBA")
137
 
138
  w, h = image.size
139
+ # if max(w, h) > 1024:
140
+ # image = ImageUtil.resize_image(image, dimension=1024)
141
 
142
  in_path = str(Path.home() / ".cache" / "input_upscale.png")
143
  image.save(in_path)
144
  input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
145
+ dimension = max(input_image.shape[0], input_image.shape[1])
146
+ if not resize_dimension:
147
+ resize_dimension = max(width, height)
148
  scale = max(math.floor(resize_dimension / dimension), 2)
149
 
150
+ print("Upscaling by: ", scale)
151
+
152
  os.chdir(str(Path.home() / ".cache"))
153
  if scale == 4:
154
  print("Using 4x-Ultrasharp")
 
187
  cv2.imwrite("out.png", output)
188
  out_bytes = cv2.imencode(".png", output)[1].tobytes()
189
  return out_bytes
190
+
191
+ @staticmethod
192
+ def to_pil(buffer: bytes, mode="RGB") -> Image.Image:
193
+ return Image.open(io.BytesIO(buffer)).convert(mode)
internals/util/__init__.py CHANGED
@@ -1,7 +1,13 @@
1
  import os
2
 
 
 
3
  from internals.util.config import get_root_dir
4
 
5
 
6
  def getcwd():
7
  return get_root_dir()
 
 
 
 
 
1
  import os
2
 
3
+ import torch
4
+
5
  from internals.util.config import get_root_dir
6
 
7
 
8
  def getcwd():
9
  return get_root_dir()
10
+
11
+
12
+ def get_generators(seed, num_generators=1):
13
+ return [torch.Generator().manual_seed(seed + i) for i in range(num_generators)]
internals/util/cache.py CHANGED
@@ -1,5 +1,6 @@
1
  import gc
2
  import os
 
3
  import psutil
4
  import torch
5
 
@@ -7,6 +8,7 @@ import torch
7
  def print_memory_usage():
8
  process = psutil.Process(os.getpid())
9
  print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
 
10
 
11
 
12
  def clear_cuda_and_gc():
 
1
  import gc
2
  import os
3
+
4
  import psutil
5
  import torch
6
 
 
8
  def print_memory_usage():
9
  process = psutil.Process(os.getpid())
10
  print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
11
+ print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB")
12
 
13
 
14
  def clear_cuda_and_gc():
internals/util/commons.py CHANGED
@@ -11,7 +11,7 @@ from typing import Any, Optional, Union
11
  import boto3
12
  import requests
13
 
14
- from internals.util.config import api_endpoint, api_headers
15
 
16
  s3 = boto3.client("s3")
17
  import io
@@ -103,7 +103,7 @@ def upload_images(images, processName: str, taskId: str):
103
  img_io.seek(0)
104
  key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
105
  res = requests.post(
106
- api_endpoint()
107
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
108
  + "{}{}_{}.png".format(taskId, processName, i),
109
  headers=api_headers(),
@@ -129,12 +129,12 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path):
129
 
130
  image.seek(0)
131
  print(
132
- api_endpoint()
133
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
134
  + str(out_path).replace("crecoAI/", ""),
135
  )
136
  res = requests.post(
137
- api_endpoint()
138
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
139
  + str(out_path).replace("crecoAI/", ""),
140
  headers=api_headers(),
 
11
  import boto3
12
  import requests
13
 
14
+ from internals.util.config import api_endpoint, api_headers, elb_endpoint
15
 
16
  s3 = boto3.client("s3")
17
  import io
 
103
  img_io.seek(0)
104
  key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
105
  res = requests.post(
106
+ elb_endpoint()
107
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
108
  + "{}{}_{}.png".format(taskId, processName, i),
109
  headers=api_headers(),
 
129
 
130
  image.seek(0)
131
  print(
132
+ elb_endpoint()
133
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
134
  + str(out_path).replace("crecoAI/", ""),
135
  )
136
  res = requests.post(
137
+ elb_endpoint()
138
  + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
139
  + str(out_path).replace("crecoAI/", ""),
140
  headers=api_headers(),
internals/util/config.py CHANGED
@@ -13,7 +13,7 @@ access_token = ""
13
  root_dir = ""
14
  model_config = None
15
  hf_token = base64.b64decode(
16
- b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
17
  ).decode()
18
  hf_cache_dir = "/tmp/hf_hub"
19
 
@@ -46,7 +46,7 @@ def set_model_config(config: ModelConfig):
46
 
47
  def set_configs_from_task(task: Task):
48
  global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
49
- name = task.get_queue_name()
50
  if name.startswith("gamma"):
51
  env = "gamma"
52
  else:
@@ -120,14 +120,25 @@ def get_base_model_variant():
120
  return model_config.base_model_variant # pyright: ignore
121
 
122
 
 
 
 
 
 
123
  def get_base_inpaint_model_variant():
124
  global model_config
125
  return model_config.base_inpaint_model_variant # pyright: ignore
126
 
127
 
 
 
 
 
 
128
  def api_headers():
129
  return {
130
  "Access-Token": access_token,
 
131
  }
132
 
133
 
@@ -138,8 +149,11 @@ def api_endpoint():
138
  return "https://gamma-api.autodraft.in"
139
 
140
 
141
- def comic_url():
 
 
 
142
  if env == "prod":
143
- return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
144
  else:
145
- return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
 
13
  root_dir = ""
14
  model_config = None
15
  hf_token = base64.b64decode(
16
+ b"aGZfaXRvVVJzTmN1RHZab1hXZ3hIeFRRRGdvSHdrQ2VNUldGbA=="
17
  ).decode()
18
  hf_cache_dir = "/tmp/hf_hub"
19
 
 
46
 
47
  def set_configs_from_task(task: Task):
48
  global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
49
+ name = task.get_environment()
50
  if name.startswith("gamma"):
51
  env = "gamma"
52
  else:
 
120
  return model_config.base_model_variant # pyright: ignore
121
 
122
 
123
+ def get_base_model_revision():
124
+ global model_config
125
+ return model_config.base_model_revision # pyright: ignore
126
+
127
+
128
  def get_base_inpaint_model_variant():
129
  global model_config
130
  return model_config.base_inpaint_model_variant # pyright: ignore
131
 
132
 
133
+ def get_base_inpaint_model_revision():
134
+ global model_config
135
+ return model_config.base_inpaint_model_revision # pyright: ignore
136
+
137
+
138
  def api_headers():
139
  return {
140
  "Access-Token": access_token,
141
+ "Host": "api.autodraft.in" if env == "prod" else "gamma-api.autodraft.in",
142
  }
143
 
144
 
 
149
  return "https://gamma-api.autodraft.in"
150
 
151
 
152
+ def elb_endpoint():
153
+ # We use the ELB endpoint for uploading images since
154
+ # cloudflare has a hard limit of 100mb when the
155
+ # DNS is proxied
156
  if env == "prod":
157
+ return "http://k8s-prod-ingresse-8ba91151af-2105029163.ap-south-1.elb.amazonaws.com"
158
  else:
159
+ return "http://k8s-gamma-ingresse-fc1051bc41-1227070426.ap-south-1.elb.amazonaws.com"
internals/util/failure_hander.py CHANGED
@@ -16,10 +16,13 @@ class FailureHandler:
16
  path = FailureHandler.__task_path
17
  path.parent.mkdir(parents=True, exist_ok=True)
18
  if path.exists():
19
- task = Task(json.loads(path.read_text()))
20
- set_configs_from_task(task)
21
- # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE"))
22
- updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
 
 
 
23
  os.remove(path)
24
 
25
  @staticmethod
 
16
  path = FailureHandler.__task_path
17
  path.parent.mkdir(parents=True, exist_ok=True)
18
  if path.exists():
19
+ try:
20
+ task = Task(json.loads(path.read_text()))
21
+ set_configs_from_task(task)
22
+ # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE"))
23
+ updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
24
+ except Exception as e:
25
+ print("Failed to handle task", e)
26
  os.remove(path)
27
 
28
  @staticmethod
internals/util/image.py CHANGED
@@ -48,3 +48,21 @@ def padd_image(image: Image.Image, to_width: int, to_height: int) -> Image.Image
48
  img = Image.new("RGBA", (to_width, to_height), (0, 0, 0, 0))
49
  img.paste(image, ((to_width - iw) // 2, (to_height - ih) // 2))
50
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  img = Image.new("RGBA", (to_width, to_height), (0, 0, 0, 0))
49
  img.paste(image, ((to_width - iw) // 2, (to_height - ih) // 2))
50
  return img
51
+
52
+
53
+ def alpha_to_white(img: Image.Image) -> Image.Image:
54
+ if img.mode == "RGBA":
55
+ data = img.getdata()
56
+
57
+ new_data = []
58
+
59
+ for item in data:
60
+ if item[3] == 0:
61
+ new_data.append((255, 255, 255, 255))
62
+ else:
63
+ new_data.append(item)
64
+
65
+ img.putdata(new_data)
66
+
67
+ img = img.convert("RGB")
68
+ return img
internals/util/lora_style.py CHANGED
@@ -52,9 +52,18 @@ class LoraStyle:
52
  def patch(self):
53
  def run(pipe):
54
  path = self.__style["path"]
55
- pipe.load_lora_weights(
56
- os.path.dirname(path), weight_name=os.path.basename(path)
57
- )
 
 
 
 
 
 
 
 
 
58
 
59
  for p in self.pipe:
60
  run(p)
@@ -105,7 +114,17 @@ class LoraStyle:
105
  def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
106
  if key in self.__styles:
107
  style = self.__styles[key]
108
- return f"{', '.join(style['text'])}, {prompt}"
 
 
 
 
 
 
 
 
 
 
109
  return prompt
110
 
111
  def get_patcher(
@@ -140,7 +159,9 @@ class LoraStyle:
140
  "path": str(file_path),
141
  "weight": attr["weight"],
142
  "type": attr["type"],
 
143
  "text": attr["text"],
 
144
  "negativePrompt": attr["negativePrompt"],
145
  }
146
  return styles
@@ -159,4 +180,7 @@ class LoraStyle:
159
 
160
  @staticmethod
161
  def unload_lora_weights(pipe):
162
- pipe.unload_lora_weights()
 
 
 
 
52
  def patch(self):
53
  def run(pipe):
54
  path = self.__style["path"]
55
+ name = str(self.__style["tag"]).replace(" ", "_")
56
+ weight = self.__style.get("weight", 1.0)
57
+ if name not in pipe.get_list_adapters().get("unet", []):
58
+ print(
59
+ f"Loading lora {os.path.basename(path)} with weights {weight}, name: {name}"
60
+ )
61
+ pipe.load_lora_weights(
62
+ os.path.dirname(path),
63
+ weight_name=os.path.basename(path),
64
+ adapter_name=name,
65
+ )
66
+ pipe.set_adapters([name], adapter_weights=[weight])
67
 
68
  for p in self.pipe:
69
  run(p)
 
114
  def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
115
  if key in self.__styles:
116
  style = self.__styles[key]
117
+ prompt = f"{', '.join(style['text'])}, {prompt}"
118
+ prompt = prompt.replace("<NOSEP>, ", "")
119
+ return prompt
120
+
121
+ def append_style_to_prompt(self, prompt: str, key: str) -> str:
122
+ if key in self.__styles and "text_append" in self.__styles[key]:
123
+ style = self.__styles[key]
124
+ if prompt.endswith(","):
125
+ prompt = prompt[:-1]
126
+ prompt = f"{prompt}, {', '.join(style['text_append'])}"
127
+ prompt = prompt.replace("<NOSEP>, ", "")
128
  return prompt
129
 
130
  def get_patcher(
 
159
  "path": str(file_path),
160
  "weight": attr["weight"],
161
  "type": attr["type"],
162
+ "tag": item["tag"],
163
  "text": attr["text"],
164
+ "text_append": attr.get("text_append", []),
165
  "negativePrompt": attr["negativePrompt"],
166
  }
167
  return styles
 
180
 
181
  @staticmethod
182
  def unload_lora_weights(pipe):
183
+ # we keep the lora layers in the adapters and unset it whenever
184
+ # not required instead of completely unloading it
185
+ pipe.set_adapters([])
186
+ # pipe.unload_lora_weights()
internals/util/model_loader.py CHANGED
@@ -18,7 +18,9 @@ class ModelConfig:
18
  base_dimension: int = 512
19
  low_gpu_mem: bool = False
20
  base_model_variant: Optional[str] = None
 
21
  base_inpaint_model_variant: Optional[str] = None
 
22
 
23
 
24
  def load_model_from_config(path):
@@ -31,7 +33,11 @@ def load_model_from_config(path):
31
  is_sdxl = config.get("is_sdxl", False)
32
  base_dimension = config.get("base_dimension", 512)
33
  base_model_variant = config.get("base_model_variant", None)
 
34
  base_inpaint_model_variant = config.get("base_inpaint_model_variant", None)
 
 
 
35
 
36
  m_config.base_model_path = model_path
37
  m_config.base_inpaint_model_path = inpaint_model_path
@@ -39,7 +45,9 @@ def load_model_from_config(path):
39
  m_config.base_dimension = base_dimension
40
  m_config.low_gpu_mem = config.get("low_gpu_mem", False)
41
  m_config.base_model_variant = base_model_variant
 
42
  m_config.base_inpaint_model_variant = base_inpaint_model_variant
 
43
 
44
  #
45
  # if config.get("model_type") == "huggingface":
 
18
  base_dimension: int = 512
19
  low_gpu_mem: bool = False
20
  base_model_variant: Optional[str] = None
21
+ base_model_revision: Optional[str] = None
22
  base_inpaint_model_variant: Optional[str] = None
23
+ base_inpaint_model_revision: Optional[str] = None
24
 
25
 
26
  def load_model_from_config(path):
 
33
  is_sdxl = config.get("is_sdxl", False)
34
  base_dimension = config.get("base_dimension", 512)
35
  base_model_variant = config.get("base_model_variant", None)
36
+ base_model_revision = config.get("base_model_revision", None)
37
  base_inpaint_model_variant = config.get("base_inpaint_model_variant", None)
38
+ base_inpaint_model_revision = config.get(
39
+ "base_inpaint_model_revision", None
40
+ )
41
 
42
  m_config.base_model_path = model_path
43
  m_config.base_inpaint_model_path = inpaint_model_path
 
45
  m_config.base_dimension = base_dimension
46
  m_config.low_gpu_mem = config.get("low_gpu_mem", False)
47
  m_config.base_model_variant = base_model_variant
48
+ m_config.base_model_revision = base_model_revision
49
  m_config.base_inpaint_model_variant = base_inpaint_model_variant
50
+ m_config.base_inpaint_model_revision = base_inpaint_model_revision
51
 
52
  #
53
  # if config.get("model_type") == "huggingface":
internals/util/prompt.py CHANGED
@@ -21,6 +21,7 @@ def get_patched_prompt(
21
  for i in range(len(prompt)):
22
  prompt[i] = avatar.add_code_names(prompt[i])
23
  prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
 
24
  if additional:
25
  prompt[i] = additional + " " + prompt[i]
26
 
@@ -51,6 +52,7 @@ def get_patched_prompt_text2img(
51
  def add_style_and_character(prompt: str, prepend: str = ""):
52
  prompt = avatar.add_code_names(prompt)
53
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
 
54
  prompt = prepend + prompt
55
  return prompt
56
 
@@ -102,6 +104,7 @@ def get_patched_prompt_tile_upscale(
102
  lora_style: LoraStyle,
103
  img_classifier: ImageClassifier,
104
  img2text: Image2Text,
 
105
  ):
106
  if task.get_prompt():
107
  prompt = task.get_prompt()
@@ -114,10 +117,12 @@ def get_patched_prompt_tile_upscale(
114
  prompt = task.PROMPT.merge_blip(blip)
115
 
116
  # remove anomalies in prompt
117
- prompt = remove_colors(prompt)
 
118
 
119
  prompt = avatar.add_code_names(prompt)
120
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
 
121
 
122
  if not task.get_style():
123
  class_name = img_classifier.classify(
 
21
  for i in range(len(prompt)):
22
  prompt[i] = avatar.add_code_names(prompt[i])
23
  prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
24
+ prompt[i] = lora_style.append_style_to_prompt(prompt[i], task.get_style())
25
  if additional:
26
  prompt[i] = additional + " " + prompt[i]
27
 
 
52
  def add_style_and_character(prompt: str, prepend: str = ""):
53
  prompt = avatar.add_code_names(prompt)
54
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
55
+ prompt = lora_style.append_style_to_prompt(prompt, task.get_style())
56
  prompt = prepend + prompt
57
  return prompt
58
 
 
104
  lora_style: LoraStyle,
105
  img_classifier: ImageClassifier,
106
  img2text: Image2Text,
107
+ is_sdxl=False,
108
  ):
109
  if task.get_prompt():
110
  prompt = task.get_prompt()
 
117
  prompt = task.PROMPT.merge_blip(blip)
118
 
119
  # remove anomalies in prompt
120
+ if not is_sdxl:
121
+ prompt = remove_colors(prompt)
122
 
123
  prompt = avatar.add_code_names(prompt)
124
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
125
+ prompt = lora_style.append_style_to_prompt(prompt, task.get_style())
126
 
127
  if not task.get_style():
128
  class_name = img_classifier.classify(
internals/util/sdxl_lightning.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from re import S
3
+ from typing import List, Union
4
+
5
+ from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
6
+ from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin
7
+ from torchvision.datasets.utils import download_url
8
+
9
+
10
+ class LightningMixin:
11
+ LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors"
12
+
13
+ __scheduler_old = None
14
+ __pipe: StableDiffusionXLPipeline = None
15
+ __scheduler = None
16
+
17
+ def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline):
18
+ lora_path = Path.home() / ".cache" / "lora_8_step.safetensors"
19
+
20
+ download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name)
21
+
22
+ pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora")
23
+ pipe.set_adapters([])
24
+
25
+ self.__scheduler = EulerDiscreteScheduler.from_config(
26
+ pipe.scheduler.config, timestep_spacing="trailing"
27
+ )
28
+ self.__scheduler_old = pipe.scheduler
29
+ self.__pipe = pipe
30
+
31
+ def enable_sdxl_lightning(self):
32
+ pipe = self.__pipe
33
+ pipe.scheduler = self.__scheduler
34
+
35
+ current = pipe.get_active_adapters()
36
+ current.extend(["8step_lora"])
37
+
38
+ weights = self.__find_adapter_weights(current)
39
+ pipe.set_adapters(current, adapter_weights=weights)
40
+
41
+ return {"guidance_scale": 0, "num_inference_steps": 8}
42
+
43
+ def disable_sdxl_lightning(self):
44
+ pipe = self.__pipe
45
+ pipe.scheduler = self.__scheduler_old
46
+
47
+ current = pipe.get_active_adapters()
48
+ current = [adapter for adapter in current if adapter != "8step_lora"]
49
+
50
+ weights = self.__find_adapter_weights(current)
51
+ pipe.set_adapters(current, adapter_weights=weights)
52
+
53
+ def __find_adapter_weights(self, names: List[str]):
54
+ pipe = self.__pipe
55
+
56
+ model = pipe.unet
57
+
58
+ from peft.tuners.tuners_utils import BaseTunerLayer
59
+
60
+ weights = []
61
+ for adapter_name in names:
62
+ weight = 1.0
63
+ for module in model.modules():
64
+ if isinstance(module, BaseTunerLayer):
65
+ if adapter_name in module.scaling:
66
+ weight = (
67
+ module.scaling[adapter_name]
68
+ * module.r[adapter_name]
69
+ / module.lora_alpha[adapter_name]
70
+ )
71
+
72
+ weights.append(weight)
73
+
74
+ return weights
internals/util/slack.py CHANGED
@@ -14,6 +14,8 @@ class Slack:
14
  self.error_webhook = "https://hooks.slack.com/services/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"
15
 
16
  def send_alert(self, task: Task, args: Optional[dict]):
 
 
17
  raw = task.get_raw().copy()
18
 
19
  raw["environment"] = get_environment()
@@ -23,6 +25,7 @@ class Slack:
23
  raw.pop("task_id", None)
24
  raw.pop("maskImageUrl", None)
25
  raw.pop("aux_imageUrl", None)
 
26
 
27
  if args is not None:
28
  raw.update(args.items())
 
14
  self.error_webhook = "https://hooks.slack.com/services/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"
15
 
16
  def send_alert(self, task: Task, args: Optional[dict]):
17
+ if task.get_slack_url():
18
+ self.webhook_url = task.get_slack_url()
19
  raw = task.get_raw().copy()
20
 
21
  raw["environment"] = get_environment()
 
25
  raw.pop("task_id", None)
26
  raw.pop("maskImageUrl", None)
27
  raw.pop("aux_imageUrl", None)
28
+ raw.pop("slack_url", None)
29
 
30
  if args is not None:
31
  raw.update(args.items())