Victor Shirasuna commited on
Commit
8c39e88
·
1 Parent(s): 6f04789
smi-ted/finetune/smi_ted_large/load.py CHANGED
@@ -377,7 +377,7 @@ class Smi_ted(nn.Module):
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
- def load_checkpoint(self, ckpt_path, n_outputm eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
 
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383