Jcrios commited on
Commit
ca92bd9
·
1 Parent(s): 23af708

carga de modelo

Browse files
Files changed (1) hide show
  1. utils.py +51 -3
utils.py CHANGED
@@ -1,10 +1,58 @@
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
 
4
 
5
- def cargar_modelo(model_name = 'ceyda/butterfly_cropped_uniq1K_512', model_version = None):
6
- gan = LightweightGAN.from_pretrained(model_name, version=model_version)
7
- gan.eval
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  return gan
9
 
10
  def genera(gan, batch_size=1):
 
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
+ from huggingface_hub import hf_hub_download
5
 
6
+ CONFIG_NAME = "config.json"
7
+ revision = None
8
+ cache_dir = None
9
+ force_download = False
10
+ proxies = None
11
+ resume_download = False
12
+ local_files_only = False
13
+ token = None
14
+
15
+ def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512"):
16
+ """
17
+ Loads a pre-trained LightweightGAN model from Hugging Face Model Hub.
18
+ Args:
19
+ model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512".
20
+ model_version (str): The version of the pre-trained model to load. Defaults to None.
21
+ Returns:
22
+ LightweightGAN: The loaded pre-trained model.
23
+ """
24
+ # Load the config
25
+ config_file = hf_hub_download(
26
+ repo_id=str(model_name),
27
+ filename=CONFIG_NAME,
28
+ revision=revision,
29
+ cache_dir=cache_dir,
30
+ force_download=force_download,
31
+ proxies=proxies,
32
+ resume_download=resume_download,
33
+ token=token,
34
+ local_files_only=local_files_only,
35
+ )
36
+ with open(config_file, "r", encoding="utf-8") as f:
37
+ config = json.load(f)
38
+
39
+ # Call the _from_pretrained with all the needed arguments
40
+ gan = LightweightGAN(latent_dim=256, image_size=512)
41
+
42
+ gan = gan._from_pretrained(
43
+ model_id=str(model_name),
44
+ revision=revision,
45
+ cache_dir=cache_dir,
46
+ force_download=force_download,
47
+ proxies=proxies,
48
+ resume_download=resume_download,
49
+ local_files_only=local_files_only,
50
+ token=token,
51
+ use_auth_token=False,
52
+ config=config, # usually in **model_kwargs
53
+ )
54
+
55
+ gan.eval()
56
  return gan
57
 
58
  def genera(gan, batch_size=1):