Geonmo commited on
Commit
58554eb
·
1 Parent(s): 2b48bba

add app.py

Browse files
Files changed (2) hide show
  1. app.py +804 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graphit
3
+ Copyright (c) 2023-present NAVER Corp.
4
+ Apache-2.0
5
+ """
6
+ import os
7
+ import numpy as np
8
+ import base64
9
+ import requests
10
+ from io import BytesIO
11
+ import json
12
+ import time
13
+ import math
14
+ import argparse
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import gradio as gr
19
+
20
+ import types
21
+ from typing import Union, List, Optional, Callable
22
+ import diffusers
23
+ import torch
24
+ from diffusers import AutoencoderKL, UNet2DConditionModel
25
+ from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
26
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
27
+ from diffusers.models import AutoencoderKL
28
+ from transformers import CLIPTextModel
29
+
30
+ import datasets
31
+
32
+ from torchvision import transforms
33
+ from torchvision.transforms.functional import to_pil_image, pil_to_tensor
34
+
35
+ import PIL
36
+ from PIL import Image, ImageOps
37
+
38
+ import compodiff
39
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
40
+ from transparent_background import Remover
41
+ from huggingface_hub import hf_hub_url, cached_download
42
+ from RealESRGAN import RealESRGAN
43
+ import einops
44
+ import cv2
45
+ from skimage import segmentation, color, graph
46
+ import random
47
+
48
+
49
+ def preprocess(image, mode):
50
+ image = np.array(image)[None, :].astype(np.float32) / 255.0
51
+ image = image
52
+ image = image.transpose(0, 3, 1, 2)
53
+ image = 2.0 * image - 1.0
54
+ if mode == 'scr2i':
55
+ image[image > 0.0] = 0.0
56
+ image = torch.from_numpy(image)
57
+ return image
58
+
59
+
60
+ class GraphitPipeline(StableDiffusionInstructPix2PixPipeline):
61
+ '''
62
+ override:
63
+ /opt/conda/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
64
+ '''
65
+ def prepare_image_latents(
66
+ self, image, mask, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
67
+ ):
68
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
69
+ raise ValueError(
70
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
71
+ )
72
+
73
+ image = image.to(device=device, dtype=dtype)
74
+ mask = mask.to(device=device, dtype=dtype)
75
+
76
+ batch_size = batch_size * num_images_per_prompt
77
+ if isinstance(generator, list) and len(generator) != batch_size:
78
+ raise ValueError(
79
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
80
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
81
+ )
82
+
83
+ if isinstance(generator, list):
84
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
85
+ image_latents = torch.cat(image_latents, dim=0)
86
+ else:
87
+ image_latents = self.vae.encode(image).latent_dist.mode()
88
+
89
+ mask = torch.nn.functional.interpolate(
90
+ mask, #.unsqueeze(0).unsqueeze(0),
91
+ size=(image_latents.shape[-2], image_latents.shape[-1]),
92
+ mode='bicubic',
93
+ align_corners=False,
94
+ )
95
+
96
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
97
+ # expand image_latents for batch_size
98
+ deprecation_message = (
99
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
100
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
101
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
102
+ " your script to pass as many initial images as text prompts to suppress this warning."
103
+ )
104
+ #deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
105
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
106
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
107
+ mask = torch.cat([mask] * additional_image_per_prompt, dim=0)
108
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
109
+ raise ValueError(
110
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
111
+ )
112
+ else:
113
+ image_latents = torch.cat([image_latents], dim=0)
114
+ image_latents *= 0.18215
115
+ if do_classifier_free_guidance:
116
+ uncond_image_latents = torch.zeros_like(image_latents)
117
+ image_latents = torch.cat([image_latents, image_latents], dim=0)
118
+ mask = torch.cat([mask, mask], dim=0)
119
+ image_latents = torch.cat([image_latents, mask], dim=1)
120
+
121
+ return image_latents
122
+
123
+ @torch.no_grad()
124
+ def __call__(
125
+ self,
126
+ prompt: Union[str, List[str]] = None,
127
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
128
+ mask: Union[torch.FloatTensor, PIL.Image.Image] = None,
129
+ depth_map: Union[torch.FloatTensor, PIL.Image.Image] = None,
130
+ num_inference_steps: int = 100,
131
+ guidance_scale: float = 3.5,
132
+ use_depth_map_as_input: bool = False,
133
+ apply_mask_to_input: bool = True,
134
+ mode: str = None,
135
+ negative_prompt: Optional[Union[str, List[str]]] = None,
136
+ num_images_per_prompt: Optional[int] = 1,
137
+ eta: float = 0.0,
138
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
139
+ latents: Optional[torch.FloatTensor] = None,
140
+ image_cond_embeds: Optional[torch.FloatTensor] = None,
141
+ negative_image_cond_embeds: Optional[torch.FloatTensor] = None,
142
+ output_type: Optional[str] = "pil",
143
+ return_dict: bool = True,
144
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
145
+ callback_steps: Optional[int] = 1,
146
+ ):
147
+ # 0. Check inputs
148
+ self.check_inputs(prompt, callback_steps)
149
+
150
+ if image is None:
151
+ raise ValueError("`image` input cannot be undefined.")
152
+
153
+ # 1. Define call parameters
154
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
155
+ device = self._execution_device
156
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
157
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
158
+ # corresponds to doing no classifier free guidance.
159
+ do_classifier_free_guidance = True#guidance_scale >= 1.0 and image_guidance_scale >= 1.0
160
+ # check if scheduler is in sigmas space
161
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
162
+
163
+ # 2. Encode input prompt
164
+ cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
165
+ cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt).to(torch.float16)
166
+ prompt_embeds = cond_embeds
167
+
168
+ # 3. Preprocess image
169
+ image = preprocess(image, mode)
170
+
171
+ if len(mask.shape) > 2:
172
+ edge_map = mask[:,:,1:]
173
+ edge_map = preprocess(edge_map, mode)
174
+ mask = mask[:,:,0]
175
+ else:
176
+ edge_map = None
177
+ mask = mask.unsqueeze(0).unsqueeze(0)
178
+ if torch.sum(mask).item() == 0.0 and use_depth_map_as_input:
179
+ image = depth_map
180
+ if edge_map is None:
181
+ if apply_mask_to_input:
182
+ image = image * (1 - mask)
183
+ else:
184
+ image = image * (1 - mask) + edge_map * mask
185
+ height, width = image.shape[-2:]
186
+
187
+ # 4. set timesteps
188
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
189
+ timesteps = self.scheduler.timesteps
190
+
191
+ # 5. Prepare Image latents
192
+ image_latents = self.prepare_image_latents(
193
+ image,
194
+ mask,
195
+ batch_size,
196
+ num_images_per_prompt,
197
+ prompt_embeds.dtype,
198
+ device,
199
+ do_classifier_free_guidance,
200
+ generator,
201
+ )
202
+
203
+ if mode == 't2i':
204
+ image_latents = torch.zeros_like(image_latents)
205
+
206
+ # 6. Prepare latent variables
207
+ num_channels_latents = self.vae.config.latent_channels
208
+ latents = self.prepare_latents(
209
+ batch_size * num_images_per_prompt,
210
+ num_channels_latents,
211
+ height,
212
+ width,
213
+ prompt_embeds.dtype,
214
+ device,
215
+ generator,
216
+ latents,
217
+ )
218
+
219
+ # 7. Check that shapes of latents and image match the UNet channels
220
+ num_channels_image = image_latents.shape[1]
221
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
222
+ raise ValueError(
223
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
224
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
225
+ f" `num_channels_image`: {num_channels_image} "
226
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
227
+ " `pipeline.unet` or your `image` input."
228
+ )
229
+
230
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
231
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
232
+
233
+ # 9. Denoising loop
234
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
235
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
236
+ for i, t in enumerate(timesteps):
237
+ # Expand the latents if we are doing classifier free guidance.
238
+ # The latents are expanded 3 times because for pix2pix the guidance\
239
+ # is applied for both the text and the input image.
240
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
241
+
242
+ # concat latents, image_latents in the channel dimension
243
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
244
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
245
+
246
+ # predict the noise residual
247
+ noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
248
+
249
+ # Hack:
250
+ # For karras style schedulers the model does classifer free guidance using the
251
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
252
+ # predicted_original_sample here if we are using a karras style scheduler.
253
+ if scheduler_is_in_sigma_space:
254
+ step_index = (self.scheduler.timesteps == t).nonzero().item()
255
+ sigma = self.scheduler.sigmas[step_index]
256
+ noise_pred = latent_model_input - sigma * noise_pred
257
+
258
+ # perform guidance
259
+ if do_classifier_free_guidance:
260
+ noise_pred_full, noise_pred_uncond = noise_pred.chunk(2)
261
+ noise_pred = (
262
+ noise_pred_uncond
263
+ + guidance_scale * (noise_pred_full - noise_pred_uncond)
264
+ )
265
+
266
+ # Hack:
267
+ # For karras style schedulers the model does classifer free guidance using the
268
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
269
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
270
+ # need to overwrite the noise_pred here such that the value of the computed
271
+ # predicted_original_sample is correct.
272
+ if scheduler_is_in_sigma_space:
273
+ noise_pred = (noise_pred - latents) / (-sigma)
274
+
275
+ # compute the previous noisy sample x_t -> x_t-1
276
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
277
+
278
+ # call the callback, if provided
279
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
280
+ progress_bar.update()
281
+ if callback is not None and i % callback_steps == 0:
282
+ callback(i, t, latents)
283
+
284
+ # 10. Post-processing
285
+ image = self.decode_latents(latents)
286
+
287
+ # 11. Run safety checker
288
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
289
+
290
+ # 12. Convert to PIL
291
+ if output_type == "pil":
292
+ image = self.numpy_to_pil(image)
293
+
294
+ if not return_dict:
295
+ return (image, has_nsfw_concept)
296
+
297
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
298
+
299
+
300
+ class CustomRealESRGAN(RealESRGAN):
301
+ @torch.no_grad()
302
+ @torch.cuda.amp.autocast()
303
+ def predict(self, pil_lr_image_list):
304
+ device = self.device
305
+ # batchfy
306
+ batch_lr_images = (torch.stack([pil_to_tensor(pil_lr_image) for pil_lr_image in pil_lr_image_list]).float() / 255).to(device)
307
+ batch_outputs = self.model(batch_lr_images).clamp_(0, 1)
308
+
309
+ # to pil images
310
+ return [to_pil_image(output) for output in batch_outputs]
311
+
312
+
313
+ def build_models(args):
314
+ # Load scheduler, tokenizer and models.
315
+
316
+ model_path = 'navervision/Graphit-SD'
317
+ unet = UNet2DConditionModel.from_pretrained(
318
+ model_path, torch_dtype=torch.float16,
319
+ )
320
+
321
+ vae_name = 'stabilityai/sd-vae-ft-ema'
322
+ vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=torch.float16)
323
+
324
+ model_name = 'timbrooks/instruct-pix2pix'
325
+ pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=None,
326
+ unet = unet,
327
+ vae = vae,
328
+ )
329
+ pipe = pipe.to('cuda:0')
330
+
331
+ ## load CompoDiff
332
+ compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
333
+ compodiff_model, clip_model = compodiff_model.to('cuda:0'), clip_model.to('cuda:0')
334
+
335
+ ## load third-party models
336
+ model_name = 'Intel/dpt-large'
337
+ depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
338
+ depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=torch.float16)
339
+ depth_predictor = depth_predictor.to('cuda:0')
340
+
341
+ if not os.path.exists('./third_party/remover_fast.pth'):
342
+ model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
343
+ cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
344
+ remover = Remover(fast=True, jit=False, device='cuda:0', ckpt='./third_party/remover_fast.pth')
345
+
346
+ sr_model = CustomRealESRGAN('cuda:0', scale=2)
347
+ sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
348
+
349
+ dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
350
+
351
+ train = dataset["train"]
352
+ prompts = train["text"]
353
+
354
+ model_dict = {'pipe': pipe,
355
+ 'compodiff': compodiff_model,
356
+ 'clip_preprocess': clip_preprocess,
357
+ 'clip_tokenizer': clip_tokenizer,
358
+ 'clip_model': clip_model,
359
+ 'depth_preprocess': depth_preprocess,
360
+ 'depth_predictor': depth_predictor,
361
+ 'remover': remover,
362
+ 'sr_model': sr_model,
363
+ 'prompt_candidates': prompts,
364
+ }
365
+ return model_dict
366
+
367
+
368
+ def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
369
+ text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
370
+ text_tokens, text_attention_mask = text_token_dict['input_ids'].to('cuda:0'), text_token_dict['attention_mask'].to('cuda:0')
371
+
372
+ negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
373
+ negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to('cuda:0'), text_token_dict['attention_mask'].to('cuda:0')
374
+
375
+ with torch.no_grad():
376
+ if image is None:
377
+ image_cond = torch.zeros([1,1,768]).to('cuda:0')
378
+ mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to('cuda:0').unsqueeze(0)
379
+ else:
380
+ image_source = image.resize((512, 512))
381
+ image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to('cuda:0')
382
+ mask = mask.resize((512, 512))
383
+ mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
384
+ mask = mask[:,:1,:,:]
385
+ mask = (mask > 0.5).float().to('cuda:0')
386
+ image_source = image_source * (1 - mask)
387
+ image_cond = model_dict['clip_model'].encode_images(image_source)
388
+ mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
389
+ mask = (mask > 0.5).float()
390
+
391
+ text_cond = model_dict['clip_model'].encode_texts(text_tokens, text_attention_mask)
392
+ negative_text_cond = model_dict['clip_model'].encode_texts(negative_text_tokens, negative_text_attention_mask)
393
+
394
+ sampled_image_features = model_dict['compodiff'].sample(image_cond, text_cond, negative_text_cond, mask, timesteps=25, cond_scale=(1.0 if image is None else 1.3, cfg_text_scale), num_samples_per_batch=4, random_seed=random_seed).unsqueeze(1)
395
+ return sampled_image_features, image_cond
396
+
397
+
398
+ def generate_depth_map(image, height, width):
399
+ depth_inputs = {k: v.to('cuda:0', dtype=torch.float16) for k, v in model_dict['depth_preprocess'](images=image, return_tensors='pt').items()}
400
+ depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
401
+ depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
402
+ depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
403
+ depth_map = 2.0 * ((depth_map - depth_min) / (depth_max - depth_min)) - 1.0
404
+ depth_map = torch.nn.functional.interpolate(
405
+ depth_map,
406
+ size=(height, width),
407
+ mode='bicubic',
408
+ align_corners=False,
409
+ )
410
+ return depth_map
411
+
412
+
413
+ def generate_color(image, compactness=30, n_segments=100, thresh=35, blur_kernel=3, blur_std=0):
414
+ img = image # 0 ~ 255 uint8
415
+ labels = segmentation.slic(img, compactness=compactness, n_segments=n_segments)#, start_label=1)
416
+ g = graph.rag_mean_color(img, labels)
417
+ labels2 = graph.cut_threshold(labels, g, thresh=thresh)
418
+ out = color.label2rgb(labels2, img, kind='avg', bg_label=-1)
419
+ return out
420
+
421
+
422
+ @torch.no_grad()
423
+ def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
424
+ text_input = text_input.lower()
425
+ if negative_prompt == '':
426
+ print('running without a negative prompt')
427
+ # prepare an input image
428
+ use_mask = False
429
+ mask = None
430
+ is_null_image_source = False
431
+ if type(image_source) == dict:
432
+ image_source, mask = image_source['image'], image_source['mask']
433
+ elif image_source is None:
434
+ image_source = Image.fromarray(np.zeros([t2i_height, t2i_width, 3]).astype('uint8'))
435
+ is_null_image_source = True
436
+
437
+ try:
438
+ image_source = ImageOps.exif_transpose(image_source)
439
+ except:
440
+ pass
441
+
442
+ width, height = image_source.size
443
+ factor = 512 / max(width, height)
444
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
445
+ width = int((width * factor) // 64) * 64
446
+ height = int((height * factor) // 64) * 64
447
+
448
+ image_source = org_image_source = ImageOps.fit(image_source, (width, height), method=Image.Resampling.LANCZOS)
449
+
450
+ if mask is not None:
451
+ mask_pil = mask = ImageOps.fit(mask, (width, height), method=Image.Resampling.LANCZOS)
452
+ mask = ((torch.tensor(np.array(mask.convert('L'))).float() / 255.0) > 0.5).float()
453
+ if torch.sum(mask).item() > 0.0:
454
+ print('now using mask')
455
+ use_mask = True
456
+ else:
457
+ mask = torch.zeros([height, width])
458
+ mask_pil = to_pil_image(mask)
459
+
460
+ use_depth_map_as_input = False
461
+ if mode == 's2i' or mode == 'scr2i': # sketch to image
462
+ image_source = mask
463
+ image_source = einops.repeat(image_source, 'h w -> r h w', r=3)
464
+ mask = image_source[0,:,:]
465
+ image_source = org_image_source = to_pil_image(image_source)
466
+ mask_pil = to_pil_image(mask)
467
+ mask *= mask_scale
468
+ use_mask = False
469
+ elif mode == 'cs2i':
470
+ mask = torch.tensor((np.array(image_source)[:,:,0] != 255)).float() * mask_scale
471
+ mask_pil = Image.fromarray(((np.array(image_source)[:,:,0] != 255) * 255).astype('uint8'))
472
+ use_mask = False #True
473
+ elif mode == 'd2i': # depth to image
474
+ use_depth_map_as_input = True
475
+ elif mode == 'e2i': # edge to image
476
+ image_source = einops.repeat(cv2.Canny(cv2.cvtColor(np.array(image_source)[:,:,::-1], cv2.COLOR_BGR2GRAY), threshold1=100, threshold2=200), 'h w -> h w r', r=3)
477
+ image_source = Image.fromarray(image_source) #to_pil_image(image_source)
478
+ org_image_source = image_source
479
+ elif mode == 'inped':
480
+ # mask = torch.Size([512, 512])
481
+ mask_np = (einops.repeat(mask.numpy(), 'h w -> h w r', r=1) * 255).astype('uint8')
482
+ gray = mask_np #cv2.cvtColor(mask_np, cv2.COLOR_BGR2GRAY)
483
+ _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
484
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
485
+ x, y, w, h = cv2.boundingRect(contours[0])
486
+ cv2.rectangle(mask_np, (x, y), (x+w, y+h), 255, -1)
487
+ mask_np = mask_np.astype('float32') / 255
488
+ if image_reference is not None:
489
+ edge_reference = image_reference.resize((w, h))
490
+ color_map = generate_color(np.array(edge_reference)).astype('float32')
491
+ reference_map = (model_dict['remover'].process(edge_reference, type='map') > 16).astype('float32')
492
+ edge_reference = einops.repeat(cv2.Canny(cv2.cvtColor(np.array(edge_reference)[:,:,::-1], cv2.COLOR_BGR2GRAY), threshold1=100, threshold2=200), 'h w -> h w r', r=3).astype('float32')
493
+ edge_np = np.zeros_like(np.array(image_source)).astype('float32')
494
+ if text_input != '':
495
+ edge_np[y:y+h,x:x+w] = edge_reference * reference_map
496
+ elif use_edge and mask_scale > 0.0:
497
+ print('mode: color inped with with_edge')
498
+ edge_np[y:y+h,x:x+w] = (255 - edge_reference) / 255 * color_map * reference_map + (1 - mask_scale) * edge_reference / 255 * reference_map
499
+ else:
500
+ print('mode: color inped with no_edge')
501
+ edge_np[y:y+h,x:x+w] = color_map * reference_map
502
+ mask_np = np.zeros_like(np.array(image_source)).astype('float32')
503
+ mask_np[y:y+h,x:x+w] = reference_map #edge_reference
504
+ mask_np = mask_np[:,:,:1]
505
+ else:
506
+ edge_np = einops.repeat(cv2.Canny(cv2.cvtColor(np.array(image_source)[:,:,::-1], cv2.COLOR_BGR2GRAY), threshold1=100, threshold2=200), 'h w -> h w r', r=3).astype('float32')
507
+ # concat edge to mask_np
508
+ mask = torch.tensor(np.concatenate([mask_np, edge_np], axis=-1))
509
+ mask_pil = to_pil_image(mask_np[:,:,0].astype('uint8') * 255)
510
+ #mask_pil = to_pil_image((mask_np[:,:,0] * 255).astype('uint8'))
511
+
512
+ with torch.no_grad():
513
+ # do reference first
514
+ if image_reference is not None:
515
+ image_cond_reference = ImageOps.exif_transpose(image_reference)
516
+ image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to('cuda:0')
517
+ image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
518
+ else:
519
+ image_cond_reference = torch.zeros([1, 1, 768]).to(torch.float16).to('cuda:0')
520
+
521
+ # do source or knn
522
+ image_cond_source = None
523
+ if text_input != '':
524
+ if mode in ['t2i', 'd2i', 'e2i', 's2i', 'scr2i', 'cs2i']:
525
+ if mode == 'cs2i':
526
+ image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
527
+ image_cond_color_compensation, _ = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
528
+ image_cond = 0.9 * image_cond + 0.1 * image_cond_color_compensation
529
+ else:
530
+ image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
531
+ else:
532
+ image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
533
+ image_cond = image_cond.to(torch.float16).to('cuda:0')
534
+ image_cond_source = image_cond_source.to(torch.float16).to('cuda:0')
535
+ else:
536
+ image_cond = torch.zeros([1, 1, 768]).to(torch.float16).to('cuda:0')
537
+
538
+ if image_cond_source is None and mode != 't2i':
539
+ image_cond_source = image_source.resize((512, 512))
540
+ image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to('cuda:0')
541
+ image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
542
+
543
+ if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
544
+ if torch.sum(image_cond).item() == 0.0:
545
+ image_cond = image_cond_reference
546
+ else:
547
+ image_cond = (1.0 - cfg_image_reference_mix_weight) * image_cond + cfg_image_reference_mix_weight * image_cond_reference
548
+
549
+ if cfg_image_source_mix_weight > 0.0:
550
+ image_cond = (1.0 - cfg_image_source_mix_weight) * image_cond + cfg_image_source_mix_weight * image_cond_source
551
+
552
+ if negative_prompt != '':
553
+ negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
554
+ negative_image_cond = negative_image_cond.to(torch.float16).to('cuda:0')
555
+ else:
556
+ negative_image_cond = torch.zeros_like(image_cond)
557
+
558
+ # negative_prompt_embeds
559
+ image_source = torch.tensor(np.array(image_source))
560
+ depth_map = einops.repeat(generate_depth_map(image_source, height, width), 'n c h w -> n (c r) h w', r=3).float().cpu()
561
+
562
+ images = model_dict['pipe'](text_input,
563
+ image=image_source,
564
+ mask=mask,
565
+ depth_map=depth_map,
566
+ num_inference_steps=int(steps),
567
+ image_cond_embeds=image_cond,
568
+ negative_image_cond_embeds=negative_image_cond,
569
+ guidance_scale=cfg_image_space_scale,
570
+ use_depth_map_as_input=use_depth_map_as_input,
571
+ apply_mask_to_input=use_mask,
572
+ mode=mode,
573
+ generator=torch.manual_seed(random_seed),
574
+ num_images_per_prompt=2).images
575
+ if do_sr:
576
+ images = model_dict['sr_model'].predict(images)
577
+
578
+ return images, [org_image_source, mask_pil, to_pil_image(0.5 * (depth_map[0] + 1.0))]
579
+
580
+
581
+ def generate_canvas(image):
582
+ return Image.fromarray((np.ones([512, 512, 3]) * 255).astype('uint8'))
583
+
584
+
585
+ def surprise_me():
586
+ return random.sample(model_dict['prompt_candidates'], k=1)[0]
587
+
588
+
589
+ if __name__ == "__main__":
590
+ parser = argparse.ArgumentParser('Demo')
591
+ parser.add_argument('--model_folder', default=None, type=str, help='path to model_folder')
592
+
593
+ args = parser.parse_args()
594
+
595
+
596
+ global model_dict
597
+
598
+ model_dict = build_models(args)
599
+
600
+ ### define gradio demo
601
+ title = 'Graphit demo'
602
+
603
+ md_title = f'''# {title}
604
+ Diffusion on GPU.
605
+ '''
606
+ neg_default = 'watermark, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
607
+ with gr.Blocks(title=title) as demo:
608
+ gr.Markdown(md_title)
609
+ mode_t2i = gr.Textbox(value='t2i', label='mode selection', visible=False)
610
+ mode_i2i = gr.Textbox(value='i2i', label='mode selection', visible=False)
611
+ mode_inpaint = gr.Textbox(value='inpaint', label='mode selection', visible=False)
612
+ mode_s2i = gr.Textbox(value='s2i', label='mode selection', visible=False)
613
+ mode_scr2i = gr.Textbox(value='scr2i', label='mode selection', visible=False)
614
+ mode_d2i = gr.Textbox(value='d2i', label='mode selection', visible=False)
615
+ mode_e2i = gr.Textbox(value='e2i', label='mode selection', visible=False)
616
+ mode_inped = gr.Textbox(value='inped', label='mode selection', visible=False)
617
+ mode_cs2i = gr.Textbox(value='cs2i', label='mode selection', visible=False)
618
+ mask_scale_default = gr.Number(value=1.0, label='mask scale', visible=False)
619
+ use_edge_default = gr.Checkbox(value=True, label='use color map with edge map', visible=False)
620
+ height_default = gr.Number(value=512, precision=0, label='height', visible=False)
621
+ width_default = gr.Number(value=512, precision=0, label='width', visible=False)
622
+ with gr.Row():
623
+ with gr.Column():
624
+ with gr.Tabs():
625
+ '''
626
+ image to image
627
+ inpainting
628
+ depth to image
629
+ saliency map to image
630
+ '''
631
+ with gr.TabItem("Text to Image"):
632
+ image_source_t2i = gr.Image(type='pil', label='Source image', visible=False)
633
+ with gr.Row():
634
+ steps_input_t2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
635
+ random_seed_t2i = gr.Number(value=12345, precision=0, label='Seed')
636
+ with gr.Accordion('Advanced options', open=False):
637
+ with gr.Row():
638
+ cfg_image_scale_t2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
639
+ cfg_image_space_scale_t2i = gr.Number(value=7.5, label='attn image space scale')
640
+ cfg_text_scale_t2i = gr.Number(value=7.5, label='attn text scale')
641
+ negative_text_input_t2i = gr.Textbox(value=neg_default, label='Negative text')
642
+ with gr.Row():
643
+ cfg_image_source_mix_weight_t2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
644
+ cfg_image_reference_mix_weight_t2i = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
645
+ with gr.Row():
646
+ height_t2i = gr.Number(value=512, precision=0, label='height (~512)')
647
+ width_t2i = gr.Number(value=512, precision=0, label='width (~512)')
648
+ submit_button_t2i = gr.Button('Generate images')
649
+ with gr.TabItem("Image to Image"):
650
+ image_source_i2i = gr.Image(type='pil', label='Source image')
651
+ with gr.Row():
652
+ steps_input_i2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
653
+ random_seed_i2i = gr.Number(value=12345, precision=0, label='Seed')
654
+ with gr.Accordion('Advanced options', open=False):
655
+ with gr.Row():
656
+ cfg_image_scale_i2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
657
+ cfg_image_space_scale_i2i = gr.Number(value=7.5, label='attn image space scale')
658
+ cfg_text_scale_i2i = gr.Number(value=7.5, label='attn text scale')
659
+ negative_text_input_i2i = gr.Textbox(value=neg_default, label='Negative text')
660
+ with gr.Row():
661
+ cfg_image_source_mix_weight_i2i = gr.Number(value=0.05, label='weight for mixing source image (0.0~1.0)')
662
+ cfg_image_reference_mix_weight_i2i = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
663
+ submit_button_i2i = gr.Button('Generate images')
664
+ with gr.TabItem("Depth to Image"):
665
+ image_source_d2i = gr.Image(type='pil', label='Source image')
666
+ with gr.Row():
667
+ steps_input_d2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
668
+ random_seed_d2i = gr.Number(value=12345, precision=0, label='Seed')
669
+ with gr.Accordion('Advanced options', open=False):
670
+ with gr.Row():
671
+ cfg_image_scale_d2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
672
+ cfg_image_space_scale_d2i = gr.Number(value=7.5, label='attn image space scale')
673
+ cfg_text_scale_d2i = gr.Number(value=7.5, label='attn text scale')
674
+ negative_text_input_d2i = gr.Textbox(value=neg_default, label='Negative text')
675
+ with gr.Row():
676
+ cfg_image_source_mix_weight_d2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
677
+ cfg_image_reference_mix_weight_d2i = gr.Number(value=1.0, label='weight for mixing reference image (0.0~1.0)')
678
+ submit_button_d2i = gr.Button('Generate images')
679
+ with gr.TabItem("Edge to Image"):
680
+ image_source_e2i = gr.Image(type='pil', label='Source image')
681
+ with gr.Row():
682
+ steps_input_e2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
683
+ random_seed_e2i = gr.Number(value=12345, precision=0, label='Seed')
684
+ with gr.Accordion('Advanced options', open=False):
685
+ with gr.Row():
686
+ cfg_image_scale_e2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
687
+ cfg_image_space_scale_e2i = gr.Number(value=7.5, label='attn image space scale')
688
+ cfg_text_scale_e2i = gr.Number(value=7.5, label='attn text scale')
689
+ negative_text_input_e2i = gr.Textbox(value=neg_default, label='Negative text')
690
+ with gr.Row():
691
+ cfg_image_source_mix_weight_e2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
692
+ cfg_image_reference_mix_weight_e2i = gr.Number(value=1.0, label='weight for mixing reference image (0.0~1.0)')
693
+ submit_button_e2i = gr.Button('Generate images')
694
+ with gr.TabItem("Inpaint"):
695
+ image_source_inp = gr.Image(type='pil', label='Source image', tool='sketch')
696
+ with gr.Row():
697
+ steps_input_inp = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
698
+ random_seed_inp = gr.Number(value=12345, precision=0, label='Seed')
699
+ with gr.Accordion('Advanced options', open=False):
700
+ with gr.Row():
701
+ cfg_image_scale_inp = gr.Number(value=1.1, label='attn source image scale', visible=False)
702
+ cfg_image_space_scale_inp = gr.Number(value=7.5, label='attn image space scale')
703
+ cfg_text_scale_inp = gr.Number(value=7.5, label='attn text scale')
704
+ negative_text_input_inp = gr.Textbox(value='', label='Negative text')
705
+ with gr.Row():
706
+ cfg_image_source_mix_weight_inp = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
707
+ cfg_image_reference_mix_weight_inp = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
708
+ submit_button_inp = gr.Button('Generate images')
709
+ with gr.TabItem("Blending"):
710
+ image_source_inped = gr.Image(type='pil', label='Source image', tool='sketch')
711
+ with gr.Row():
712
+ steps_input_inped = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
713
+ random_seed_inped = gr.Number(value=12345, precision=0, label='Seed')
714
+ with gr.Accordion('Advanced options', open=False):
715
+ with gr.Row():
716
+ cfg_image_scale_inped = gr.Number(value=1.1, label='attn source image scale', visible=False)
717
+ cfg_image_space_scale_inped = gr.Number(value=7.5, label='attn image space scale')
718
+ cfg_text_scale_inped = gr.Number(value=7.5, label='attn text scale')
719
+ negative_text_input_inped = gr.Textbox(value=neg_default, label='Negative text')
720
+ with gr.Row():
721
+ cfg_image_source_mix_weight_inped = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
722
+ cfg_image_reference_mix_weight_inped = gr.Number(value=0.35, label='weight for mixing reference image (0.0~1.0)')
723
+ with gr.Row():
724
+ mask_scale_inped = gr.Number(value=1.0, label='edge scale')
725
+ use_edge_inped = gr.Checkbox(value=False, label='use a color map with an edge map')
726
+ submit_button_inped = gr.Button('Generate images')
727
+ with gr.TabItem("Sketch (Rough) to Image"):
728
+ with gr.Column():
729
+ image_source_s2i = gr.Image(type='pil', label='Source image', tool='sketch', brush_radius=100).style(height=256, width=256)
730
+ build_canvas_s2i = gr.Button('Build canvas')
731
+ with gr.Row():
732
+ steps_input_s2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
733
+ random_seed_s2i = gr.Number(value=12345, precision=0, label='Seed')
734
+ with gr.Accordion('Advanced options', open=False):
735
+ with gr.Row():
736
+ cfg_image_scale_s2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
737
+ cfg_image_space_scale_s2i = gr.Number(value=7.5, label='attn image space scale')
738
+ cfg_text_scale_s2i = gr.Number(value=7.5, label='attn text scale')
739
+ negative_text_input_s2i = gr.Textbox(value=neg_default, label='Negative text')
740
+ with gr.Row():
741
+ cfg_image_source_mix_weight_s2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
742
+ cfg_image_reference_mix_weight_s2i = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
743
+ mask_scale_s2i = gr.Number(value=0.5, label='sketch weight (0.0~1.0)')
744
+ submit_button_s2i = gr.Button('Generate images')
745
+ with gr.TabItem("Sketch (Detail) to Image"):
746
+ with gr.Column():
747
+ image_source_scr2i = gr.Image(type='pil', label='Source image', tool='sketch', brush_radius=10).style(height=256, width=256)
748
+ build_canvas_scr2i = gr.Button('Build canvas')
749
+ with gr.Row():
750
+ steps_input_scr2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
751
+ random_seed_scr2i = gr.Number(value=12345, precision=0, label='Seed')
752
+ with gr.Accordion('Advanced options', open=False):
753
+ with gr.Row():
754
+ cfg_image_scale_scr2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
755
+ cfg_image_space_scale_scr2i = gr.Number(value=7.5, label='attn image space scale')
756
+ cfg_text_scale_scr2i = gr.Number(value=7.5, label='attn text scale')
757
+ negative_text_input_scr2i = gr.Textbox(value=neg_default, label='Negative text')
758
+ with gr.Row():
759
+ cfg_image_source_mix_weight_scr2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
760
+ cfg_image_reference_mix_weight_scr2i = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
761
+ mask_scale_scr2i = gr.Number(value=0.5, label='sketch weight (0.0~1.0)')
762
+ submit_button_scr2i = gr.Button('Generate images')
763
+ with gr.TabItem("Color Sketch to Image"):
764
+ with gr.Column():
765
+ image_source_cs2i = gr.Image(type='pil', source='canvas', label='Source image', tool='color-sketch').style(height=256, width=256)
766
+ #build_canvas_cs2i = gr.Button('Build canvas')
767
+ with gr.Row():
768
+ steps_input_cs2i = gr.Radio(['5', '10', '25', '50'], value='25', label='denoising steps')
769
+ random_seed_cs2i = gr.Number(value=12345, precision=0, label='Seed')
770
+ with gr.Accordion('Advanced options', open=False):
771
+ with gr.Row():
772
+ cfg_image_scale_cs2i = gr.Number(value=1.1, label='attn source image scale', visible=False)
773
+ cfg_image_space_scale_cs2i = gr.Number(value=7.5, label='attn image space scale')
774
+ cfg_text_scale_cs2i = gr.Number(value=7.5, label='attn text scale')
775
+ negative_text_input_cs2i = gr.Textbox(value=neg_default, label='Negative text')
776
+ with gr.Row():
777
+ cfg_image_source_mix_weight_cs2i = gr.Number(value=0.0, label='weight for mixing source image (0.0~1.0)', visible=False)
778
+ cfg_image_reference_mix_weight_cs2i = gr.Number(value=0.65, label='weight for mixing reference image (0.0~1.0)')
779
+ mask_scale_cs2i = gr.Number(value=0.5, label='sketch weight (0.0~1.0)')
780
+ submit_button_cs2i = gr.Button('Generate images')
781
+ text_input = gr.Textbox(value='', label='Input text')
782
+ submit_surprise_me = gr.Button('Surprise me')
783
+ #swap_button = gr.Button('Swap source with reference', visible=False)
784
+ with gr.Column():
785
+ with gr.Row():
786
+ do_sr = gr.Checkbox(value=False, label='Super-resolution')
787
+ image_reference = gr.Image(type='pil', label='Reference image')
788
+ gallery_outputs = gr.Gallery(label='Generated outputs').style(grid=[2], height='auto')
789
+ gallery_inputs = gr.Gallery(label='Processed inputs').style(grid=[2], height='auto')
790
+
791
+ submit_button_t2i.click(generate, inputs=[image_source_t2i, image_reference, text_input, negative_text_input_t2i, steps_input_t2i, random_seed_t2i, cfg_image_scale_t2i, cfg_text_scale_t2i, cfg_image_space_scale_t2i, cfg_image_reference_mix_weight_t2i, cfg_image_source_mix_weight_t2i, mask_scale_default, use_edge_default, height_t2i, width_t2i, do_sr, mode_t2i], outputs=[gallery_outputs, gallery_inputs])
792
+ submit_button_i2i.click(generate, inputs=[image_source_i2i, image_reference, text_input, negative_text_input_i2i, steps_input_i2i, random_seed_i2i, cfg_image_scale_i2i, cfg_text_scale_i2i, cfg_image_space_scale_i2i, cfg_image_reference_mix_weight_i2i, cfg_image_source_mix_weight_i2i, mask_scale_default, use_edge_default, height_default, width_default, do_sr, mode_i2i], outputs=[gallery_outputs, gallery_inputs])
793
+ submit_button_d2i.click(generate, inputs=[image_source_d2i, image_reference, text_input, negative_text_input_d2i, steps_input_d2i, random_seed_d2i, cfg_image_scale_d2i, cfg_text_scale_d2i, cfg_image_space_scale_d2i, cfg_image_reference_mix_weight_d2i, cfg_image_source_mix_weight_d2i, mask_scale_default, use_edge_default, height_default, width_default, do_sr, mode_d2i], outputs=[gallery_outputs, gallery_inputs])
794
+ submit_button_e2i.click(generate, inputs=[image_source_e2i, image_reference, text_input, negative_text_input_e2i, steps_input_e2i, random_seed_e2i, cfg_image_scale_e2i, cfg_text_scale_e2i, cfg_image_space_scale_e2i, cfg_image_reference_mix_weight_e2i, cfg_image_source_mix_weight_e2i, mask_scale_default, use_edge_default, height_default, width_default, do_sr, mode_e2i], outputs=[gallery_outputs, gallery_inputs])
795
+ submit_button_inp.click(generate, inputs=[image_source_inp, image_reference, text_input, negative_text_input_inp, steps_input_inp, random_seed_inp, cfg_image_scale_inp, cfg_text_scale_inp, cfg_image_space_scale_inp, cfg_image_reference_mix_weight_inp, cfg_image_source_mix_weight_inp, mask_scale_default, use_edge_default, height_default, width_default, do_sr, mode_inpaint], outputs=[gallery_outputs, gallery_inputs])
796
+ submit_button_inped.click(generate, inputs=[image_source_inped, image_reference, text_input, negative_text_input_inped, steps_input_inped, random_seed_inped, cfg_image_scale_inped, cfg_text_scale_inped, cfg_image_space_scale_inped, cfg_image_reference_mix_weight_inped, cfg_image_source_mix_weight_inped, mask_scale_inped, use_edge_inped, height_default, width_default, do_sr, mode_inped], outputs=[gallery_outputs, gallery_inputs])
797
+ submit_button_s2i.click(generate, inputs=[image_source_s2i, image_reference, text_input, negative_text_input_s2i, steps_input_s2i, random_seed_s2i, cfg_image_scale_s2i, cfg_text_scale_s2i, cfg_image_space_scale_s2i, cfg_image_reference_mix_weight_s2i, cfg_image_source_mix_weight_s2i, mask_scale_s2i, use_edge_default, height_default, width_default, do_sr, mode_s2i], outputs=[gallery_outputs, gallery_inputs])
798
+ submit_button_scr2i.click(generate, inputs=[image_source_scr2i, image_reference, text_input, negative_text_input_scr2i, steps_input_scr2i, random_seed_scr2i, cfg_image_scale_scr2i, cfg_text_scale_scr2i, cfg_image_space_scale_scr2i, cfg_image_reference_mix_weight_scr2i, cfg_image_source_mix_weight_scr2i, mask_scale_scr2i, use_edge_default, height_default, width_default, do_sr, mode_scr2i], outputs=[gallery_outputs, gallery_inputs])
799
+ submit_button_cs2i.click(generate, inputs=[image_source_cs2i, image_reference, text_input, negative_text_input_cs2i, steps_input_cs2i, random_seed_cs2i, cfg_image_scale_cs2i, cfg_text_scale_cs2i, cfg_image_space_scale_cs2i, cfg_image_reference_mix_weight_cs2i, cfg_image_source_mix_weight_cs2i, mask_scale_cs2i, use_edge_default, height_default, width_default, do_sr, mode_cs2i], outputs=[gallery_outputs, gallery_inputs])
800
+ build_canvas_s2i.click(generate_canvas, inputs=[image_source_s2i], outputs=[image_source_s2i])
801
+ build_canvas_scr2i.click(generate_canvas, inputs=[image_source_scr2i], outputs=[image_source_scr2i])
802
+ submit_surprise_me.click(surprise_me, outputs=[text_input])
803
+ demo.queue()
804
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.13.0
2
+ torchvision>=0.9
3
+ transformers
4
+ diffusers
5
+ huggingface_hub
6
+ git+https://github.com/navervision/CompoDiff.git
7
+ transparent-background
8
+ git+https://github.com/sberbank-ai/Real-ESRGAN.git
9
+ gradio