jiuface commited on
Commit
92f4fa3
·
verified ·
1 Parent(s): a1c8c21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -21
app.py CHANGED
@@ -9,7 +9,7 @@ import logging
9
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
10
  from huggingface_hub import login
11
  from diffusers.utils import load_image
12
-
13
  import time
14
  from datetime import datetime
15
  from io import BytesIO
@@ -34,19 +34,13 @@ base_model = "black-forest-labs/FLUX.1-dev"
34
  # load pipe
35
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
37
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
38
 
39
- # img2img model
40
- img2img = AutoPipelineForImage2Image.from_pretrained(base_model,
41
- vae=good_vae,
42
- transformer=pipe.transformer,
43
- text_encoder=pipe.text_encoder,
44
- tokenizer=pipe.tokenizer,
45
- text_encoder_2=pipe.text_encoder_2,
46
- tokenizer_2=pipe.tokenizer_2,
47
- torch_dtype=dtype
48
- )
49
 
 
 
 
50
 
51
 
52
  MAX_SEED = 2**32 - 1
@@ -76,15 +70,19 @@ def generate_image(orginal_image, prompt, adapter_names, steps, seed, image_str
76
 
77
 
78
  gr.Info("Start to generate images ...")
79
- with calculateDuration(f"Make a new generator:{seed}"):
80
- pipe.to(device)
 
 
 
 
81
  generator = torch.Generator(device=device).manual_seed(seed)
82
 
83
  with calculateDuration("Generating image"):
84
  # Generate image
85
- joint_attention_kwargs = {"scale": 1}
86
  if orginal_image:
87
- generated_image = img2img(
88
  prompt=prompt,
89
  image=orginal_image,
90
  strength=image_strength,
@@ -96,7 +94,7 @@ def generate_image(orginal_image, prompt, adapter_names, steps, seed, image_str
96
  joint_attention_kwargs=joint_attention_kwargs
97
  ).images[0]
98
  else:
99
- generated_image = pipe(
100
  prompt=prompt,
101
  num_inference_steps=steps,
102
  guidance_scale=cfg_scale,
@@ -189,18 +187,18 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
189
  if lora_repo and weights and adapter_name:
190
  try:
191
  if img2img_model:
192
- img2img.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
193
  else:
194
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
195
  except:
196
  print("load lora error")
197
 
198
  # set lora weights
199
  if len(adapter_names) > 0:
200
  if img2img_model:
201
- img2img.set_adapters(adapter_names, adapter_weights=adapter_weights)
202
  else:
203
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
204
 
205
 
206
  # Generate image
 
9
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
10
  from huggingface_hub import login
11
  from diffusers.utils import load_image
12
+ from lora_loading_patch import load_lora_into_transformer
13
  import time
14
  from datetime import datetime
15
  from io import BytesIO
 
34
  # load pipe
35
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
 
37
 
38
+ txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
39
+ txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
 
 
 
 
 
 
 
 
40
 
41
+ # img2img model
42
+ img2img_pipe = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=txt2img_pipe.transformer, text_encoder=txt2img_pipe.text_encoder, tokenizer=txt2img_pipe.tokenizer, text_encoder_2=txt2img_pipe.text_encoder_2, tokenizer_2=txt2img_pipe.tokenizer_2, torch_dtype=dtype)
43
+ img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
44
 
45
 
46
  MAX_SEED = 2**32 - 1
 
70
 
71
 
72
  gr.Info("Start to generate images ...")
73
+ with calculateDuration(f"Make a new generator: {seed}"):
74
+ if orginal_image:
75
+ img2img_pipe.to(device)
76
+ else:
77
+ txt2img_pipe.to(device)
78
+
79
  generator = torch.Generator(device=device).manual_seed(seed)
80
 
81
  with calculateDuration("Generating image"):
82
  # Generate image
83
+ joint_attention_kwargs = {"scale": 1}
84
  if orginal_image:
85
+ generated_image = img2img_pipe(
86
  prompt=prompt,
87
  image=orginal_image,
88
  strength=image_strength,
 
94
  joint_attention_kwargs=joint_attention_kwargs
95
  ).images[0]
96
  else:
97
+ generated_image = txt2img_pipe(
98
  prompt=prompt,
99
  num_inference_steps=steps,
100
  guidance_scale=cfg_scale,
 
187
  if lora_repo and weights and adapter_name:
188
  try:
189
  if img2img_model:
190
+ img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
191
  else:
192
+ txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
193
  except:
194
  print("load lora error")
195
 
196
  # set lora weights
197
  if len(adapter_names) > 0:
198
  if img2img_model:
199
+ img2img_pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
200
  else:
201
+ txt2img_pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
202
 
203
 
204
  # Generate image