jiuface commited on
Commit
a209f33
·
verified ·
1 Parent(s): 9a9c4b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -20,6 +20,8 @@ import boto3
20
  from io import BytesIO
21
  import re
22
  import json
 
 
23
 
24
  # Login Hugging Face Hub
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -143,6 +145,9 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
143
  print("upload thumbnail finish", thumbnail_file)
144
  return image_file
145
 
 
 
 
146
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
147
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
148
  gr.Info("Starting process")
@@ -165,7 +170,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
165
 
166
  lora_configs = None
167
  adapter_names = []
168
-
169
  if lora_strings_json:
170
  try:
171
  lora_configs = json.loads(lora_strings_json)
@@ -174,30 +179,34 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
174
  print("parse lora config json failed")
175
 
176
  if lora_configs:
 
177
  with calculateDuration("Loading LoRA weights"):
178
  adapter_weights = []
179
- for lora_info in lora_configs:
 
180
  lora_repo = lora_info.get("repo")
181
  weights = lora_info.get("weights")
182
  adapter_name = lora_info.get("adapter_name")
 
 
183
  adapter_weight = lora_info.get("adapter_weight")
184
  adapter_names.append(adapter_name)
185
  adapter_weights.append(adapter_weight)
186
  if lora_repo and weights and adapter_name:
187
  try:
188
  if img2img_model:
189
- img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
190
  else:
191
- txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
192
  except:
193
  print("load lora error")
194
 
195
  # set lora weights
196
  if len(adapter_names) > 0:
197
  if img2img_model:
198
- img2img_pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
199
  else:
200
- txt2img_pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
201
 
202
  print(txt2img_pipe.get_active_adapters())
203
 
 
20
  from io import BytesIO
21
  import re
22
  import json
23
+ import random
24
+ import string
25
 
26
  # Login Hugging Face Hub
27
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
145
  print("upload thumbnail finish", thumbnail_file)
146
  return image_file
147
 
148
+ def generate_random_4_digit_string():
149
+ return ''.join(random.choices(string.digits, k=4))
150
+
151
  def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
152
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
153
  gr.Info("Starting process")
 
170
 
171
  lora_configs = None
172
  adapter_names = []
173
+ lora_names = []
174
  if lora_strings_json:
175
  try:
176
  lora_configs = json.loads(lora_strings_json)
 
179
  print("parse lora config json failed")
180
 
181
  if lora_configs:
182
+
183
  with calculateDuration("Loading LoRA weights"):
184
  adapter_weights = []
185
+
186
+ for idx, lora_info in enumerate(lora_configs):
187
  lora_repo = lora_info.get("repo")
188
  weights = lora_info.get("weights")
189
  adapter_name = lora_info.get("adapter_name")
190
+ lora_name = f"lora_{generate_random_4_digit_string()}"
191
+ lora_names.append(load_image)
192
  adapter_weight = lora_info.get("adapter_weight")
193
  adapter_names.append(adapter_name)
194
  adapter_weights.append(adapter_weight)
195
  if lora_repo and weights and adapter_name:
196
  try:
197
  if img2img_model:
198
+ img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
199
  else:
200
+ txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
201
  except:
202
  print("load lora error")
203
 
204
  # set lora weights
205
  if len(adapter_names) > 0:
206
  if img2img_model:
207
+ img2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
208
  else:
209
+ txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
210
 
211
  print(txt2img_pipe.get_active_adapters())
212