jiuface commited on
Commit
e60ffba
·
verified ·
1 Parent(s): 625a0c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -40,6 +40,7 @@ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtyp
40
 
41
  txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
42
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
 
43
 
44
 
45
  MAX_SEED = 2**32 - 1
@@ -123,6 +124,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
123
  # Load LoRA weights
124
  gr.Info("Start to load LoRA ...")
125
  with calculateDuration("Unloading LoRA"):
 
126
  txt2img_pipe.unload_lora_weights()
127
  print(device)
128
  lora_configs = None
@@ -139,7 +141,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
139
 
140
  with calculateDuration("Loading LoRA weights"):
141
  adapter_weights = []
142
-
143
  for idx, lora_info in enumerate(lora_configs):
144
  lora_repo = lora_info.get("repo")
145
  weights = lora_info.get("weights")
@@ -165,10 +167,12 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
165
  try:
166
  gr.Info("Start to generate images ...")
167
  with calculateDuration(f"Make a new generator: {seed}"):
 
168
  generator = torch.Generator(device=device).manual_seed(seed)
169
  print(device)
170
  with calculateDuration("Generating image"):
171
  # Generate image
 
172
  joint_attention_kwargs = {"scale": 1}
173
  final_image = txt2img_pipe(
174
  prompt=prompt,
 
40
 
41
  txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
42
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
43
+ txt2img_pipe.to(device)
44
 
45
 
46
  MAX_SEED = 2**32 - 1
 
124
  # Load LoRA weights
125
  gr.Info("Start to load LoRA ...")
126
  with calculateDuration("Unloading LoRA"):
127
+ txt2img_pipe.to(device)
128
  txt2img_pipe.unload_lora_weights()
129
  print(device)
130
  lora_configs = None
 
141
 
142
  with calculateDuration("Loading LoRA weights"):
143
  adapter_weights = []
144
+ txt2img_pipe.to(device)
145
  for idx, lora_info in enumerate(lora_configs):
146
  lora_repo = lora_info.get("repo")
147
  weights = lora_info.get("weights")
 
167
  try:
168
  gr.Info("Start to generate images ...")
169
  with calculateDuration(f"Make a new generator: {seed}"):
170
+ txt2img_pipe.to(device)
171
  generator = torch.Generator(device=device).manual_seed(seed)
172
  print(device)
173
  with calculateDuration("Generating image"):
174
  # Generate image
175
+ txt2img_pipe.to(device)
176
  joint_attention_kwargs = {"scale": 1}
177
  final_image = txt2img_pipe(
178
  prompt=prompt,