John6666 commited on
Commit
650af8f
·
verified ·
1 Parent(s): 4641302

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -44
app.py CHANGED
@@ -4,8 +4,7 @@ import json
4
  import logging
5
  import torch
6
  from PIL import Image
7
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
8
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
10
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
11
  import copy
@@ -22,29 +21,16 @@ from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_
22
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
23
  from tagger.fl2flux import predict_tags_fl2_flux
24
 
25
- # Load LoRAs from JSON file
26
- with open('loras.json', 'r') as f:
27
- loras = json.load(f)
28
-
29
- dtype = torch.bfloat16
30
- #dtype = torch.float8_e4m3fn
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
  # Initialize the base model
33
  base_model = models[0]
34
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
35
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
36
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
38
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
39
  controlnet_union = None
40
  controlnet = None
41
  last_model = models[0]
42
  last_cn_on = False
43
 
44
- MAX_SEED = 2**32-1
45
-
46
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
47
-
48
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
49
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
50
  def change_base_model(repo_id: str, cn_on: bool):
@@ -53,6 +39,8 @@ def change_base_model(repo_id: str, cn_on: bool):
53
  global controlnet
54
  global last_model
55
  global last_cn_on
 
 
56
  try:
57
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
58
  if cn_on:
@@ -62,7 +50,6 @@ def change_base_model(repo_id: str, cn_on: bool):
62
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
63
  controlnet = FluxMultiControlNetModel([controlnet_union])
64
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
65
- #pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
66
  last_model = repo_id
67
  last_cn_on = cn_on
68
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
@@ -71,8 +58,7 @@ def change_base_model(repo_id: str, cn_on: bool):
71
  #progress(0, desc=f"Loading model: {repo_id}")
72
  print(f"Loading model: {repo_id}")
73
  clear_cache()
74
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, vae=taef1)
75
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
76
  last_model = repo_id
77
  last_cn_on = cn_on
78
  #progress(1, desc=f"Model loaded: {repo_id}")
@@ -84,6 +70,12 @@ def change_base_model(repo_id: str, cn_on: bool):
84
 
85
  change_base_model.zerogpu = True
86
 
 
 
 
 
 
 
87
  class calculateDuration:
88
  def __init__(self, activity_name=""):
89
  self.activity_name = activity_name
@@ -123,13 +115,9 @@ def update_selection(evt: gr.SelectData, width, height):
123
  @spaces.GPU(duration=70)
124
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
125
  global pipe
126
- global taef1
127
- global good_vae
128
  global controlnet
129
  global controlnet_union
130
  try:
131
- good_vae.to("cuda")
132
- taef1.to("cuda")
133
  pipe.to("cuda")
134
  generator = torch.Generator(device="cuda").manual_seed(seed)
135
 
@@ -138,7 +126,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
138
  modes, images, scales = get_control_params()
139
  if not cn_on or len(modes) == 0:
140
  progress(0, desc="Start Inference.")
141
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
142
  prompt=prompt_mash,
143
  num_inference_steps=steps,
144
  guidance_scale=cfg_scale,
@@ -146,15 +134,12 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
146
  height=height,
147
  generator=generator,
148
  joint_attention_kwargs={"scale": lora_scale},
149
- output_type="pil",
150
- good_vae=good_vae,
151
- ):
152
- yield img
153
  else:
154
  progress(0, desc="Start Inference with ControlNet.")
155
  if controlnet is not None: controlnet.to("cuda")
156
  if controlnet_union is not None: controlnet_union.to("cuda")
157
- for img in pipe(
158
  prompt=prompt_mash,
159
  control_image=images,
160
  control_mode=modes,
@@ -165,19 +150,15 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
165
  controlnet_conditioning_scale=scales,
166
  generator=generator,
167
  joint_attention_kwargs={"scale": lora_scale},
168
- ).images:
169
- yield img
170
  except Exception as e:
171
  print(e)
172
  raise gr.Error(f"Inference Error: {e}")
 
173
 
174
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
175
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
176
  global pipe
177
- global taef1
178
- global good_vae
179
- global controlnet
180
- global controlnet_union
181
  if selected_index is None and not is_valid_lora(lora_json):
182
  gr.Info("LoRA isn't selected.")
183
  # raise gr.Error("You must select a LoRA before proceeding.")
@@ -216,23 +197,17 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
216
  seed = random.randint(0, MAX_SEED)
217
 
218
  progress(0, desc="Running Inference.")
219
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
220
- # Consume the generator to get the final image
221
- final_image = None
222
- for image in image_generator:
223
- final_image = image
224
- yield image, seed # Yield intermediate images and seed
225
  if is_valid_lora(lora_json):
226
  pipe.unfuse_lora()
227
  pipe.unload_lora_weights()
228
  if selected_index is not None: pipe.unload_lora_weights()
229
  pipe.to("cpu")
230
- good_vae.to("cpu")
231
- taef1.to("cpu")
232
  if controlnet is not None: controlnet.to("cpu")
233
  if controlnet_union is not None: controlnet_union.to("cpu")
234
  clear_cache()
235
- return final_image, seed # Return the final image and seed
236
 
237
  def get_huggingface_safetensors(link):
238
  split_link = link.split("/")
 
4
  import logging
5
  import torch
6
  from PIL import Image
7
+ from diffusers import DiffusionPipeline
 
8
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
  import copy
 
21
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
22
  from tagger.fl2flux import predict_tags_fl2_flux
23
 
 
 
 
 
 
 
 
24
  # Initialize the base model
25
  base_model = models[0]
26
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
27
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
28
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
 
29
  controlnet_union = None
30
  controlnet = None
31
  last_model = models[0]
32
  last_cn_on = False
33
 
 
 
 
 
34
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
35
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
36
  def change_base_model(repo_id: str, cn_on: bool):
 
39
  global controlnet
40
  global last_model
41
  global last_cn_on
42
+ dtype = torch.bfloat16
43
+ #dtype = torch.float8_e4m3fn
44
  try:
45
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
46
  if cn_on:
 
50
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
51
  controlnet = FluxMultiControlNetModel([controlnet_union])
52
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
 
53
  last_model = repo_id
54
  last_cn_on = cn_on
55
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
 
58
  #progress(0, desc=f"Loading model: {repo_id}")
59
  print(f"Loading model: {repo_id}")
60
  clear_cache()
61
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
 
62
  last_model = repo_id
63
  last_cn_on = cn_on
64
  #progress(1, desc=f"Model loaded: {repo_id}")
 
70
 
71
  change_base_model.zerogpu = True
72
 
73
+ # Load LoRAs from JSON file
74
+ with open('loras.json', 'r') as f:
75
+ loras = json.load(f)
76
+
77
+ MAX_SEED = 2**32-1
78
+
79
  class calculateDuration:
80
  def __init__(self, activity_name=""):
81
  self.activity_name = activity_name
 
115
  @spaces.GPU(duration=70)
116
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
117
  global pipe
 
 
118
  global controlnet
119
  global controlnet_union
120
  try:
 
 
121
  pipe.to("cuda")
122
  generator = torch.Generator(device="cuda").manual_seed(seed)
123
 
 
126
  modes, images, scales = get_control_params()
127
  if not cn_on or len(modes) == 0:
128
  progress(0, desc="Start Inference.")
129
+ image = pipe(
130
  prompt=prompt_mash,
131
  num_inference_steps=steps,
132
  guidance_scale=cfg_scale,
 
134
  height=height,
135
  generator=generator,
136
  joint_attention_kwargs={"scale": lora_scale},
137
+ ).images[0]
 
 
 
138
  else:
139
  progress(0, desc="Start Inference with ControlNet.")
140
  if controlnet is not None: controlnet.to("cuda")
141
  if controlnet_union is not None: controlnet_union.to("cuda")
142
+ image = pipe(
143
  prompt=prompt_mash,
144
  control_image=images,
145
  control_mode=modes,
 
150
  controlnet_conditioning_scale=scales,
151
  generator=generator,
152
  joint_attention_kwargs={"scale": lora_scale},
153
+ ).images[0]
 
154
  except Exception as e:
155
  print(e)
156
  raise gr.Error(f"Inference Error: {e}")
157
+ return image
158
 
159
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
160
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
161
  global pipe
 
 
 
 
162
  if selected_index is None and not is_valid_lora(lora_json):
163
  gr.Info("LoRA isn't selected.")
164
  # raise gr.Error("You must select a LoRA before proceeding.")
 
197
  seed = random.randint(0, MAX_SEED)
198
 
199
  progress(0, desc="Running Inference.")
200
+
201
+ image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
 
 
 
 
202
  if is_valid_lora(lora_json):
203
  pipe.unfuse_lora()
204
  pipe.unload_lora_weights()
205
  if selected_index is not None: pipe.unload_lora_weights()
206
  pipe.to("cpu")
 
 
207
  if controlnet is not None: controlnet.to("cpu")
208
  if controlnet_union is not None: controlnet_union.to("cpu")
209
  clear_cache()
210
+ return image, seed
211
 
212
  def get_huggingface_safetensors(link):
213
  split_link = link.split("/")