John6666 commited on
Commit
465d252
·
verified ·
1 Parent(s): 6bc8c1e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -33,9 +33,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
33
  base_model = models[0]
34
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
35
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
36
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
38
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
39
  controlnet_union = None
40
  controlnet = None
41
  last_model = models[0]
@@ -61,7 +61,7 @@ def change_base_model(repo_id: str, cn_on: bool):
61
  clear_cache()
62
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
63
  controlnet = FluxMultiControlNetModel([controlnet_union])
64
- pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
65
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
66
  last_model = repo_id
67
  last_cn_on = cn_on
@@ -71,7 +71,7 @@ def change_base_model(repo_id: str, cn_on: bool):
71
  #progress(0, desc=f"Loading model: {repo_id}")
72
  print(f"Loading model: {repo_id}")
73
  clear_cache()
74
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
75
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
76
  last_model = repo_id
77
  last_cn_on = cn_on
@@ -123,9 +123,13 @@ def update_selection(evt: gr.SelectData, width, height):
123
  @spaces.GPU(duration=70)
124
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
125
  global pipe
 
 
126
  global controlnet
127
  global controlnet_union
128
  try:
 
 
129
  pipe.to("cuda")
130
  generator = torch.Generator(device="cuda").manual_seed(seed)
131
 
@@ -172,6 +176,10 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
172
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
173
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
174
  global pipe
 
 
 
 
175
  if selected_index is None and not is_valid_lora(lora_json):
176
  gr.Info("LoRA isn't selected.")
177
  # raise gr.Error("You must select a LoRA before proceeding.")
@@ -221,6 +229,8 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
221
  pipe.unload_lora_weights()
222
  if selected_index is not None: pipe.unload_lora_weights()
223
  pipe.to("cpu")
 
 
224
  if controlnet is not None: controlnet.to("cpu")
225
  if controlnet_union is not None: controlnet_union.to("cpu")
226
  clear_cache()
 
33
  base_model = models[0]
34
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
35
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
36
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
37
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype)
38
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1)
39
  controlnet_union = None
40
  controlnet = None
41
  last_model = models[0]
 
61
  clear_cache()
62
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
63
  controlnet = FluxMultiControlNetModel([controlnet_union])
64
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, vae=taef1)
65
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
66
  last_model = repo_id
67
  last_cn_on = cn_on
 
71
  #progress(0, desc=f"Loading model: {repo_id}")
72
  print(f"Loading model: {repo_id}")
73
  clear_cache()
74
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, vae=taef1)
75
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
76
  last_model = repo_id
77
  last_cn_on = cn_on
 
123
  @spaces.GPU(duration=70)
124
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
125
  global pipe
126
+ global taef1
127
+ global good_vae
128
  global controlnet
129
  global controlnet_union
130
  try:
131
+ good_vae.to("cuda")
132
+ taef1.to("cuda")
133
  pipe.to("cuda")
134
  generator = torch.Generator(device="cuda").manual_seed(seed)
135
 
 
176
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
177
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
178
  global pipe
179
+ global taef1
180
+ global good_vae
181
+ global controlnet
182
+ global controlnet_union
183
  if selected_index is None and not is_valid_lora(lora_json):
184
  gr.Info("LoRA isn't selected.")
185
  # raise gr.Error("You must select a LoRA before proceeding.")
 
229
  pipe.unload_lora_weights()
230
  if selected_index is not None: pipe.unload_lora_weights()
231
  pipe.to("cpu")
232
+ good_vae.to("cpu")
233
+ taef1.to("cpu")
234
  if controlnet is not None: controlnet.to("cpu")
235
  if controlnet_union is not None: controlnet_union.to("cpu")
236
  clear_cache()