multimodalart HF staff fffiloni commited on
Commit
3c7efe2
·
verified ·
1 Parent(s): 717456e

Make sure to always load the highest trained safetensors file for all cases (#36)

Browse files

- Make sure to always load the highest trained safetensors file for all cases (c576d4ca850474f1cbb84f4ca2ff8ff449ea1f68)


Co-authored-by: Sylvain Filoni <[email protected]>

Files changed (1) hide show
  1. app.py +68 -24
app.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
12
  import copy
13
  import random
14
  import time
 
15
 
16
  # Load LoRAs from JSON file
17
  with open('loras.json', 'r') as f:
@@ -172,30 +173,73 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
172
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
173
 
174
  def get_huggingface_safetensors(link):
175
- split_link = link.split("/")
176
- if(len(split_link) == 2):
177
- model_card = ModelCard.load(link)
178
- base_model = model_card.data.get("base_model")
179
- print(base_model)
180
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
181
- raise Exception("Not a FLUX LoRA!")
182
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
183
- trigger_word = model_card.data.get("instance_prompt", "")
184
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
185
- fs = HfFileSystem()
186
- try:
187
- list_of_files = fs.ls(link, detail=False)
188
- for file in list_of_files:
189
- if(file.endswith(".safetensors")):
190
- safetensors_name = file.split("/")[-1]
191
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
192
- image_elements = file.split("/")
193
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
194
- except Exception as e:
195
- print(e)
196
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
197
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
198
- return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def check_custom_model(link):
201
  if(link.startswith("https://")):
 
12
  import copy
13
  import random
14
  import time
15
+ import re
16
 
17
  # Load LoRAs from JSON file
18
  with open('loras.json', 'r') as f:
 
173
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
174
 
175
  def get_huggingface_safetensors(link):
176
+ split_link = link.split("/")
177
+ if len(split_link) != 2:
178
+ raise Exception("Invalid Hugging Face repository link format.")
179
+
180
+ # Load model card
181
+ model_card = ModelCard.load(link)
182
+ base_model = model_card.data.get("base_model")
183
+ print(base_model)
184
+
185
+ # Validate model type
186
+ if base_model not in {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}:
187
+ raise Exception("Not a FLUX LoRA!")
188
+
189
+ # Extract image and trigger word
190
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
191
+ trigger_word = model_card.data.get("instance_prompt", "")
192
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
193
+
194
+ # Initialize Hugging Face file system
195
+ fs = HfFileSystem()
196
+ try:
197
+ list_of_files = fs.ls(link, detail=False)
198
+
199
+ # Initialize variables for safetensors selection
200
+ safetensors_name = None
201
+ highest_trained_file = None
202
+ highest_steps = -1
203
+ last_safetensors_file = None
204
+ step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
205
+
206
+ for file in list_of_files:
207
+ filename = file.split("/")[-1]
208
+
209
+ # Select safetensors file
210
+ if filename.endswith(".safetensors"):
211
+ last_safetensors_file = filename # Track last encountered file
212
+
213
+ match = step_pattern.search(filename)
214
+ if not match:
215
+ # Found a full model without step numbers, return immediately
216
+ safetensors_name = filename
217
+ break
218
+ else:
219
+ # Extract step count and track highest
220
+ steps = int(match.group().lstrip("_"))
221
+ if steps > highest_steps:
222
+ highest_trained_file = filename
223
+ highest_steps = steps
224
+
225
+ # Select an image file if not found in model card
226
+ if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
227
+ image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
228
+
229
+ # If no full model found, fall back to the most trained safetensors file
230
+ if not safetensors_name:
231
+ safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
232
+
233
+ # If still no safetensors file found, raise an exception
234
+ if not safetensors_name:
235
+ raise Exception("No valid *.safetensors file found in the repository.")
236
+
237
+ except Exception as e:
238
+ print(e)
239
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
240
+
241
+ return split_link[1], link, safetensors_name, trigger_word, image_url
242
+
243
 
244
  def check_custom_model(link):
245
  if(link.startswith("https://")):