taki0112 commited on
Commit
eeb7e29
Β·
1 Parent(s): 9953e90
Files changed (1) hide show
  1. app.py +22 -218
app.py CHANGED
@@ -5,24 +5,18 @@ import gradio as gr
5
  import os, json
6
  import numpy as np
7
  from PIL import Image
8
-
9
- from pipelines.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
10
  from pipelines.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
11
- from diffusers import ControlNetModel, AutoencoderKL
12
- from transformers import DPTFeatureExtractor, DPTForDepthEstimation
13
  from random import randint
14
  from utils import init_latent
15
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
16
  from diffusers import DDIMScheduler
17
 
18
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
- # if device == 'cpu':
20
- # torch_dtype = torch.float32
21
- # else:
22
- # torch_dtype = torch.float16
23
 
24
- device = 'cuda'
25
- torch_dtype = torch.float16
26
  def memory_efficient(model):
27
  try:
28
  model.to(device)
@@ -37,33 +31,16 @@ def memory_efficient(model):
37
  model.enable_vae_slicing()
38
  except AttributeError:
39
  print("enable_vae_slicing is not supported.")
40
- # if device == 'cuda':
41
- # try:
42
- # model.enable_xformers_memory_efficient_attention()
43
- # except AttributeError:
44
- # print("enable_xformers_memory_efficient_attention is not supported.")
45
-
46
- controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch_dtype)
47
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype)
48
-
49
- model_controlnet = StableDiffusionXLControlNetPipeline.from_pretrained(
50
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch_dtype
51
- )
52
 
53
  model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
54
-
55
- print("vae")
56
- memory_efficient(vae)
57
- print("control")
58
- memory_efficient(controlnet)
59
- print("ControlNet-SDXL")
60
- memory_efficient(model_controlnet)
61
  print("SDXL")
62
  memory_efficient(model)
63
 
64
- depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
65
- feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
66
-
67
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
68
  blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch_dtype).to(device)
69
 
@@ -75,44 +52,8 @@ def parse_config(config):
75
  config = json.load(f)
76
  return config
77
 
78
- def get_depth_map(image):
79
- image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
80
- with torch.no_grad(), torch.autocast(device):
81
- depth_map = depth_estimator(image).predicted_depth
82
-
83
- depth_map = torch.nn.functional.interpolate(
84
- depth_map.unsqueeze(1),
85
- size=(1024, 1024),
86
- mode="bicubic",
87
- align_corners=False,
88
- )
89
- depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
90
- depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
91
- depth_map = (depth_map - depth_min) / (depth_max - depth_min)
92
- image = torch.cat([depth_map] * 3, dim=1)
93
-
94
- image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
95
- image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
96
- return image
97
-
98
-
99
- def get_depth_edge_array(depth_img_path):
100
- depth_image_tmp = Image.fromarray(depth_img_path)
101
-
102
- # get depth map
103
- depth_map = get_depth_map(depth_image_tmp)
104
-
105
- return depth_map
106
-
107
- def blip_inf_prompt(image):
108
- inputs = blip_processor(images=image, return_tensors="pt").to(device, torch.float16)
109
 
110
- generated_ids = blip_model.generate(**inputs)
111
- generated_text = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
112
-
113
- return generated_text
114
-
115
- def load_example_controlnet():
116
  folder_path = 'assets/ref'
117
  examples = []
118
  for filename in os.listdir(folder_path):
@@ -125,33 +66,22 @@ def load_example_controlnet():
125
  config = parse_config(config_path)
126
  inf_object_name = config["inference_info"]["inf_object_list"][0]
127
 
128
- canny_path = './assets/depth_dir/gundam.png'
129
- image_info = [image_path, canny_path, style_name, inf_object_name, 1, 0.5, 50]
130
-
131
  examples.append(image_info)
132
 
133
  return examples
134
 
135
- def load_example_style():
136
- folder_path = 'assets/ref'
137
- examples = []
138
- for filename in os.listdir(folder_path):
139
- if filename.endswith((".png")):
140
- image_path = os.path.join(folder_path, filename)
141
- image_name = os.path.basename(image_path)
142
- style_name = image_name.split('_')[1]
143
-
144
- config_path = './config/{}.json'.format(style_name)
145
- config = parse_config(config_path)
146
- inf_object_name = config["inference_info"]["inf_object_list"][0]
147
 
148
- image_info = [image_path, style_name, inf_object_name, 1, 50]
149
- examples.append(image_info)
150
 
151
- return examples
152
 
153
  @spaces.GPU
154
- def style_fn(image_path, style_name, content_text, output_number, diffusion_step=50):
 
155
  user_image_flag = not style_name.strip() # empty
156
 
157
  if not user_image_flag:
@@ -272,131 +202,23 @@ def style_fn(image_path, style_name, content_text, output_number, diffusion_step
272
  )[0][1:]
273
 
274
  n_row = 1
275
- n_col = len(inf_seeds)
276
 
277
  # make grid
278
  grid = create_image_grid(images, n_row, n_col, padding=10)
279
 
280
  return grid
281
 
282
- @spaces.GPU
283
- def controlnet_fn(image_path, depth_image_path, style_name, content_text, output_number, controlnet_scale=0.5, diffusion_step=50):
284
- config_path = './config/{}.json'.format(style_name)
285
- config = parse_config(config_path)
286
-
287
- inf_object = content_text
288
- inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
289
- # inf_seeds = [i for i in range(int(output_number))]
290
-
291
- activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
292
- activate_step_indices_list = config['inference_info']['activate_step_indices_list']
293
- ref_seed = config['reference_info']['ref_seeds'][0]
294
-
295
- attn_map_save_steps = config['inference_info']['attn_map_save_steps']
296
- guidance_scale = config['guidance_scale']
297
- use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
298
-
299
- style_name = config["style_name_list"][0]
300
-
301
- ref_object = config["reference_info"]["ref_object_list"][0]
302
- ref_with_style_description = config['reference_info']['with_style_description']
303
- inf_with_style_description = config['inference_info']['with_style_description']
304
-
305
- use_shared_attention = config['inference_info']['use_shared_attention']
306
- adain_queries = config['inference_info']['adain_queries']
307
- adain_keys = config['inference_info']['adain_keys']
308
- adain_values = config['inference_info']['adain_values']
309
-
310
- use_advanced_sampling = config['inference_info']['use_advanced_sampling']
311
-
312
- #get canny edge array
313
- depth_image = get_depth_edge_array(depth_image_path)
314
-
315
- style_description_pos, style_description_neg = STYLE_DESCRIPTION_DICT[style_name][0], \
316
- STYLE_DESCRIPTION_DICT[style_name][1]
317
-
318
- # Inference
319
- with torch.inference_mode():
320
- grid = None
321
- if ref_with_style_description:
322
- ref_prompt = style_description_pos.replace("{object}", ref_object)
323
- else:
324
- ref_prompt = ref_object
325
-
326
- if inf_with_style_description:
327
- inf_prompt = style_description_pos.replace("{object}", inf_object)
328
- else:
329
- inf_prompt = inf_object
330
-
331
- for activate_layer_indices in activate_layer_indices_list:
332
-
333
- for activate_step_indices in activate_step_indices_list:
334
-
335
- str_activate_layer, str_activate_step = model_controlnet.activate_layer(
336
- activate_layer_indices=activate_layer_indices,
337
- attn_map_save_steps=attn_map_save_steps,
338
- activate_step_indices=activate_step_indices,
339
- use_shared_attention=use_shared_attention,
340
- adain_queries=adain_queries,
341
- adain_keys=adain_keys,
342
- adain_values=adain_values,
343
- )
344
-
345
- # ref_latent = model_controlnet.get_init_latent(ref_seed, precomputed_path=None)
346
- ref_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=ref_seed)
347
- latents = [ref_latent]
348
-
349
- for inf_seed in inf_seeds:
350
- # latents.append(model_controlnet.get_init_latent(inf_seed, precomputed_path=None))
351
- inf_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=inf_seed)
352
- latents.append(inf_latent)
353
-
354
-
355
- latents = torch.cat(latents, dim=0)
356
- latents.to(device)
357
-
358
- images = model_controlnet.generated_ve_inference(
359
- prompt=ref_prompt,
360
- negative_prompt=style_description_neg,
361
- guidance_scale=guidance_scale,
362
- num_inference_steps=diffusion_step,
363
- controlnet_conditioning_scale=controlnet_scale,
364
- latents=latents,
365
- num_images_per_prompt=len(inf_seeds) + 1,
366
- target_prompt=inf_prompt,
367
- image=depth_image,
368
- use_inf_negative_prompt=use_inf_negative_prompt,
369
- use_advanced_sampling=use_advanced_sampling
370
- )[0][1:]
371
-
372
- n_row = 1
373
- n_col = len(inf_seeds) # μ›λ³ΈμΆ”κ°€ν•˜λ €λ©΄ + 1
374
-
375
- # make grid
376
- grid = create_image_grid(images, n_row, n_col)
377
-
378
- return grid
379
-
380
-
381
  description_md = """
382
 
383
  ### We introduce `Visual Style Prompting`, which reflects the style of a reference image to the images generated by a pretrained text-to-image diffusion model without finetuning or optimization (e.g., Figure N).
384
  ### πŸ“– [[Paper](https://arxiv.org/abs/2402.12974)] | ✨ [[Project page](https://curryjung.github.io/VisualStylePrompt)] | ✨ [[Code](https://github.com/naver-ai/Visual-Style-Prompting)]
 
385
  ---
386
- ### πŸ‘‰ To better reflect the style of a user's image, the higher the resolution, the better.
387
  ### πŸ”₯ To try out our vanilla demo,
388
  1. Choose a `style reference` from the collection of images below.
389
  2. Enter the `text prompt`.
390
  3. Choose the `number of outputs`.
391
- ---
392
- ### ✨ Visual Style Prompting also works on `ControlNet` which specifies the shape of the results by depthmap or keypoints.
393
- ### ‼️ w/ ControlNet ver does not support user style images.
394
- ### πŸ”₯ To try out our demo with ControlNet,
395
- 1. Upload an `image for depth control`. An off-the-shelf model will produce the depthmap from it.
396
- 2. Choose `ControlNet scale` which determines the alignment to the depthmap.
397
- 3. Choose a `style reference` from the collection of images below.
398
- 4. Enter the `text prompt`. (`Empty text` is okay, but a depthmap description helps.)
399
- 5. Choose the `number of outputs`.
400
 
401
  ### πŸ‘‰ To achieve faster results, we recommend lowering the diffusion steps to 30.
402
  ### Enjoy ! πŸ˜„
@@ -417,22 +239,4 @@ iface_style = gr.Interface(
417
  examples=load_example_style(),
418
  )
419
 
420
- iface_controlnet = gr.Interface(
421
- fn=controlnet_fn,
422
- inputs=[
423
- gr.components.Image(label="Style image"),
424
- gr.components.Image(label="Depth image"),
425
- gr.components.Textbox(label='Style name', visible=False),
426
- gr.components.Textbox(label="Text prompt", placeholder="Enter Text prompt"),
427
- gr.components.Textbox(label="Number of outputs", placeholder="Enter Number of outputs"),
428
- gr.components.Slider(minimum=0.5, maximum=10, step=0.5, value=0.5, label="Controlnet scale"),
429
- gr.components.Slider(minimum=10, maximum=50, step=10, value=50, label="Diffusion steps")
430
- ],
431
- outputs=gr.components.Image(label="Generated Image"),
432
- title="🎨 Visual Style Prompting (w/ ControlNet)",
433
- description=description_md,
434
- examples=load_example_controlnet(),
435
- )
436
-
437
- iface = gr.TabbedInterface([iface_style, iface_controlnet], ["Vanilla", "w/ ControlNet"])
438
- iface.launch(debug=True)
 
5
  import os, json
6
  import numpy as np
7
  from PIL import Image
 
 
8
  from pipelines.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
 
 
9
  from random import randint
10
  from utils import init_latent
11
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
12
  from diffusers import DDIMScheduler
13
 
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ if device == 'cpu':
16
+ torch_dtype = torch.float32
17
+ else:
18
+ torch_dtype = torch.float16
19
 
 
 
20
  def memory_efficient(model):
21
  try:
22
  model.to(device)
 
31
  model.enable_vae_slicing()
32
  except AttributeError:
33
  print("enable_vae_slicing is not supported.")
34
+ if device == 'cuda':
35
+ try:
36
+ model.enable_xformers_memory_efficient_attention()
37
+ except AttributeError:
38
+ print("enable_xformers_memory_efficient_attention is not supported.")
 
 
 
 
 
 
 
39
 
40
  model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
 
 
 
 
 
 
 
41
  print("SDXL")
42
  memory_efficient(model)
43
 
 
 
 
44
  blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
45
  blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch_dtype).to(device)
46
 
 
52
  config = json.load(f)
53
  return config
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def load_example_style():
 
 
 
 
 
57
  folder_path = 'assets/ref'
58
  examples = []
59
  for filename in os.listdir(folder_path):
 
66
  config = parse_config(config_path)
67
  inf_object_name = config["inference_info"]["inf_object_list"][0]
68
 
69
+ image_info = [image_path, style_name, inf_object_name, 1, 50]
 
 
70
  examples.append(image_info)
71
 
72
  return examples
73
 
74
+ def blip_inf_prompt(image):
75
+ inputs = blip_processor(images=image, return_tensors="pt").to(device, torch.float16)
 
 
 
 
 
 
 
 
 
 
76
 
77
+ generated_ids = blip_model.generate(**inputs)
78
+ generated_text = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
79
 
80
+ return generated_text
81
 
82
  @spaces.GPU
83
+ def style_fn(image_path, style_name, content_text, output_number=1, diffusion_step=50):
84
+
85
  user_image_flag = not style_name.strip() # empty
86
 
87
  if not user_image_flag:
 
202
  )[0][1:]
203
 
204
  n_row = 1
205
+ n_col = len(inf_seeds) + 1 # μ›λ³ΈμΆ”κ°€ν•˜λ €λ©΄ + 1
206
 
207
  # make grid
208
  grid = create_image_grid(images, n_row, n_col, padding=10)
209
 
210
  return grid
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  description_md = """
213
 
214
  ### We introduce `Visual Style Prompting`, which reflects the style of a reference image to the images generated by a pretrained text-to-image diffusion model without finetuning or optimization (e.g., Figure N).
215
  ### πŸ“– [[Paper](https://arxiv.org/abs/2402.12974)] | ✨ [[Project page](https://curryjung.github.io/VisualStylePrompt)] | ✨ [[Code](https://github.com/naver-ai/Visual-Style-Prompting)]
216
+ ### πŸ”₯ [[w/ Controlnet ver](https://huggingface.co/spaces/naver-ai/VisualStylePrompting_Controlnet)]
217
  ---
 
218
  ### πŸ”₯ To try out our vanilla demo,
219
  1. Choose a `style reference` from the collection of images below.
220
  2. Enter the `text prompt`.
221
  3. Choose the `number of outputs`.
 
 
 
 
 
 
 
 
 
222
 
223
  ### πŸ‘‰ To achieve faster results, we recommend lowering the diffusion steps to 30.
224
  ### Enjoy ! πŸ˜„
 
239
  examples=load_example_style(),
240
  )
241
 
242
+ iface_style.launch(debug=True)