Hhblvjgvg commited on
Commit
afe3886
verified
1 Parent(s): f5f6d36

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +21 -226
convert.py CHANGED
@@ -13,7 +13,6 @@ from huggingface_hub import HfApi, Repository, hf_hub_download
13
  from huggingface_hub.file_download import repo_folder_name
14
  from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
15
 
16
- # Descripci贸n para el reporte en formato Markdown
17
  REPORT_DESCRIPTION = """
18
  Este es un reporte automatizado creado con una herramienta de conversi贸n personalizada.
19
 
@@ -25,52 +24,35 @@ https://colab.research.google.com/github/huggingface/notebooks/blob/main/safeten
25
 
26
  Los widgets en la p谩gina de tu modelo funcionar谩n usando este modelo, asegurando que el archivo realmente funcione.
27
 
28
- Si encuentras alg煤n problema: por favor rep贸rtalo en el siguiente enlace: https://huggingface.co/spaces/Hhblvjgvg/convert/discussions
29
 
30
  Si茅ntete libre de ignorar este reporte.
31
  """
32
 
33
- # Tipo de resultado de conversi贸n: Lista de archivos convertidos y lista de errores
34
  ConversionResult = Tuple[List[str], List[Tuple[str, "Exception"]]]
35
 
36
- def _remove_duplicate_names(
37
- state_dict: Dict[str, torch.Tensor],
38
- *,
39
- preferred_names: List[str] = None,
40
- discard_names: List[str] = None,
41
- ) -> Dict[str, List[str]]:
42
- """
43
- Elimina nombres duplicados en el state_dict bas谩ndose en las preferencias y nombres a descartar.
44
- """
45
  if preferred_names is None:
46
  preferred_names = []
47
  preferred_names = set(preferred_names)
48
  if discard_names is None:
49
  discard_names = []
50
  discard_names = set(discard_names)
51
-
52
  shareds = _find_shared_tensors(state_dict)
53
  to_remove = defaultdict(list)
54
  for shared in shareds:
55
  complete_names = set([name for name in shared if _is_complete(state_dict[name])])
56
  if not complete_names:
57
  if len(shared) == 1:
58
- # Forzar contig眉idad
59
  name = list(shared)[0]
60
  state_dict[name] = state_dict[name].clone()
61
  complete_names = {name}
62
  else:
63
- raise RuntimeError(
64
- f"Error al intentar encontrar nombres para remover al guardar el state dict, pero no se encontr贸 un nombre adecuado para mantener entre: {shared}. Ninguno cubre todo el almacenamiento. Rechazando guardar/cargar el modelo ya que podr铆as estar almacenando mucha m谩s memoria de la necesaria. Por favor, refi茅rete a https://huggingface.co/docs/safetensors/torch_shared_tensors para m谩s informaci贸n. O abre un issue."
65
- )
66
-
67
  keep_name = sorted(list(complete_names))[0]
68
-
69
- # Mecanismo para seleccionar preferentemente claves a mantener provenientes del archivo en disco
70
  preferred = complete_names.difference(discard_names)
71
  if preferred:
72
  keep_name = sorted(list(preferred))[0]
73
-
74
  if preferred_names:
75
  preferred = preferred_names.intersection(complete_names)
76
  if preferred:
@@ -81,66 +63,34 @@ def _remove_duplicate_names(
81
  return to_remove
82
 
83
  def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
84
- """
85
- Obtiene los nombres de pesos atados que deben ser descartados seg煤n la arquitectura del modelo.
86
- """
87
  try:
88
  import transformers
89
-
90
- config_filename = hf_hub_download(
91
- model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
92
- )
93
  with open(config_filename, "r") as f:
94
  config = json.load(f)
95
  architecture = config["architectures"][0]
96
-
97
  class_ = getattr(transformers, architecture)
98
-
99
- # Nombre para esta variable depende de la versi贸n de transformers.
100
  discard_names = getattr(class_, "_tied_weights_keys", [])
101
-
102
  except Exception:
103
  discard_names = []
104
  return discard_names
105
 
106
  def check_file_size(sf_filename: str, pt_filename: str):
107
- """
108
- Verifica que la diferencia de tama帽o entre el archivo safetensors y el original sea menor al 1%.
109
- """
110
  sf_size = os.stat(sf_filename).st_size
111
  pt_size = os.stat(pt_filename).st_size
112
-
113
  if (sf_size - pt_size) / pt_size > 0.01:
114
- raise RuntimeError(
115
- f"""La diferencia de tama帽o de archivo es mayor al 1%:
116
- - {sf_filename}: {sf_size} bytes
117
- - {pt_filename}: {pt_size} bytes
118
- """
119
- )
120
 
121
  def rename(model_id: str, pt_filename: str) -> str:
122
- """
123
- Renombra el archivo PyTorch a safetensors usando el model_id para un mapeo autom谩tico.
124
- """
125
  filename, ext = os.path.splitext(pt_filename)
126
- # Extraer el nombre base del archivo sin directorios
127
  base_name = os.path.basename(filename)
128
- # Generar el nombre safetensors basado en el model_id y el nombre base
129
  safetensors_name = f"{model_id.replace('/', '_')}_{base_name}.safetensors"
130
  return safetensors_name
131
 
132
- def convert_multi(
133
- model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
134
- ) -> ConversionResult:
135
- """
136
- Convierte modelos con m煤ltiples archivos de pesos (multi-file).
137
- """
138
- filename = hf_hub_download(
139
- repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
140
- )
141
  with open(filename, "r") as f:
142
  data = json.load(f)
143
-
144
  filenames = set(data["weight_map"].values())
145
  local_filenames = []
146
  errors = []
@@ -153,8 +103,6 @@ def convert_multi(
153
  local_filenames.append(sf_filepath)
154
  except Exception as e:
155
  errors.append((filename, e))
156
-
157
- # Crear el archivo de 铆ndice para safetensors
158
  index = os.path.join(folder, f"{model_id.replace('/', '_')}_model.safetensors.index.json")
159
  try:
160
  with open(index, "w") as f:
@@ -165,19 +113,11 @@ def convert_multi(
165
  local_filenames.append(index)
166
  except Exception as e:
167
  errors.append((index, e))
168
-
169
  return local_filenames, errors
170
 
171
- def convert_single(
172
- model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
173
- ) -> ConversionResult:
174
- """
175
- Convierte un modelo con un 煤nico archivo de pesos.
176
- """
177
  try:
178
- pt_filename = hf_hub_download(
179
- repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
180
- )
181
  sf_name = rename(model_id, "pytorch_model.bin")
182
  sf_filepath = os.path.join(folder, sf_name)
183
  convert_file(pt_filename, sf_filepath, discard_names)
@@ -188,28 +128,18 @@ def convert_single(
188
  errors = [("pytorch_model.bin", e)]
189
  return local_filenames, errors
190
 
191
- def convert_file(
192
- pt_filename: str,
193
- sf_filename: str,
194
- discard_names: List[str],
195
- ):
196
- """
197
- Convierte un archivo de pesos de PyTorch a safetensors.
198
- """
199
  loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
200
  if "state_dict" in loaded:
201
  loaded = loaded["state_dict"]
202
  to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
203
-
204
  metadata = {"format": "pt"}
205
  for kept_name, to_remove_group in to_removes.items():
206
  for to_remove in to_remove_group:
207
  if to_remove not in metadata:
208
  metadata[to_remove] = kept_name
209
  del loaded[to_remove]
210
- # Forzar que los tensores sean contiguos
211
  loaded = {k: v.contiguous() for k, v in loaded.items()}
212
-
213
  dirname = os.path.dirname(sf_filename)
214
  os.makedirs(dirname, exist_ok=True)
215
  save_file(loaded, sf_filename, metadata=metadata)
@@ -221,27 +151,17 @@ def convert_file(
221
  if not torch.equal(pt_tensor, sf_tensor):
222
  raise RuntimeError(f"Los tensores de salida no coinciden para la clave {k}")
223
 
224
- def convert_generic(
225
- model_id: str, *, revision: Optional[str], folder: str, filenames: Set[str], token: Optional[str]
226
- ) -> ConversionResult:
227
- """
228
- Convierte modelos que no utilizan la librer铆a Transformers o que tienen una estructura gen茅rica.
229
- """
230
  local_filenames = []
231
  errors = []
232
-
233
- # Agregar ".pth" a las extensiones soportadas
234
  extensions = set([".bin", ".ckpt", ".pth"])
235
  for filename in filenames:
236
  prefix, ext = os.path.splitext(filename)
237
  if ext in extensions:
238
  try:
239
- pt_filename = hf_hub_download(
240
- model_id, revision=revision, filename=filename, token=token, cache_dir=folder
241
- )
242
  dirname, raw_filename = os.path.split(filename)
243
  if raw_filename in {"pytorch_model.bin", "pytorch_model.pth"}:
244
- # Manejar casos especiales para transformers
245
  sf_in_repo = rename(model_id, raw_filename)
246
  else:
247
  sf_in_repo = rename(model_id, filename)
@@ -252,21 +172,9 @@ def convert_generic(
252
  errors.append((filename, e))
253
  return local_filenames, errors
254
 
255
- def prepare_target_repo_files(
256
- model_id: str,
257
- revision: Optional[str],
258
- folder: str,
259
- token: str,
260
- repo_dir: str
261
- ):
262
- """
263
- Prepara los archivos adicionales necesarios en el repositorio de destino.
264
- Descarga o crea archivos como .gitattributes, LICENSE.txt, README.md, USE_POLICY.md, config.json,
265
- generation_config.json, special_tokens_map.json, tokenizer.json, tokenizer_config.json.
266
- """
267
  api = HfApi()
268
  try:
269
- # Descargar archivos comunes del modelo original
270
  common_files = [
271
  ".gitattributes",
272
  "LICENSE.txt",
@@ -280,70 +188,44 @@ def prepare_target_repo_files(
280
  ]
281
  for file in common_files:
282
  try:
283
- file_path = hf_hub_download(
284
- repo_id=model_id,
285
- revision=revision,
286
- filename=file,
287
- token=token,
288
- cache_dir=folder
289
- )
290
  shutil.copy(file_path, repo_dir)
291
- print(f"Archivo descargado y copiado: {file}")
292
- except Exception as e:
293
- # Si el archivo no existe en el modelo original, crear uno vac铆o o con contenido por defecto
294
  if file == ".gitattributes":
295
  gitattributes_content = "model.safetensors filter=safetensors diff=safetensors merge=safetensors -text\n"
296
  with open(os.path.join(repo_dir, file), "w") as f:
297
  f.write(gitattributes_content)
298
- print(f"Archivo creado: {file} con configuraci贸n para Git LFS")
299
  elif file == "LICENSE.txt":
300
- # Crear un archivo LICENSE.txt gen茅rico o personalizado
301
  default_license = "MIT License\n\nCopyright (c) 2024"
302
  with open(os.path.join(repo_dir, file), "w") as f:
303
  f.write(default_license)
304
- print(f"Archivo creado: {file} con licencia por defecto")
305
  elif file == "README.md":
306
- # Crear un README.md gen茅rico
307
  readme_content = f"# {model_id.replace('/', ' ').title()}\n\nModelo convertido a safetensors."
308
  with open(os.path.join(repo_dir, file), "w") as f:
309
  f.write(readme_content)
310
- print(f"Archivo creado: {file} con contenido b谩sico de README")
311
  elif file == "USE_POLICY.md":
312
- # Crear un USE_POLICY.md gen茅rico
313
  use_policy_content = "### Pol铆tica de Uso\n\nEste modelo se distribuye bajo t茅rminos de uso est谩ndar."
314
  with open(os.path.join(repo_dir, file), "w") as f:
315
  f.write(use_policy_content)
316
- print(f"Archivo creado: {file} con pol铆tica de uso por defecto")
317
  elif file in {"config.json", "generation_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"}:
318
- # Crear archivos JSON vac铆os o con contenido por defecto
319
  default_json_content = {}
320
  with open(os.path.join(repo_dir, file), "w") as f:
321
  json.dump(default_json_content, f, indent=4)
322
- print(f"Archivo creado: {file} con contenido JSON vac铆o")
323
- else:
324
- print(f"Error al manejar el archivo {file}: {e}")
325
  except Exception as e:
326
- print(f"Error al preparar archivos adicionales: {e}")
327
  raise e
328
 
329
  def generate_report(model_id: str, local_filenames: List[str], errors: List[Tuple[str, Exception]], output_md_path: str):
330
- """
331
- Genera un reporte en formato Markdown y JSON detallando los resultados de la conversi贸n.
332
- """
333
- # Generar reporte Markdown
334
  report_lines = [
335
  f"# Reporte de Conversi贸n para el Modelo `{model_id}`",
336
  f"Fecha y Hora: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
337
  "",
338
  "## Archivos Convertidos Exitosamente",
339
  ]
340
-
341
  if local_filenames:
342
  for filename in local_filenames:
343
  report_lines.append(f"- `{os.path.basename(filename)}`")
344
  else:
345
  report_lines.append("No se convirtieron archivos.")
346
-
347
  report_lines.append("")
348
  report_lines.append("## Errores Durante la Conversi贸n")
349
  if errors:
@@ -351,14 +233,9 @@ def generate_report(model_id: str, local_filenames: List[str], errors: List[Tupl
351
  report_lines.append(f"- **Archivo**: `{os.path.basename(filename)}`\n - **Error**: {error}")
352
  else:
353
  report_lines.append("No hubo errores durante la conversi贸n.")
354
-
355
  report_content_md = "\n".join(report_lines)
356
-
357
- # Guardar reporte Markdown
358
  with open(output_md_path, "w") as f:
359
  f.write(report_content_md)
360
-
361
- # Generar reporte JSON
362
  report_json = {
363
  "model_id": model_id,
364
  "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
@@ -366,31 +243,20 @@ def generate_report(model_id: str, local_filenames: List[str], errors: List[Tupl
366
  "errors": [{"file": os.path.basename(f), "error": str(e)} for f, e in errors],
367
  "description": REPORT_DESCRIPTION.strip()
368
  }
369
-
370
- # Guardar reporte JSON
371
  json_output_path = os.path.splitext(output_md_path)[0] + "_report.json"
372
  with open(json_output_path, "w") as f:
373
  json.dump(report_json, f, indent=4)
374
-
375
  print(f"Reportes generados en: {output_md_path} y {json_output_path}")
376
 
377
- def convert(
378
- model_id: str, revision: Optional[str] = None, force: bool = False, token: Optional[str] = None
379
- ) -> ConversionResult:
380
- """
381
- Orquesta la conversi贸n del modelo especificado.
382
- """
383
  api = HfApi()
384
  info = api.model_info(repo_id=model_id, revision=revision)
385
  filenames = set(s.rfilename for s in info.siblings)
386
-
387
  with TemporaryDirectory() as d:
388
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
389
  os.makedirs(folder, exist_ok=True)
390
  local_filenames = []
391
  errors = []
392
-
393
- # Verificar si ya existen archivos .safetensors y si no forzar
394
  if not force and any(filename.endswith(".safetensors") for filename in filenames):
395
  print(f"El modelo `{model_id}` ya tiene archivos `.safetensors` convertidos. Usando report existente o forzando con --force.")
396
  else:
@@ -398,32 +264,22 @@ def convert(
398
  if library_name == "transformers":
399
  discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=token)
400
  if "pytorch_model.bin" in filenames or "pytorch_model.pth" in filenames:
401
- converted, conv_errors = convert_single(
402
- model_id, revision=revision, folder=folder, token=token, discard_names=discard_names
403
- )
404
  local_filenames.extend(converted)
405
  errors.extend(conv_errors)
406
  elif "pytorch_model.bin.index.json" in filenames:
407
- converted, conv_errors = convert_multi(
408
- model_id, revision=revision, folder=folder, token=token, discard_names=discard_names
409
- )
410
  local_filenames.extend(converted)
411
  errors.extend(conv_errors)
412
  else:
413
  print(f"El modelo `{model_id}` no parece ser un modelo v谩lido de PyTorch. No se puede convertir.")
414
  else:
415
- converted, conv_errors = convert_generic(
416
- model_id, revision=revision, folder=folder, filenames=filenames, token=token
417
- )
418
  local_filenames.extend(converted)
419
  errors.extend(conv_errors)
420
-
421
  return local_filenames, errors
422
 
423
  def read_token(token_file: Optional[str]) -> Optional[str]:
424
- """
425
- Lee el token de autenticaci贸n desde un archivo o variable de entorno.
426
- """
427
  if token_file:
428
  if os.path.isfile(token_file):
429
  with open(token_file, "r") as f:
@@ -433,74 +289,40 @@ def read_token(token_file: Optional[str]) -> Optional[str]:
433
  print(f"El archivo de token especificado no existe: {token_file}")
434
  return None
435
  else:
436
- # Leer desde variable de entorno
437
  return os.getenv("HF_TOKEN")
438
 
439
  def create_target_repo(model_id: str, api: HfApi, token: str) -> str:
440
- """
441
- Crea un nuevo repositorio en Hugging Face Hub bajo tu perfil para almacenar los archivos safetensors.
442
- """
443
- # Definir el nombre del nuevo repositorio
444
  target_repo_id = f"{api.whoami(token=token)['name']}/{model_id.replace('/', '_')}_safetensors"
445
-
446
  try:
447
- api.create_repo(
448
- name=f"{model_id.replace('/', '_')}_safetensors",
449
- repo_type="model",
450
- exist_ok=True,
451
- token=token
452
- )
453
  print(f"Repositorio creado o ya existente: {target_repo_id}")
454
  except Exception as e:
455
  print(f"Error al crear el repositorio `{target_repo_id}`: {e}")
456
  raise e
457
-
458
  return target_repo_id
459
 
460
  def upload_to_hf(local_filenames: List[str], target_repo_id: str, token: str, additional_files: List[str]):
461
- """
462
- Sube los archivos convertidos y archivos adicionales a un nuevo repositorio en Hugging Face Hub.
463
- """
464
- # Inicializar Repository
465
  repo_dir = "./temp_repo"
466
  if os.path.exists(repo_dir):
467
  shutil.rmtree(repo_dir)
468
  os.makedirs(repo_dir, exist_ok=True)
469
-
470
  try:
471
- # Clonar el repositorio vac铆o (crear uno nuevo)
472
- repo = Repository(
473
- local_dir=repo_dir,
474
- clone_from=target_repo_id,
475
- use_auth_token=token
476
- )
477
-
478
- # Copiar archivos convertidos al repositorio local
479
  for file_path in local_filenames:
480
  shutil.copy(file_path, repo_dir)
481
-
482
- # Copiar archivos adicionales al repositorio local
483
  for file_path in additional_files:
484
  shutil.copy(file_path, repo_dir)
485
-
486
- # A帽adir y hacer commit de los archivos
487
  repo.git_add(auto_lfs_track=True)
488
  repo.git_commit("A帽adiendo archivos safetensors convertidos")
489
- # Push al repositorio
490
  repo.git_push()
491
-
492
  print(f"Archivos subidos exitosamente al repositorio: {target_repo_id}")
493
  except Exception as e:
494
  print(f"Error al subir archivos al repositorio `{target_repo_id}`: {e}")
495
  raise e
496
  finally:
497
- # Limpiar el directorio temporal del repositorio
498
  shutil.rmtree(repo_dir)
499
 
500
  def main():
501
- """
502
- Funci贸n principal que maneja la interacci贸n con el usuario y coordina la conversi贸n, subida y generaci贸n de reportes.
503
- """
504
  DESCRIPTION = """
505
  Herramienta de utilidad simple para convertir autom谩ticamente algunos pesos en el hub al formato `safetensors`.
506
  Actualmente exclusiva para PyTorch.
@@ -547,14 +369,10 @@ def main():
547
  )
548
  args = parser.parse_args()
549
  model_id = args.model_id
550
-
551
- # Leer el token de autenticaci贸n
552
  token = read_token(args.token_file)
553
  if not token:
554
  print("No se proporcion贸 un token de autenticaci贸n v谩lido. Por favor, proporci贸nalo mediante --token-file o establece la variable de entorno 'HF_TOKEN'.")
555
  return
556
-
557
- # Inicializar HfApi con el token
558
  api = HfApi()
559
  try:
560
  user_info = api.whoami(token=token)
@@ -562,8 +380,6 @@ def main():
562
  except Exception as e:
563
  print(f"No se pudo autenticar con Hugging Face Hub: {e}")
564
  return
565
-
566
- # Confirmaci贸n de seguridad
567
  if args.y:
568
  proceed = True
569
  else:
@@ -573,48 +389,27 @@ def main():
573
  " 驴Continuar [Y/n] ? "
574
  )
575
  proceed = txt.lower() in {"", "y", "yes"}
576
-
577
  if proceed:
578
  try:
579
  with TemporaryDirectory() as d:
580
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
581
  os.makedirs(folder, exist_ok=True)
582
-
583
- # Realizar la conversi贸n
584
  local_filenames, errors = convert(model_id, revision=args.revision, force=args.force, token=token)
585
-
586
- # Crear el repositorio de destino en tu perfil
587
  target_repo_id = create_target_repo(model_id, api, token)
588
-
589
- # Preparar archivos adicionales en el repositorio local
590
  with TemporaryDirectory() as repo_temp_dir:
591
  prepare_target_repo_files(model_id, args.revision, folder, token, repo_temp_dir)
592
-
593
- # Obtener la lista de archivos adicionales
594
  additional_files = [os.path.join(repo_temp_dir, f) for f in os.listdir(repo_temp_dir)]
595
-
596
- # Subir los archivos convertidos y adicionales al repositorio de destino
597
  if local_filenames or additional_files:
598
  upload_to_hf(local_filenames, target_repo_id, token, additional_files)
599
  print(f"Archivos convertidos y adicionales subidos exitosamente a: {target_repo_id}")
600
  else:
601
  print("No hay archivos convertidos ni adicionales para subir.")
602
-
603
- # Definir la ruta de salida para el reporte
604
  output_md = args.output
605
  if args.output_json:
606
  output_json = args.output_json
607
  else:
608
  output_json = os.path.splitext(output_md)[0] + "_report.json"
609
-
610
- # Generar el reporte
611
  generate_report(model_id, local_filenames, errors, output_md)
612
-
613
- # Generar reporte JSON adicional si se especific贸
614
- if args.output_json:
615
- # Ya se ha generado en `generate_report`
616
- pass
617
-
618
  except Exception as e:
619
  print(f"Ocurri贸 un error inesperado: {e}")
620
  else:
 
13
  from huggingface_hub.file_download import repo_folder_name
14
  from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
15
 
 
16
  REPORT_DESCRIPTION = """
17
  Este es un reporte automatizado creado con una herramienta de conversi贸n personalizada.
18
 
 
24
 
25
  Los widgets en la p谩gina de tu modelo funcionar谩n usando este modelo, asegurando que el archivo realmente funcione.
26
 
27
+ Si encuentras alg煤n problema: por favor rep贸rtalo en el siguiente enlace: https://huggingface.co/spaces/safetensors/convert/discussions
28
 
29
  Si茅ntete libre de ignorar este reporte.
30
  """
31
 
 
32
  ConversionResult = Tuple[List[str], List[Tuple[str, "Exception"]]]
33
 
34
+ def _remove_duplicate_names(state_dict: Dict[str, torch.Tensor], *, preferred_names: List[str] = None, discard_names: List[str] = None) -> Dict[str, List[str]]:
 
 
 
 
 
 
 
 
35
  if preferred_names is None:
36
  preferred_names = []
37
  preferred_names = set(preferred_names)
38
  if discard_names is None:
39
  discard_names = []
40
  discard_names = set(discard_names)
 
41
  shareds = _find_shared_tensors(state_dict)
42
  to_remove = defaultdict(list)
43
  for shared in shareds:
44
  complete_names = set([name for name in shared if _is_complete(state_dict[name])])
45
  if not complete_names:
46
  if len(shared) == 1:
 
47
  name = list(shared)[0]
48
  state_dict[name] = state_dict[name].clone()
49
  complete_names = {name}
50
  else:
51
+ raise RuntimeError(f"Error al intentar encontrar nombres para remover al guardar el state dict, pero no se encontr贸 un nombre adecuado para mantener entre: {shared}. Ninguno cubre todo el almacenamiento. Rechazando guardar/cargar el modelo ya que podr铆as estar almacenando mucha m谩s memoria de la necesaria. Por favor, refi茅rete a https://huggingface.co/docs/safetensors/torch_shared_tensors para m谩s informaci贸n. O abre un issue.")
 
 
 
52
  keep_name = sorted(list(complete_names))[0]
 
 
53
  preferred = complete_names.difference(discard_names)
54
  if preferred:
55
  keep_name = sorted(list(preferred))[0]
 
56
  if preferred_names:
57
  preferred = preferred_names.intersection(complete_names)
58
  if preferred:
 
63
  return to_remove
64
 
65
  def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
 
 
 
66
  try:
67
  import transformers
68
+ config_filename = hf_hub_download(model_id, revision=revision, filename="config.json", token=token, cache_dir=folder)
 
 
 
69
  with open(config_filename, "r") as f:
70
  config = json.load(f)
71
  architecture = config["architectures"][0]
 
72
  class_ = getattr(transformers, architecture)
 
 
73
  discard_names = getattr(class_, "_tied_weights_keys", [])
 
74
  except Exception:
75
  discard_names = []
76
  return discard_names
77
 
78
  def check_file_size(sf_filename: str, pt_filename: str):
 
 
 
79
  sf_size = os.stat(sf_filename).st_size
80
  pt_size = os.stat(pt_filename).st_size
 
81
  if (sf_size - pt_size) / pt_size > 0.01:
82
+ raise RuntimeError(f"La diferencia de tama帽o de archivo es mayor al 1%:\n - {sf_filename}: {sf_size} bytes\n - {pt_filename}: {pt_size} bytes")
 
 
 
 
 
83
 
84
  def rename(model_id: str, pt_filename: str) -> str:
 
 
 
85
  filename, ext = os.path.splitext(pt_filename)
 
86
  base_name = os.path.basename(filename)
 
87
  safetensors_name = f"{model_id.replace('/', '_')}_{base_name}.safetensors"
88
  return safetensors_name
89
 
90
+ def convert_multi(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
91
+ filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
 
 
 
 
 
 
 
92
  with open(filename, "r") as f:
93
  data = json.load(f)
 
94
  filenames = set(data["weight_map"].values())
95
  local_filenames = []
96
  errors = []
 
103
  local_filenames.append(sf_filepath)
104
  except Exception as e:
105
  errors.append((filename, e))
 
 
106
  index = os.path.join(folder, f"{model_id.replace('/', '_')}_model.safetensors.index.json")
107
  try:
108
  with open(index, "w") as f:
 
113
  local_filenames.append(index)
114
  except Exception as e:
115
  errors.append((index, e))
 
116
  return local_filenames, errors
117
 
118
+ def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
 
 
 
 
 
119
  try:
120
+ pt_filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder)
 
 
121
  sf_name = rename(model_id, "pytorch_model.bin")
122
  sf_filepath = os.path.join(folder, sf_name)
123
  convert_file(pt_filename, sf_filepath, discard_names)
 
128
  errors = [("pytorch_model.bin", e)]
129
  return local_filenames, errors
130
 
131
+ def convert_file(pt_filename: str, sf_filename: str, discard_names: List[str]):
 
 
 
 
 
 
 
132
  loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
133
  if "state_dict" in loaded:
134
  loaded = loaded["state_dict"]
135
  to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
 
136
  metadata = {"format": "pt"}
137
  for kept_name, to_remove_group in to_removes.items():
138
  for to_remove in to_remove_group:
139
  if to_remove not in metadata:
140
  metadata[to_remove] = kept_name
141
  del loaded[to_remove]
 
142
  loaded = {k: v.contiguous() for k, v in loaded.items()}
 
143
  dirname = os.path.dirname(sf_filename)
144
  os.makedirs(dirname, exist_ok=True)
145
  save_file(loaded, sf_filename, metadata=metadata)
 
151
  if not torch.equal(pt_tensor, sf_tensor):
152
  raise RuntimeError(f"Los tensores de salida no coinciden para la clave {k}")
153
 
154
+ def convert_generic(model_id: str, *, revision: Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
 
 
 
 
 
155
  local_filenames = []
156
  errors = []
 
 
157
  extensions = set([".bin", ".ckpt", ".pth"])
158
  for filename in filenames:
159
  prefix, ext = os.path.splitext(filename)
160
  if ext in extensions:
161
  try:
162
+ pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
 
 
163
  dirname, raw_filename = os.path.split(filename)
164
  if raw_filename in {"pytorch_model.bin", "pytorch_model.pth"}:
 
165
  sf_in_repo = rename(model_id, raw_filename)
166
  else:
167
  sf_in_repo = rename(model_id, filename)
 
172
  errors.append((filename, e))
173
  return local_filenames, errors
174
 
175
+ def prepare_target_repo_files(model_id: str, revision: Optional[str], folder: str, token: str, repo_dir: str):
 
 
 
 
 
 
 
 
 
 
 
176
  api = HfApi()
177
  try:
 
178
  common_files = [
179
  ".gitattributes",
180
  "LICENSE.txt",
 
188
  ]
189
  for file in common_files:
190
  try:
191
+ file_path = hf_hub_download(repo_id=model_id, revision=revision, filename=file, token=token, cache_dir=folder)
 
 
 
 
 
 
192
  shutil.copy(file_path, repo_dir)
193
+ except Exception:
 
 
194
  if file == ".gitattributes":
195
  gitattributes_content = "model.safetensors filter=safetensors diff=safetensors merge=safetensors -text\n"
196
  with open(os.path.join(repo_dir, file), "w") as f:
197
  f.write(gitattributes_content)
 
198
  elif file == "LICENSE.txt":
 
199
  default_license = "MIT License\n\nCopyright (c) 2024"
200
  with open(os.path.join(repo_dir, file), "w") as f:
201
  f.write(default_license)
 
202
  elif file == "README.md":
 
203
  readme_content = f"# {model_id.replace('/', ' ').title()}\n\nModelo convertido a safetensors."
204
  with open(os.path.join(repo_dir, file), "w") as f:
205
  f.write(readme_content)
 
206
  elif file == "USE_POLICY.md":
 
207
  use_policy_content = "### Pol铆tica de Uso\n\nEste modelo se distribuye bajo t茅rminos de uso est谩ndar."
208
  with open(os.path.join(repo_dir, file), "w") as f:
209
  f.write(use_policy_content)
 
210
  elif file in {"config.json", "generation_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"}:
 
211
  default_json_content = {}
212
  with open(os.path.join(repo_dir, file), "w") as f:
213
  json.dump(default_json_content, f, indent=4)
 
 
 
214
  except Exception as e:
 
215
  raise e
216
 
217
  def generate_report(model_id: str, local_filenames: List[str], errors: List[Tuple[str, Exception]], output_md_path: str):
 
 
 
 
218
  report_lines = [
219
  f"# Reporte de Conversi贸n para el Modelo `{model_id}`",
220
  f"Fecha y Hora: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
221
  "",
222
  "## Archivos Convertidos Exitosamente",
223
  ]
 
224
  if local_filenames:
225
  for filename in local_filenames:
226
  report_lines.append(f"- `{os.path.basename(filename)}`")
227
  else:
228
  report_lines.append("No se convirtieron archivos.")
 
229
  report_lines.append("")
230
  report_lines.append("## Errores Durante la Conversi贸n")
231
  if errors:
 
233
  report_lines.append(f"- **Archivo**: `{os.path.basename(filename)}`\n - **Error**: {error}")
234
  else:
235
  report_lines.append("No hubo errores durante la conversi贸n.")
 
236
  report_content_md = "\n".join(report_lines)
 
 
237
  with open(output_md_path, "w") as f:
238
  f.write(report_content_md)
 
 
239
  report_json = {
240
  "model_id": model_id,
241
  "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
 
243
  "errors": [{"file": os.path.basename(f), "error": str(e)} for f, e in errors],
244
  "description": REPORT_DESCRIPTION.strip()
245
  }
 
 
246
  json_output_path = os.path.splitext(output_md_path)[0] + "_report.json"
247
  with open(json_output_path, "w") as f:
248
  json.dump(report_json, f, indent=4)
 
249
  print(f"Reportes generados en: {output_md_path} y {json_output_path}")
250
 
251
+ def convert(model_id: str, revision: Optional[str] = None, force: bool = False, token: Optional[str] = None) -> ConversionResult:
 
 
 
 
 
252
  api = HfApi()
253
  info = api.model_info(repo_id=model_id, revision=revision)
254
  filenames = set(s.rfilename for s in info.siblings)
 
255
  with TemporaryDirectory() as d:
256
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
257
  os.makedirs(folder, exist_ok=True)
258
  local_filenames = []
259
  errors = []
 
 
260
  if not force and any(filename.endswith(".safetensors") for filename in filenames):
261
  print(f"El modelo `{model_id}` ya tiene archivos `.safetensors` convertidos. Usando report existente o forzando con --force.")
262
  else:
 
264
  if library_name == "transformers":
265
  discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=token)
266
  if "pytorch_model.bin" in filenames or "pytorch_model.pth" in filenames:
267
+ converted, conv_errors = convert_single(model_id, revision=revision, folder=folder, token=token, discard_names=discard_names)
 
 
268
  local_filenames.extend(converted)
269
  errors.extend(conv_errors)
270
  elif "pytorch_model.bin.index.json" in filenames:
271
+ converted, conv_errors = convert_multi(model_id, revision=revision, folder=folder, token=token, discard_names=discard_names)
 
 
272
  local_filenames.extend(converted)
273
  errors.extend(conv_errors)
274
  else:
275
  print(f"El modelo `{model_id}` no parece ser un modelo v谩lido de PyTorch. No se puede convertir.")
276
  else:
277
+ converted, conv_errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=token)
 
 
278
  local_filenames.extend(converted)
279
  errors.extend(conv_errors)
 
280
  return local_filenames, errors
281
 
282
  def read_token(token_file: Optional[str]) -> Optional[str]:
 
 
 
283
  if token_file:
284
  if os.path.isfile(token_file):
285
  with open(token_file, "r") as f:
 
289
  print(f"El archivo de token especificado no existe: {token_file}")
290
  return None
291
  else:
 
292
  return os.getenv("HF_TOKEN")
293
 
294
  def create_target_repo(model_id: str, api: HfApi, token: str) -> str:
 
 
 
 
295
  target_repo_id = f"{api.whoami(token=token)['name']}/{model_id.replace('/', '_')}_safetensors"
 
296
  try:
297
+ api.create_repo(name=f"{model_id.replace('/', '_')}_safetensors", repo_type="model", exist_ok=True, token=token)
 
 
 
 
 
298
  print(f"Repositorio creado o ya existente: {target_repo_id}")
299
  except Exception as e:
300
  print(f"Error al crear el repositorio `{target_repo_id}`: {e}")
301
  raise e
 
302
  return target_repo_id
303
 
304
  def upload_to_hf(local_filenames: List[str], target_repo_id: str, token: str, additional_files: List[str]):
 
 
 
 
305
  repo_dir = "./temp_repo"
306
  if os.path.exists(repo_dir):
307
  shutil.rmtree(repo_dir)
308
  os.makedirs(repo_dir, exist_ok=True)
 
309
  try:
310
+ repo = Repository(local_dir=repo_dir, clone_from=target_repo_id, use_auth_token=token)
 
 
 
 
 
 
 
311
  for file_path in local_filenames:
312
  shutil.copy(file_path, repo_dir)
 
 
313
  for file_path in additional_files:
314
  shutil.copy(file_path, repo_dir)
 
 
315
  repo.git_add(auto_lfs_track=True)
316
  repo.git_commit("A帽adiendo archivos safetensors convertidos")
 
317
  repo.git_push()
 
318
  print(f"Archivos subidos exitosamente al repositorio: {target_repo_id}")
319
  except Exception as e:
320
  print(f"Error al subir archivos al repositorio `{target_repo_id}`: {e}")
321
  raise e
322
  finally:
 
323
  shutil.rmtree(repo_dir)
324
 
325
  def main():
 
 
 
326
  DESCRIPTION = """
327
  Herramienta de utilidad simple para convertir autom谩ticamente algunos pesos en el hub al formato `safetensors`.
328
  Actualmente exclusiva para PyTorch.
 
369
  )
370
  args = parser.parse_args()
371
  model_id = args.model_id
 
 
372
  token = read_token(args.token_file)
373
  if not token:
374
  print("No se proporcion贸 un token de autenticaci贸n v谩lido. Por favor, proporci贸nalo mediante --token-file o establece la variable de entorno 'HF_TOKEN'.")
375
  return
 
 
376
  api = HfApi()
377
  try:
378
  user_info = api.whoami(token=token)
 
380
  except Exception as e:
381
  print(f"No se pudo autenticar con Hugging Face Hub: {e}")
382
  return
 
 
383
  if args.y:
384
  proceed = True
385
  else:
 
389
  " 驴Continuar [Y/n] ? "
390
  )
391
  proceed = txt.lower() in {"", "y", "yes"}
 
392
  if proceed:
393
  try:
394
  with TemporaryDirectory() as d:
395
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
396
  os.makedirs(folder, exist_ok=True)
 
 
397
  local_filenames, errors = convert(model_id, revision=args.revision, force=args.force, token=token)
 
 
398
  target_repo_id = create_target_repo(model_id, api, token)
 
 
399
  with TemporaryDirectory() as repo_temp_dir:
400
  prepare_target_repo_files(model_id, args.revision, folder, token, repo_temp_dir)
 
 
401
  additional_files = [os.path.join(repo_temp_dir, f) for f in os.listdir(repo_temp_dir)]
 
 
402
  if local_filenames or additional_files:
403
  upload_to_hf(local_filenames, target_repo_id, token, additional_files)
404
  print(f"Archivos convertidos y adicionales subidos exitosamente a: {target_repo_id}")
405
  else:
406
  print("No hay archivos convertidos ni adicionales para subir.")
 
 
407
  output_md = args.output
408
  if args.output_json:
409
  output_json = args.output_json
410
  else:
411
  output_json = os.path.splitext(output_md)[0] + "_report.json"
 
 
412
  generate_report(model_id, local_filenames, errors, output_md)
 
 
 
 
 
 
413
  except Exception as e:
414
  print(f"Ocurri贸 un error inesperado: {e}")
415
  else: