jiuface commited on
Commit
16c2491
·
verified ·
1 Parent(s): 422bc49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -20,18 +20,18 @@ from io import BytesIO
20
  import re
21
  import json
22
 
23
- # 登录 Hugging Face Hub
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
  login(token=HF_TOKEN)
26
  import diffusers
27
  print(diffusers.__version__)
28
 
29
- # 初始化
30
- dtype = torch.float16 # 您可以根据需要调整数据类型
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
- base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型
33
 
34
- # 加载管道
35
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
36
 
37
  MAX_SEED = 2**32 - 1
@@ -56,10 +56,7 @@ class calculateDuration:
56
  else:
57
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
58
 
59
- print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
60
-
61
- # 生成图像的函数
62
- @spaces.GPU
63
  @torch.inference_mode()
64
  def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
65
  pipe.to(device)
@@ -121,10 +118,11 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
121
  adapter_name = lora_info.get("adapter_name")
122
  adapter_weight = lora_info.get("adapter_weight")
123
  if lora_repo and weights and adapter_name:
124
- # 加载 LoRA 权重
125
  pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
126
  adapter_names.append(adapter_name)
127
  adapter_weights.append(adapter_weight)
 
128
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
129
 
130
  # Set random seed for reproducibility
@@ -133,8 +131,11 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
133
  seed = random.randint(0, MAX_SEED)
134
 
135
  # Generate image
136
- final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
137
-
 
 
 
138
  if final_image:
139
  if upload_to_r2:
140
  with calculateDuration("Upload image"):
@@ -142,12 +143,15 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
142
  result = {"status": "success", "message": "upload image success", "url": url}
143
  else:
144
  result = {"status": "success", "message": "Image generated but not uploaded"}
145
-
 
 
146
  progress(100, "Completed!")
147
 
148
  return final_image, seed, json.dumps(result)
149
 
150
- # Gradio 界面
 
151
  css="""
152
  #col-container {
153
  margin: 0 auto;
@@ -156,7 +160,7 @@ css="""
156
  """
157
 
158
  with gr.Blocks(css=css) as demo:
159
- gr.Markdown("Flux with LoRA")
160
  with gr.Row():
161
 
162
  with gr.Column():
 
20
  import re
21
  import json
22
 
23
+ # Login Hugging Face Hub
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
  login(token=HF_TOKEN)
26
  import diffusers
27
  print(diffusers.__version__)
28
 
29
+ # init
30
+ dtype = torch.float16 # use float16 for fast generate
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ base_model = "black-forest-labs/FLUX.1-dev"
33
 
34
+ # load pipe
35
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
36
 
37
  MAX_SEED = 2**32 - 1
 
56
  else:
57
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
58
 
59
+ @spaces.GPU(duration=120)
 
 
 
60
  @torch.inference_mode()
61
  def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
62
  pipe.to(device)
 
118
  adapter_name = lora_info.get("adapter_name")
119
  adapter_weight = lora_info.get("adapter_weight")
120
  if lora_repo and weights and adapter_name:
121
+ # load lora
122
  pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
123
  adapter_names.append(adapter_name)
124
  adapter_weights.append(adapter_weight)
125
+ # set lora weights
126
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
127
 
128
  # Set random seed for reproducibility
 
131
  seed = random.randint(0, MAX_SEED)
132
 
133
  # Generate image
134
+ try:
135
+ final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
136
+ except:
137
+ final_image = None
138
+
139
  if final_image:
140
  if upload_to_r2:
141
  with calculateDuration("Upload image"):
 
143
  result = {"status": "success", "message": "upload image success", "url": url}
144
  else:
145
  result = {"status": "success", "message": "Image generated but not uploaded"}
146
+ else:
147
+ result = {"status": "failed", "message": "Image generate failed"}
148
+
149
  progress(100, "Completed!")
150
 
151
  return final_image, seed, json.dumps(result)
152
 
153
+ # Gradio interface
154
+
155
  css="""
156
  #col-container {
157
  margin: 0 auto;
 
160
  """
161
 
162
  with gr.Blocks(css=css) as demo:
163
+ gr.Markdown("flux-dev-multi-lora")
164
  with gr.Row():
165
 
166
  with gr.Column():