John6666 commited on
Commit
09fcf0c
·
verified ·
1 Parent(s): 8afe134

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. mod.py +8 -4
app.py CHANGED
@@ -48,7 +48,7 @@ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_ret
48
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
49
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
50
  @spaces.GPU()
51
- def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)): # , progress=gr.Progress(track_tqdm=True) # gradio.exceptions.Error: 'Model load Error: too many values to unpack (expected 2)'
52
  global pipe
53
  global controlnet_union
54
  global controlnet
@@ -59,7 +59,7 @@ def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm
59
  if cn_on:
60
  #progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
61
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
62
- #clear_cache()
63
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)#.to(device)
64
  controlnet = FluxMultiControlNetModel([controlnet_union])#.to(device)
65
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)#.to(device)
@@ -71,7 +71,7 @@ def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm
71
  else:
72
  #progress(0, desc=f"Loading model: {repo_id}")
73
  print(f"Loading model: {repo_id}")
74
- #clear_cache()
75
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)#, vae=taef1 .to(device)
76
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
77
  last_model = repo_id
 
48
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
49
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
50
  @spaces.GPU()
51
+ def change_base_model(repo_id: str, cn_on: bool): # , progress=gr.Progress(track_tqdm=True) # gradio.exceptions.Error: 'Model load Error: too many values to unpack (expected 2)'
52
  global pipe
53
  global controlnet_union
54
  global controlnet
 
59
  if cn_on:
60
  #progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
61
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
62
+ clear_cache()
63
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)#.to(device)
64
  controlnet = FluxMultiControlNetModel([controlnet_union])#.to(device)
65
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)#.to(device)
 
71
  else:
72
  #progress(0, desc=f"Loading model: {repo_id}")
73
  print(f"Loading model: {repo_id}")
74
+ clear_cache()
75
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)#, vae=taef1 .to(device)
76
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
77
  last_model = repo_id
mod.py CHANGED
@@ -74,10 +74,13 @@ def is_repo_exists(repo_id):
74
 
75
 
76
  def clear_cache():
77
- torch.cuda.empty_cache()
78
- torch.cuda.reset_max_memory_allocated()
79
- torch.cuda.reset_peak_memory_stats()
80
- gc.collect()
 
 
 
81
 
82
 
83
  def deselect_lora():
@@ -348,3 +351,4 @@ load_prompt_enhancer.zerogpu = True
348
  fuse_loras.zerogpu = True
349
  preprocess_image.zerogpu = True
350
  get_control_params.zerogpu = True
 
 
74
 
75
 
76
  def clear_cache():
77
+ try:
78
+ torch.cuda.empty_cache()
79
+ torch.cuda.reset_max_memory_allocated()
80
+ torch.cuda.reset_peak_memory_stats()
81
+ gc.collect()
82
+ except Exception as e:
83
+ print(e)
84
 
85
 
86
  def deselect_lora():
 
351
  fuse_loras.zerogpu = True
352
  preprocess_image.zerogpu = True
353
  get_control_params.zerogpu = True
354
+ clear_cache.zerogpu = True