jiuface commited on
Commit
521316d
·
verified ·
1 Parent(s): c21899b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -30,7 +30,7 @@ import diffusers
30
 
31
  # init
32
  dtype = torch.bfloat16
33
- device = "cuda"
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
36
  # load pipe
@@ -42,7 +42,7 @@ txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_t
42
 
43
  # img2img model
44
  img2img_pipe = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=txt2img_pipe.transformer, text_encoder=txt2img_pipe.text_encoder, tokenizer=txt2img_pipe.tokenizer, text_encoder_2=txt2img_pipe.text_encoder_2, tokenizer_2=txt2img_pipe.tokenizer_2, torch_dtype=dtype)
45
- # img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
46
 
47
 
48
  MAX_SEED = 2**32 - 1
@@ -152,13 +152,16 @@ def generate_random_4_digit_string():
152
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
153
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
154
  gr.Info("Starting process")
155
- txt2img_pipe.to("cuda")
156
  torch.cuda.empty_cache()
157
  img2img_model = False
158
  orginal_image = None
159
  if image_url:
160
  orginal_image = load_image(image_url)
161
  img2img_model = True
 
 
 
162
  # Set random seed for reproducibility
163
  if randomize_seed:
164
  with calculateDuration("Set random seed"):
 
30
 
31
  # init
32
  dtype = torch.bfloat16
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
36
  # load pipe
 
42
 
43
  # img2img model
44
  img2img_pipe = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=txt2img_pipe.transformer, text_encoder=txt2img_pipe.text_encoder, tokenizer=txt2img_pipe.tokenizer, text_encoder_2=txt2img_pipe.text_encoder_2, tokenizer_2=txt2img_pipe.tokenizer_2, torch_dtype=dtype)
45
+ img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
46
 
47
 
48
  MAX_SEED = 2**32 - 1
 
152
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
153
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
154
  gr.Info("Starting process")
155
+
156
  torch.cuda.empty_cache()
157
  img2img_model = False
158
  orginal_image = None
159
  if image_url:
160
  orginal_image = load_image(image_url)
161
  img2img_model = True
162
+ img2img_pipe.to(device)
163
+ else:
164
+ txt2img_pipe.to(device)
165
  # Set random seed for reproducibility
166
  if randomize_seed:
167
  with calculateDuration("Set random seed"):