Fabrice-TIERCELIN commited on
Commit
f1dff10
·
verified ·
1 Parent(s): d2f25e6

KO -> more comment

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -51,21 +51,21 @@ class Tango:
51
  self.stft = TacotronSTFT(**stft_config).to(device)
52
  self.model = AudioDiffusion(**main_config).to(device)
53
 
54
- vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
55
- stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
56
- main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
57
-
58
- self.vae.load_state_dict(vae_weights)
59
- self.stft.load_state_dict(stft_weights)
60
- self.model.load_state_dict(main_weights)
61
-
62
- print ("Successfully loaded checkpoint from:", name)
63
-
64
- self.vae.eval()
65
- self.stft.eval()
66
- self.model.eval()
67
-
68
- self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
69
 
70
  def chunks(self, lst, n):
71
  # Yield successive n-sized chunks from a list
 
51
  self.stft = TacotronSTFT(**stft_config).to(device)
52
  self.model = AudioDiffusion(**main_config).to(device)
53
 
54
+ # vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
55
+ # stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
56
+ # main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
57
+ #
58
+ # self.vae.load_state_dict(vae_weights)
59
+ # self.stft.load_state_dict(stft_weights)
60
+ # self.model.load_state_dict(main_weights)
61
+ #
62
+ # print ("Successfully loaded checkpoint from:", name)
63
+ #
64
+ # self.vae.eval()
65
+ # self.stft.eval()
66
+ # self.model.eval()
67
+ #
68
+ # self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
69
 
70
  def chunks(self, lst, n):
71
  # Yield successive n-sized chunks from a list