jiuface commited on
Commit
55e274c
·
verified ·
1 Parent(s): 35b4517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -30,16 +30,16 @@ import diffusers
30
 
31
  # init
32
  dtype = torch.bfloat16
33
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
34
  print(device)
35
  base_model = "black-forest-labs/FLUX.1-dev"
36
 
37
  # load pipe
38
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
39
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
40
 
41
- txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
42
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
 
43
 
44
 
45
  MAX_SEED = 2**32 - 1
@@ -108,6 +108,7 @@ def generate_random_4_digit_string():
108
  return ''.join(random.choices(string.digits, k=4))
109
 
110
  @spaces.GPU(duration=120)
 
111
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
112
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
113
  gr.Info("Starting process")
 
30
 
31
  # init
32
  dtype = torch.bfloat16
33
+ device = "cuda"
34
+
35
  print(device)
36
  base_model = "black-forest-labs/FLUX.1-dev"
37
 
38
  # load pipe
 
 
39
 
40
+ txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
41
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
42
+ txt2img_pipe = txt2img_pipe.to(device)
43
 
44
 
45
  MAX_SEED = 2**32 - 1
 
108
  return ''.join(random.choices(string.digits, k=4))
109
 
110
  @spaces.GPU(duration=120)
111
+
112
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
113
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
114
  gr.Info("Starting process")