Vijish commited on
Commit
f4f26b2
1 Parent(s): becc578

Upload 5 files

Browse files
src/__init__.py ADDED
File without changes
src/core.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import uuid
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ import cv2
10
+
11
+ # For inpainting
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import streamlit as st
16
+ from PIL import Image
17
+ #from streamlit_drawable_canvas import st_canvas
18
+
19
+
20
+ import argparse
21
+ import io
22
+ import multiprocessing
23
+ from typing import Union
24
+
25
+ import torch
26
+
27
+ try:
28
+ torch._C._jit_override_can_fuse_on_cpu(False)
29
+ torch._C._jit_override_can_fuse_on_gpu(False)
30
+ torch._C._jit_set_texpr_fuser_enabled(False)
31
+ torch._C._jit_set_nvfuser_enabled(False)
32
+ except:
33
+ pass
34
+
35
+ from src.helper import (
36
+ download_model,
37
+ load_img,
38
+ norm_img,
39
+ numpy_to_bytes,
40
+ pad_img_to_modulo,
41
+ resize_max_size,
42
+ )
43
+
44
+ NUM_THREADS = str(multiprocessing.cpu_count())
45
+
46
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
47
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
48
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
49
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
50
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
51
+ if os.environ.get("CACHE_DIR"):
52
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
53
+
54
+ #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
55
+
56
+ # For Seam-carving
57
+
58
+ from scipy import ndimage as ndi
59
+
60
+ SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
61
+ SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
62
+ DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
63
+ ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
64
+ MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
+ USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
+
67
+ device = torch.device("cpu")
68
+ model_path = "./assets/big-lama.pt"
69
+ model = torch.jit.load(model_path, map_location="cpu")
70
+ model = model.to(device)
71
+ model.eval()
72
+
73
+
74
+ ########################################
75
+ # UTILITY CODE
76
+ ########################################
77
+
78
+
79
+ def visualize(im, boolmask=None, rotate=False):
80
+ vis = im.astype(np.uint8)
81
+ if boolmask is not None:
82
+ vis[np.where(boolmask == False)] = SEAM_COLOR
83
+ if rotate:
84
+ vis = rotate_image(vis, False)
85
+ cv2.imshow("visualization", vis)
86
+ cv2.waitKey(1)
87
+ return vis
88
+
89
+ def resize(image, width):
90
+ dim = None
91
+ h, w = image.shape[:2]
92
+ dim = (width, int(h * width / float(w)))
93
+ image = image.astype('float32')
94
+ return cv2.resize(image, dim)
95
+
96
+ def rotate_image(image, clockwise):
97
+ k = 1 if clockwise else 3
98
+ return np.rot90(image, k)
99
+
100
+
101
+ ########################################
102
+ # ENERGY FUNCTIONS
103
+ ########################################
104
+
105
+ def backward_energy(im):
106
+ """
107
+ Simple gradient magnitude energy map.
108
+ """
109
+ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
110
+ ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
111
+
112
+ grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
113
+
114
+ # vis = visualize(grad_mag)
115
+ # cv2.imwrite("backward_energy_demo.jpg", vis)
116
+
117
+ return grad_mag
118
+
119
+ def forward_energy(im):
120
+ """
121
+ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
122
+ by Rubinstein, Shamir, Avidan.
123
+ Vectorized code adapted from
124
+ https://github.com/axu2/improved-seam-carving.
125
+ """
126
+ h, w = im.shape[:2]
127
+ im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
128
+
129
+ energy = np.zeros((h, w))
130
+ m = np.zeros((h, w))
131
+
132
+ U = np.roll(im, 1, axis=0)
133
+ L = np.roll(im, 1, axis=1)
134
+ R = np.roll(im, -1, axis=1)
135
+
136
+ cU = np.abs(R - L)
137
+ cL = np.abs(U - L) + cU
138
+ cR = np.abs(U - R) + cU
139
+
140
+ for i in range(1, h):
141
+ mU = m[i-1]
142
+ mL = np.roll(mU, 1)
143
+ mR = np.roll(mU, -1)
144
+
145
+ mULR = np.array([mU, mL, mR])
146
+ cULR = np.array([cU[i], cL[i], cR[i]])
147
+ mULR += cULR
148
+
149
+ argmins = np.argmin(mULR, axis=0)
150
+ m[i] = np.choose(argmins, mULR)
151
+ energy[i] = np.choose(argmins, cULR)
152
+
153
+ # vis = visualize(energy)
154
+ # cv2.imwrite("forward_energy_demo.jpg", vis)
155
+
156
+ return energy
157
+
158
+ ########################################
159
+ # SEAM HELPER FUNCTIONS
160
+ ########################################
161
+
162
+ def add_seam(im, seam_idx):
163
+ """
164
+ Add a vertical seam to a 3-channel color image at the indices provided
165
+ by averaging the pixels values to the left and right of the seam.
166
+ Code adapted from https://github.com/vivianhylee/seam-carving.
167
+ """
168
+ h, w = im.shape[:2]
169
+ output = np.zeros((h, w + 1, 3))
170
+ for row in range(h):
171
+ col = seam_idx[row]
172
+ for ch in range(3):
173
+ if col == 0:
174
+ p = np.mean(im[row, col: col + 2, ch])
175
+ output[row, col, ch] = im[row, col, ch]
176
+ output[row, col + 1, ch] = p
177
+ output[row, col + 1:, ch] = im[row, col:, ch]
178
+ else:
179
+ p = np.mean(im[row, col - 1: col + 1, ch])
180
+ output[row, : col, ch] = im[row, : col, ch]
181
+ output[row, col, ch] = p
182
+ output[row, col + 1:, ch] = im[row, col:, ch]
183
+
184
+ return output
185
+
186
+ def add_seam_grayscale(im, seam_idx):
187
+ """
188
+ Add a vertical seam to a grayscale image at the indices provided
189
+ by averaging the pixels values to the left and right of the seam.
190
+ """
191
+ h, w = im.shape[:2]
192
+ output = np.zeros((h, w + 1))
193
+ for row in range(h):
194
+ col = seam_idx[row]
195
+ if col == 0:
196
+ p = np.mean(im[row, col: col + 2])
197
+ output[row, col] = im[row, col]
198
+ output[row, col + 1] = p
199
+ output[row, col + 1:] = im[row, col:]
200
+ else:
201
+ p = np.mean(im[row, col - 1: col + 1])
202
+ output[row, : col] = im[row, : col]
203
+ output[row, col] = p
204
+ output[row, col + 1:] = im[row, col:]
205
+
206
+ return output
207
+
208
+ def remove_seam(im, boolmask):
209
+ h, w = im.shape[:2]
210
+ boolmask3c = np.stack([boolmask] * 3, axis=2)
211
+ return im[boolmask3c].reshape((h, w - 1, 3))
212
+
213
+ def remove_seam_grayscale(im, boolmask):
214
+ h, w = im.shape[:2]
215
+ return im[boolmask].reshape((h, w - 1))
216
+
217
+ def get_minimum_seam(im, mask=None, remove_mask=None):
218
+ """
219
+ DP algorithm for finding the seam of minimum energy. Code adapted from
220
+ https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
221
+ """
222
+ h, w = im.shape[:2]
223
+ energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
224
+ M = energyfn(im)
225
+
226
+ if mask is not None:
227
+ M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
228
+
229
+ # give removal mask priority over protective mask by using larger negative value
230
+ if remove_mask is not None:
231
+ M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
232
+
233
+ seam_idx, boolmask = compute_shortest_path(M, im, h, w)
234
+
235
+ return np.array(seam_idx), boolmask
236
+
237
+ def compute_shortest_path(M, im, h, w):
238
+ backtrack = np.zeros_like(M, dtype=np.int_)
239
+
240
+
241
+ # populate DP matrix
242
+ for i in range(1, h):
243
+ for j in range(0, w):
244
+ if j == 0:
245
+ idx = np.argmin(M[i - 1, j:j + 2])
246
+ backtrack[i, j] = idx + j
247
+ min_energy = M[i-1, idx + j]
248
+ else:
249
+ idx = np.argmin(M[i - 1, j - 1:j + 2])
250
+ backtrack[i, j] = idx + j - 1
251
+ min_energy = M[i - 1, idx + j - 1]
252
+
253
+ M[i, j] += min_energy
254
+
255
+ # backtrack to find path
256
+ seam_idx = []
257
+ boolmask = np.ones((h, w), dtype=np.bool_)
258
+ j = np.argmin(M[-1])
259
+ for i in range(h-1, -1, -1):
260
+ boolmask[i, j] = False
261
+ seam_idx.append(j)
262
+ j = backtrack[i, j]
263
+
264
+ seam_idx.reverse()
265
+ return seam_idx, boolmask
266
+
267
+ ########################################
268
+ # MAIN ALGORITHM
269
+ ########################################
270
+
271
+ def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
272
+ for _ in range(num_remove):
273
+ seam_idx, boolmask = get_minimum_seam(im, mask)
274
+ if vis:
275
+ visualize(im, boolmask, rotate=rot)
276
+ im = remove_seam(im, boolmask)
277
+ if mask is not None:
278
+ mask = remove_seam_grayscale(mask, boolmask)
279
+ return im, mask
280
+
281
+
282
+ def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
283
+ seams_record = []
284
+ temp_im = im.copy()
285
+ temp_mask = mask.copy() if mask is not None else None
286
+
287
+ for _ in range(num_add):
288
+ seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
289
+ if vis:
290
+ visualize(temp_im, boolmask, rotate=rot)
291
+
292
+ seams_record.append(seam_idx)
293
+ temp_im = remove_seam(temp_im, boolmask)
294
+ if temp_mask is not None:
295
+ temp_mask = remove_seam_grayscale(temp_mask, boolmask)
296
+
297
+ seams_record.reverse()
298
+
299
+ for _ in range(num_add):
300
+ seam = seams_record.pop()
301
+ im = add_seam(im, seam)
302
+ if vis:
303
+ visualize(im, rotate=rot)
304
+ if mask is not None:
305
+ mask = add_seam_grayscale(mask, seam)
306
+
307
+ # update the remaining seam indices
308
+ for remaining_seam in seams_record:
309
+ remaining_seam[np.where(remaining_seam >= seam)] += 2
310
+
311
+ return im, mask
312
+
313
+ ########################################
314
+ # MAIN DRIVER FUNCTIONS
315
+ ########################################
316
+
317
+ def seam_carve(im, dy, dx, mask=None, vis=False):
318
+ im = im.astype(np.float64)
319
+ h, w = im.shape[:2]
320
+ assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
321
+
322
+ if mask is not None:
323
+ mask = mask.astype(np.float64)
324
+
325
+ output = im
326
+
327
+ if dx < 0:
328
+ output, mask = seams_removal(output, -dx, mask, vis)
329
+
330
+ elif dx > 0:
331
+ output, mask = seams_insertion(output, dx, mask, vis)
332
+
333
+ if dy < 0:
334
+ output = rotate_image(output, True)
335
+ if mask is not None:
336
+ mask = rotate_image(mask, True)
337
+ output, mask = seams_removal(output, -dy, mask, vis, rot=True)
338
+ output = rotate_image(output, False)
339
+
340
+ elif dy > 0:
341
+ output = rotate_image(output, True)
342
+ if mask is not None:
343
+ mask = rotate_image(mask, True)
344
+ output, mask = seams_insertion(output, dy, mask, vis, rot=True)
345
+ output = rotate_image(output, False)
346
+
347
+ return output
348
+
349
+
350
+ def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
351
+ im = im.astype(np.float64)
352
+ rmask = rmask.astype(np.float64)
353
+ if mask is not None:
354
+ mask = mask.astype(np.float64)
355
+ output = im
356
+
357
+ h, w = im.shape[:2]
358
+
359
+ if horizontal_removal:
360
+ output = rotate_image(output, True)
361
+ rmask = rotate_image(rmask, True)
362
+ if mask is not None:
363
+ mask = rotate_image(mask, True)
364
+
365
+ while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
366
+ seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
367
+ if vis:
368
+ visualize(output, boolmask, rotate=horizontal_removal)
369
+ output = remove_seam(output, boolmask)
370
+ rmask = remove_seam_grayscale(rmask, boolmask)
371
+ if mask is not None:
372
+ mask = remove_seam_grayscale(mask, boolmask)
373
+
374
+ num_add = (h if horizontal_removal else w) - output.shape[1]
375
+ output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
376
+ if horizontal_removal:
377
+ output = rotate_image(output, False)
378
+
379
+ return output
380
+
381
+
382
+
383
+ def s_image(im,mask,vs,hs,mode="resize"):
384
+ im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
385
+ mask = 255-mask[:,:,3]
386
+ h, w = im.shape[:2]
387
+ if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
388
+ im = resize(im, width=DOWNSIZE_WIDTH)
389
+ if mask is not None:
390
+ mask = resize(mask, width=DOWNSIZE_WIDTH)
391
+
392
+ # image resize mode
393
+ if mode=="resize":
394
+ dy = hs#reverse
395
+ dx = vs#reverse
396
+ assert dy is not None and dx is not None
397
+ output = seam_carve(im, dy, dx, mask, False)
398
+
399
+
400
+ # object removal mode
401
+ elif mode=="remove":
402
+ assert mask is not None
403
+ output = object_removal(im, mask, None, False, True)
404
+
405
+ return output
406
+
407
+
408
+ ##### Inpainting helper code
409
+
410
+ def run(image, mask):
411
+ """
412
+ image: [C, H, W]
413
+ mask: [1, H, W]
414
+ return: BGR IMAGE
415
+ """
416
+ origin_height, origin_width = image.shape[1:]
417
+ image = pad_img_to_modulo(image, mod=8)
418
+ mask = pad_img_to_modulo(mask, mod=8)
419
+
420
+ mask = (mask > 0) * 1
421
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
422
+ mask = torch.from_numpy(mask).unsqueeze(0).to(device)
423
+
424
+ start = time.time()
425
+ with torch.no_grad():
426
+ inpainted_image = model(image, mask)
427
+
428
+ print(f"process time: {(time.time() - start)*1000}ms")
429
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
430
+ cur_res = cur_res[0:origin_height, 0:origin_width, :]
431
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
432
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
433
+ return cur_res
434
+
435
+
436
+ def get_args_parser():
437
+ parser = argparse.ArgumentParser()
438
+ parser.add_argument("--port", default=8080, type=int)
439
+ parser.add_argument("--device", default="cuda", type=str)
440
+ parser.add_argument("--debug", action="store_true")
441
+ return parser.parse_args()
442
+
443
+
444
+ def process_inpaint(image, mask):
445
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
446
+ original_shape = image.shape
447
+ interpolation = cv2.INTER_CUBIC
448
+
449
+ #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
450
+ #if size_limit == "Original":
451
+ size_limit = max(image.shape)
452
+ #else:
453
+ # size_limit = int(size_limit)
454
+
455
+ print(f"Origin image shape: {original_shape}")
456
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
457
+ print(f"Resized image shape: {image.shape}")
458
+ image = norm_img(image)
459
+
460
+ mask = 255-mask[:,:,3]
461
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
462
+ mask = norm_img(mask)
463
+
464
+ res_np_img = run(image, mask)
465
+
466
+ return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
src/helper.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from urllib.parse import urlparse
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from torch.hub import download_url_to_file, get_dir
9
+
10
+ LAMA_MODEL_URL = os.environ.get(
11
+ "LAMA_MODEL_URL",
12
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
13
+ )
14
+
15
+
16
+ def download_model(url=LAMA_MODEL_URL):
17
+ parts = urlparse(url)
18
+ hub_dir = get_dir()
19
+ model_dir = os.path.join(hub_dir, "checkpoints")
20
+ if not os.path.isdir(model_dir):
21
+ os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
22
+ filename = os.path.basename(parts.path)
23
+ cached_file = os.path.join(model_dir, filename)
24
+ if not os.path.exists(cached_file):
25
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
26
+ hash_prefix = None
27
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
28
+ return cached_file
29
+
30
+
31
+ def ceil_modulo(x, mod):
32
+ if x % mod == 0:
33
+ return x
34
+ return (x // mod + 1) * mod
35
+
36
+
37
+ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
38
+ data = cv2.imencode(".jpg", image_numpy)[1]
39
+ image_bytes = data.tobytes()
40
+ return image_bytes
41
+
42
+
43
+ def load_img(img_bytes, gray: bool = False):
44
+ nparr = np.frombuffer(img_bytes, np.uint8)
45
+ if gray:
46
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
47
+ else:
48
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
49
+ if len(np_img.shape) == 3 and np_img.shape[2] == 4:
50
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
51
+ else:
52
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
53
+
54
+ return np_img
55
+
56
+
57
+ def norm_img(np_img):
58
+ if len(np_img.shape) == 2:
59
+ np_img = np_img[:, :, np.newaxis]
60
+ np_img = np.transpose(np_img, (2, 0, 1))
61
+ np_img = np_img.astype("float32") / 255
62
+ return np_img
63
+
64
+
65
+ def resize_max_size(
66
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
67
+ ) -> np.ndarray:
68
+ # Resize image's longer size to size_limit if longer size larger than size_limit
69
+ h, w = np_img.shape[:2]
70
+ if max(h, w) > size_limit:
71
+ ratio = size_limit / max(h, w)
72
+ new_w = int(w * ratio + 0.5)
73
+ new_h = int(h * ratio + 0.5)
74
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
75
+ else:
76
+ return np_img
77
+
78
+
79
+ def pad_img_to_modulo(img, mod):
80
+ channels, height, width = img.shape
81
+ out_height = ceil_modulo(height, mod)
82
+ out_width = ceil_modulo(width, mod)
83
+ return np.pad(
84
+ img,
85
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
86
+ mode="symmetric",
87
+ )
src/pipeline_stable_diffusion_controlnet_inpaint.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import PIL.Image
3
+ import numpy as np
4
+
5
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
6
+
7
+ EXAMPLE_DOC_STRING = """
8
+ Examples:
9
+ ```py
10
+ >>> # !pip install opencv-python transformers accelerate
11
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
12
+ >>> from diffusers.utils import load_image
13
+ >>> import numpy as np
14
+ >>> import torch
15
+ >>> import cv2
16
+ >>> from PIL import Image
17
+ >>> # download an image
18
+ >>> image = load_image(
19
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
20
+ ... )
21
+ >>> image = np.array(image)
22
+ >>> mask_image = load_image(
23
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
24
+ ... )
25
+ >>> mask_image = np.array(mask_image)
26
+ >>> # get canny image
27
+ >>> canny_image = cv2.Canny(image, 100, 200)
28
+ >>> canny_image = canny_image[:, :, None]
29
+ >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
30
+ >>> canny_image = Image.fromarray(canny_image)
31
+ >>> # load control net and stable diffusion v1-5
32
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
33
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
34
+ ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
35
+ ... )
36
+ >>> # speed up diffusion process with faster scheduler and memory optimization
37
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
+ >>> # remove following line if xformers is not installed
39
+ >>> pipe.enable_xformers_memory_efficient_attention()
40
+ >>> pipe.enable_model_cpu_offload()
41
+ >>> # generate image
42
+ >>> generator = torch.manual_seed(0)
43
+ >>> image = pipe(
44
+ ... "futuristic-looking doggo",
45
+ ... num_inference_steps=20,
46
+ ... generator=generator,
47
+ ... image=image,
48
+ ... control_image=canny_image,
49
+ ... mask_image=mask_image
50
+ ... ).images[0]
51
+ ```
52
+ """
53
+
54
+
55
+ def prepare_mask_and_masked_image(image, mask):
56
+ """
57
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
58
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
59
+ ``image`` and ``1`` for the ``mask``.
60
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
61
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
62
+ Args:
63
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
64
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
65
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
66
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
67
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
68
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
69
+ Raises:
70
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
71
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
72
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
73
+ (ot the other way around).
74
+ Returns:
75
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
76
+ dimensions: ``batch x channels x height x width``.
77
+ """
78
+ if isinstance(image, torch.Tensor):
79
+ if not isinstance(mask, torch.Tensor):
80
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
81
+
82
+ # Batch single image
83
+ if image.ndim == 3:
84
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
85
+ image = image.unsqueeze(0)
86
+
87
+ # Batch and add channel dim for single mask
88
+ if mask.ndim == 2:
89
+ mask = mask.unsqueeze(0).unsqueeze(0)
90
+
91
+ # Batch single mask or add channel dim
92
+ if mask.ndim == 3:
93
+ # Single batched mask, no channel dim or single mask not batched but channel dim
94
+ if mask.shape[0] == 1:
95
+ mask = mask.unsqueeze(0)
96
+
97
+ # Batched masks no channel dim
98
+ else:
99
+ mask = mask.unsqueeze(1)
100
+
101
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
102
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
103
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
104
+
105
+ # Check image is in [-1, 1]
106
+ if image.min() < -1 or image.max() > 1:
107
+ raise ValueError("Image should be in [-1, 1] range")
108
+
109
+ # Check mask is in [0, 1]
110
+ if mask.min() < 0 or mask.max() > 1:
111
+ raise ValueError("Mask should be in [0, 1] range")
112
+
113
+ # Binarize mask
114
+ mask[mask < 0.5] = 0
115
+ mask[mask >= 0.5] = 1
116
+
117
+ # Image as float32
118
+ image = image.to(dtype=torch.float32)
119
+ elif isinstance(mask, torch.Tensor):
120
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
121
+ else:
122
+ # preprocess image
123
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
124
+ image = [image]
125
+
126
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
127
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
128
+ image = np.concatenate(image, axis=0)
129
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
130
+ image = np.concatenate([i[None, :] for i in image], axis=0)
131
+
132
+ image = image.transpose(0, 3, 1, 2)
133
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
134
+
135
+ # preprocess mask
136
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
137
+ mask = [mask]
138
+
139
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
140
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
141
+ mask = mask.astype(np.float32) / 255.0
142
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
143
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
144
+
145
+ mask[mask < 0.5] = 0
146
+ mask[mask >= 0.5] = 1
147
+ mask = torch.from_numpy(mask)
148
+
149
+ masked_image = image * (mask < 0.5)
150
+
151
+ return mask, masked_image
152
+
153
+ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
154
+ r"""
155
+ Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
156
+ This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ Args:
159
+ vae ([`AutoencoderKL`]):
160
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
161
+ text_encoder ([`CLIPTextModel`]):
162
+ Frozen text-encoder. Stable Diffusion uses the text portion of
163
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
164
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
165
+ tokenizer (`CLIPTokenizer`):
166
+ Tokenizer of class
167
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
168
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
169
+ controlnet ([`ControlNetModel`]):
170
+ Provides additional conditioning to the unet during the denoising process
171
+ scheduler ([`SchedulerMixin`]):
172
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
173
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
174
+ safety_checker ([`StableDiffusionSafetyChecker`]):
175
+ Classification module that estimates whether generated images could be considered offensive or harmful.
176
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
177
+ feature_extractor ([`CLIPFeatureExtractor`]):
178
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
179
+ """
180
+
181
+ def prepare_mask_latents(
182
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
183
+ ):
184
+ # resize the mask to latents shape as we concatenate the mask to the latents
185
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
186
+ # and half precision
187
+ mask = torch.nn.functional.interpolate(
188
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
189
+ )
190
+ mask = mask.to(device=device, dtype=dtype)
191
+
192
+ masked_image = masked_image.to(device=device, dtype=dtype)
193
+
194
+ # encode the mask image into latents space so we can concatenate it to the latents
195
+ if isinstance(generator, list):
196
+ masked_image_latents = [
197
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
198
+ for i in range(batch_size)
199
+ ]
200
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
201
+ else:
202
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
203
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
204
+
205
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
206
+ if mask.shape[0] < batch_size:
207
+ if not batch_size % mask.shape[0] == 0:
208
+ raise ValueError(
209
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
210
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
211
+ " of masks that you pass is divisible by the total requested batch size."
212
+ )
213
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
214
+ if masked_image_latents.shape[0] < batch_size:
215
+ if not batch_size % masked_image_latents.shape[0] == 0:
216
+ raise ValueError(
217
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
218
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
219
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
220
+ )
221
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
222
+
223
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
224
+ masked_image_latents = (
225
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
226
+ )
227
+
228
+ # aligning device to prevent device errors when concating it with the latent model input
229
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
230
+ return mask, masked_image_latents
231
+
232
+ @torch.no_grad()
233
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
234
+ def __call__(
235
+ self,
236
+ prompt: Union[str, List[str]] = None,
237
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
238
+ control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
239
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
240
+ height: Optional[int] = None,
241
+ width: Optional[int] = None,
242
+ num_inference_steps: int = 50,
243
+ guidance_scale: float = 7.5,
244
+ negative_prompt: Optional[Union[str, List[str]]] = None,
245
+ num_images_per_prompt: Optional[int] = 1,
246
+ eta: float = 0.0,
247
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
248
+ latents: Optional[torch.FloatTensor] = None,
249
+ prompt_embeds: Optional[torch.FloatTensor] = None,
250
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
251
+ output_type: Optional[str] = "pil",
252
+ return_dict: bool = True,
253
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
254
+ callback_steps: int = 1,
255
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
256
+ controlnet_conditioning_scale: float = 1.0,
257
+ ):
258
+ r"""
259
+ Function invoked when calling the pipeline for generation.
260
+ Args:
261
+ prompt (`str` or `List[str]`, *optional*):
262
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
263
+ instead.
264
+ image (`PIL.Image.Image`):
265
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
266
+ be masked out with `mask_image` and repainted according to `prompt`.
267
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
268
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
269
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
270
+ also be accepted as an image. The control image is automatically resized to fit the output image.
271
+ mask_image (`PIL.Image.Image`):
272
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
273
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
274
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
275
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
276
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
277
+ The height in pixels of the generated image.
278
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
279
+ The width in pixels of the generated image.
280
+ num_inference_steps (`int`, *optional*, defaults to 50):
281
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
282
+ expense of slower inference.
283
+ guidance_scale (`float`, *optional*, defaults to 7.5):
284
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
285
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
286
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
287
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
288
+ usually at the expense of lower image quality.
289
+ negative_prompt (`str` or `List[str]`, *optional*):
290
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
291
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
292
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
293
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
294
+ The number of images to generate per prompt.
295
+ eta (`float`, *optional*, defaults to 0.0):
296
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
297
+ [`schedulers.DDIMScheduler`], will be ignored for others.
298
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
299
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
300
+ to make generation deterministic.
301
+ latents (`torch.FloatTensor`, *optional*):
302
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
303
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
304
+ tensor will ge generated by sampling using the supplied random `generator`.
305
+ prompt_embeds (`torch.FloatTensor`, *optional*):
306
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
307
+ provided, text embeddings will be generated from `prompt` input argument.
308
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
309
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
310
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
311
+ argument.
312
+ output_type (`str`, *optional*, defaults to `"pil"`):
313
+ The output format of the generate image. Choose between
314
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
315
+ return_dict (`bool`, *optional*, defaults to `True`):
316
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
317
+ plain tuple.
318
+ callback (`Callable`, *optional*):
319
+ A function that will be called every `callback_steps` steps during inference. The function will be
320
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
321
+ callback_steps (`int`, *optional*, defaults to 1):
322
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
323
+ called at every step.
324
+ cross_attention_kwargs (`dict`, *optional*):
325
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
326
+ `self.processor` in
327
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
328
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
329
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
330
+ to the residual in the original unet.
331
+ Examples:
332
+ Returns:
333
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
334
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
335
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
336
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
337
+ (nsfw) content, according to the `safety_checker`.
338
+ """
339
+ # 0. Default height and width to unet
340
+ height, width = self._default_height_width(height, width, control_image)
341
+
342
+ # 1. Check inputs. Raise error if not correct
343
+ self.check_inputs(
344
+ prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
345
+ )
346
+
347
+ # 2. Define call parameters
348
+ if prompt is not None and isinstance(prompt, str):
349
+ batch_size = 1
350
+ elif prompt is not None and isinstance(prompt, list):
351
+ batch_size = len(prompt)
352
+ else:
353
+ batch_size = prompt_embeds.shape[0]
354
+
355
+ device = self._execution_device
356
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
357
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
358
+ # corresponds to doing no classifier free guidance.
359
+ do_classifier_free_guidance = guidance_scale > 1.0
360
+
361
+ # 3. Encode input prompt
362
+ prompt_embeds = self._encode_prompt(
363
+ prompt,
364
+ device,
365
+ num_images_per_prompt,
366
+ do_classifier_free_guidance,
367
+ negative_prompt,
368
+ prompt_embeds=prompt_embeds,
369
+ negative_prompt_embeds=negative_prompt_embeds,
370
+ )
371
+
372
+ # 4. Prepare image
373
+ control_image = self.prepare_image(
374
+ control_image,
375
+ width,
376
+ height,
377
+ batch_size * num_images_per_prompt,
378
+ num_images_per_prompt,
379
+ device,
380
+ self.controlnet.dtype,
381
+ )
382
+
383
+ if do_classifier_free_guidance:
384
+ control_image = torch.cat([control_image] * 2)
385
+
386
+ # 5. Prepare timesteps
387
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
388
+ timesteps = self.scheduler.timesteps
389
+
390
+ # 6. Prepare latent variables
391
+ num_channels_latents = self.controlnet.in_channels
392
+ latents = self.prepare_latents(
393
+ batch_size * num_images_per_prompt,
394
+ num_channels_latents,
395
+ height,
396
+ width,
397
+ prompt_embeds.dtype,
398
+ device,
399
+ generator,
400
+ latents,
401
+ )
402
+
403
+ # EXTRA: prepare mask latents
404
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
405
+ mask, masked_image_latents = self.prepare_mask_latents(
406
+ mask,
407
+ masked_image,
408
+ batch_size * num_images_per_prompt,
409
+ height,
410
+ width,
411
+ prompt_embeds.dtype,
412
+ device,
413
+ generator,
414
+ do_classifier_free_guidance,
415
+ )
416
+
417
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
418
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
419
+
420
+ # 8. Denoising loop
421
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
422
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
423
+ for i, t in enumerate(timesteps):
424
+ # expand the latents if we are doing classifier free guidance
425
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
426
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
427
+
428
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
429
+ latent_model_input,
430
+ t,
431
+ encoder_hidden_states=prompt_embeds,
432
+ controlnet_cond=control_image,
433
+ return_dict=False,
434
+ )
435
+
436
+ down_block_res_samples = [
437
+ down_block_res_sample * controlnet_conditioning_scale
438
+ for down_block_res_sample in down_block_res_samples
439
+ ]
440
+ mid_block_res_sample *= controlnet_conditioning_scale
441
+
442
+ # predict the noise residual
443
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
444
+ noise_pred = self.unet(
445
+ latent_model_input,
446
+ t,
447
+ encoder_hidden_states=prompt_embeds,
448
+ cross_attention_kwargs=cross_attention_kwargs,
449
+ down_block_additional_residuals=down_block_res_samples,
450
+ mid_block_additional_residual=mid_block_res_sample,
451
+ ).sample
452
+
453
+ # perform guidance
454
+ if do_classifier_free_guidance:
455
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
456
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
457
+
458
+ # compute the previous noisy sample x_t -> x_t-1
459
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
460
+
461
+ # call the callback, if provided
462
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
463
+ progress_bar.update()
464
+ if callback is not None and i % callback_steps == 0:
465
+ callback(i, t, latents)
466
+
467
+ # If we do sequential model offloading, let's offload unet and controlnet
468
+ # manually for max memory savings
469
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
470
+ self.unet.to("cpu")
471
+ self.controlnet.to("cpu")
472
+ torch.cuda.empty_cache()
473
+
474
+ if output_type == "latent":
475
+ image = latents
476
+ has_nsfw_concept = None
477
+ elif output_type == "pil":
478
+ # 8. Post-processing
479
+ image = self.decode_latents(latents)
480
+
481
+ # 9. Run safety checker
482
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
483
+
484
+ # 10. Convert to PIL
485
+ image = self.numpy_to_pil(image)
486
+ else:
487
+ # 8. Post-processing
488
+ image = self.decode_latents(latents)
489
+
490
+ # 9. Run safety checker
491
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
492
+
493
+ # Offload last model to CPU
494
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
495
+ self.final_offload_hook.offload()
496
+
497
+ if not return_dict:
498
+ return (image, has_nsfw_concept)
499
+
500
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
src/st_style.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ button_style = """
2
+ <style>
3
+ div.stButton > button:first-child {
4
+ background-color: rgb(255, 75, 75);
5
+ color: rgb(255, 255, 255);
6
+ }
7
+ div.stButton > button:hover {
8
+ background-color: rgb(255, 75, 75);
9
+ color: rgb(255, 255, 255);
10
+ }
11
+ div.stButton > button:active {
12
+ background-color: rgb(255, 75, 75);
13
+ color: rgb(255, 255, 255);
14
+ }
15
+ div.stButton > button:focus {
16
+ background-color: rgb(255, 75, 75);
17
+ color: rgb(255, 255, 255);
18
+ }
19
+ .css-1cpxqw2:focus:not(:active) {
20
+ background-color: rgb(255, 75, 75);
21
+ border-color: rgb(255, 75, 75);
22
+ color: rgb(255, 255, 255);
23
+ }
24
+ """
25
+
26
+ style = """
27
+ <style>
28
+ #MainMenu {
29
+ visibility: hidden;
30
+ }
31
+ footer {
32
+ visibility: hidden;
33
+ }
34
+ header {
35
+ visibility: hidden;
36
+ }
37
+ </style>
38
+ """
39
+
40
+
41
+ def apply_prod_style(st):
42
+ return st.markdown(style, unsafe_allow_html=True)