Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,8 @@ import boto3
|
|
20 |
from io import BytesIO
|
21 |
import re
|
22 |
import json
|
|
|
|
|
23 |
|
24 |
# Login Hugging Face Hub
|
25 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
@@ -143,6 +145,9 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
|
143 |
print("upload thumbnail finish", thumbnail_file)
|
144 |
return image_file
|
145 |
|
|
|
|
|
|
|
146 |
def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
147 |
print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
|
148 |
gr.Info("Starting process")
|
@@ -165,7 +170,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
165 |
|
166 |
lora_configs = None
|
167 |
adapter_names = []
|
168 |
-
|
169 |
if lora_strings_json:
|
170 |
try:
|
171 |
lora_configs = json.loads(lora_strings_json)
|
@@ -174,30 +179,34 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
174 |
print("parse lora config json failed")
|
175 |
|
176 |
if lora_configs:
|
|
|
177 |
with calculateDuration("Loading LoRA weights"):
|
178 |
adapter_weights = []
|
179 |
-
|
|
|
180 |
lora_repo = lora_info.get("repo")
|
181 |
weights = lora_info.get("weights")
|
182 |
adapter_name = lora_info.get("adapter_name")
|
|
|
|
|
183 |
adapter_weight = lora_info.get("adapter_weight")
|
184 |
adapter_names.append(adapter_name)
|
185 |
adapter_weights.append(adapter_weight)
|
186 |
if lora_repo and weights and adapter_name:
|
187 |
try:
|
188 |
if img2img_model:
|
189 |
-
img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=
|
190 |
else:
|
191 |
-
txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=
|
192 |
except:
|
193 |
print("load lora error")
|
194 |
|
195 |
# set lora weights
|
196 |
if len(adapter_names) > 0:
|
197 |
if img2img_model:
|
198 |
-
img2img_pipe.set_adapters(
|
199 |
else:
|
200 |
-
txt2img_pipe.set_adapters(
|
201 |
|
202 |
print(txt2img_pipe.get_active_adapters())
|
203 |
|
|
|
20 |
from io import BytesIO
|
21 |
import re
|
22 |
import json
|
23 |
+
import random
|
24 |
+
import string
|
25 |
|
26 |
# Login Hugging Face Hub
|
27 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
145 |
print("upload thumbnail finish", thumbnail_file)
|
146 |
return image_file
|
147 |
|
148 |
+
def generate_random_4_digit_string():
|
149 |
+
return ''.join(random.choices(string.digits, k=4))
|
150 |
+
|
151 |
def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
152 |
print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
|
153 |
gr.Info("Starting process")
|
|
|
170 |
|
171 |
lora_configs = None
|
172 |
adapter_names = []
|
173 |
+
lora_names = []
|
174 |
if lora_strings_json:
|
175 |
try:
|
176 |
lora_configs = json.loads(lora_strings_json)
|
|
|
179 |
print("parse lora config json failed")
|
180 |
|
181 |
if lora_configs:
|
182 |
+
|
183 |
with calculateDuration("Loading LoRA weights"):
|
184 |
adapter_weights = []
|
185 |
+
|
186 |
+
for idx, lora_info in enumerate(lora_configs):
|
187 |
lora_repo = lora_info.get("repo")
|
188 |
weights = lora_info.get("weights")
|
189 |
adapter_name = lora_info.get("adapter_name")
|
190 |
+
lora_name = f"lora_{generate_random_4_digit_string()}"
|
191 |
+
lora_names.append(load_image)
|
192 |
adapter_weight = lora_info.get("adapter_weight")
|
193 |
adapter_names.append(adapter_name)
|
194 |
adapter_weights.append(adapter_weight)
|
195 |
if lora_repo and weights and adapter_name:
|
196 |
try:
|
197 |
if img2img_model:
|
198 |
+
img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
|
199 |
else:
|
200 |
+
txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
|
201 |
except:
|
202 |
print("load lora error")
|
203 |
|
204 |
# set lora weights
|
205 |
if len(adapter_names) > 0:
|
206 |
if img2img_model:
|
207 |
+
img2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
|
208 |
else:
|
209 |
+
txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
|
210 |
|
211 |
print(txt2img_pipe.get_active_adapters())
|
212 |
|