Victor Shirasuna
commited on
Commit
·
8c39e88
1
Parent(s):
6f04789
Fix typo
Browse files
smi-ted/finetune/smi_ted_large/load.py
CHANGED
@@ -377,7 +377,7 @@ class Smi_ted(nn.Module):
|
|
377 |
self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
|
378 |
self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
|
379 |
|
380 |
-
def load_checkpoint(self, ckpt_path,
|
381 |
# load checkpoint file
|
382 |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
383 |
|
|
|
377 |
self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
|
378 |
self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
|
379 |
|
380 |
+
def load_checkpoint(self, ckpt_path, n_output, eval=False):
|
381 |
# load checkpoint file
|
382 |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
383 |
|