Spaces:
Running
on
Zero
Running
on
Zero
Make sure to always load the highest trained safetensors file for all cases (#36)
Browse files- Make sure to always load the highest trained safetensors file for all cases (c576d4ca850474f1cbb84f4ca2ff8ff449ea1f68)
Co-authored-by: Sylvain Filoni <[email protected]>
app.py
CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
|
|
12 |
import copy
|
13 |
import random
|
14 |
import time
|
|
|
15 |
|
16 |
# Load LoRAs from JSON file
|
17 |
with open('loras.json', 'r') as f:
|
@@ -172,30 +173,73 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
172 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
173 |
|
174 |
def get_huggingface_safetensors(link):
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
def check_custom_model(link):
|
201 |
if(link.startswith("https://")):
|
|
|
12 |
import copy
|
13 |
import random
|
14 |
import time
|
15 |
+
import re
|
16 |
|
17 |
# Load LoRAs from JSON file
|
18 |
with open('loras.json', 'r') as f:
|
|
|
173 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
174 |
|
175 |
def get_huggingface_safetensors(link):
|
176 |
+
split_link = link.split("/")
|
177 |
+
if len(split_link) != 2:
|
178 |
+
raise Exception("Invalid Hugging Face repository link format.")
|
179 |
+
|
180 |
+
# Load model card
|
181 |
+
model_card = ModelCard.load(link)
|
182 |
+
base_model = model_card.data.get("base_model")
|
183 |
+
print(base_model)
|
184 |
+
|
185 |
+
# Validate model type
|
186 |
+
if base_model not in {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}:
|
187 |
+
raise Exception("Not a FLUX LoRA!")
|
188 |
+
|
189 |
+
# Extract image and trigger word
|
190 |
+
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
191 |
+
trigger_word = model_card.data.get("instance_prompt", "")
|
192 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
193 |
+
|
194 |
+
# Initialize Hugging Face file system
|
195 |
+
fs = HfFileSystem()
|
196 |
+
try:
|
197 |
+
list_of_files = fs.ls(link, detail=False)
|
198 |
+
|
199 |
+
# Initialize variables for safetensors selection
|
200 |
+
safetensors_name = None
|
201 |
+
highest_trained_file = None
|
202 |
+
highest_steps = -1
|
203 |
+
last_safetensors_file = None
|
204 |
+
step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
|
205 |
+
|
206 |
+
for file in list_of_files:
|
207 |
+
filename = file.split("/")[-1]
|
208 |
+
|
209 |
+
# Select safetensors file
|
210 |
+
if filename.endswith(".safetensors"):
|
211 |
+
last_safetensors_file = filename # Track last encountered file
|
212 |
+
|
213 |
+
match = step_pattern.search(filename)
|
214 |
+
if not match:
|
215 |
+
# Found a full model without step numbers, return immediately
|
216 |
+
safetensors_name = filename
|
217 |
+
break
|
218 |
+
else:
|
219 |
+
# Extract step count and track highest
|
220 |
+
steps = int(match.group().lstrip("_"))
|
221 |
+
if steps > highest_steps:
|
222 |
+
highest_trained_file = filename
|
223 |
+
highest_steps = steps
|
224 |
+
|
225 |
+
# Select an image file if not found in model card
|
226 |
+
if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
|
227 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
|
228 |
+
|
229 |
+
# If no full model found, fall back to the most trained safetensors file
|
230 |
+
if not safetensors_name:
|
231 |
+
safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
|
232 |
+
|
233 |
+
# If still no safetensors file found, raise an exception
|
234 |
+
if not safetensors_name:
|
235 |
+
raise Exception("No valid *.safetensors file found in the repository.")
|
236 |
+
|
237 |
+
except Exception as e:
|
238 |
+
print(e)
|
239 |
+
raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
|
240 |
+
|
241 |
+
return split_link[1], link, safetensors_name, trigger_word, image_url
|
242 |
+
|
243 |
|
244 |
def check_custom_model(link):
|
245 |
if(link.startswith("https://")):
|