Topallaj Denis commited on
Commit
027aeb0
1 Parent(s): 84b598c

modified torch load

Browse files
Files changed (1) hide show
  1. main.py +1 -1
main.py CHANGED
@@ -68,7 +68,7 @@ class EndpointHandler():
68
  self.vocab = WordVocab(vocab_content)
69
  self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
70
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
71
- self.trfm.load_state_dict(torch.load(trfm_path), map_location=device)
72
  self.trfm.eval()
73
 
74
  # path to the pretrained models
 
68
  self.vocab = WordVocab(vocab_content)
69
  self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
70
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
71
+ self.trfm.load_state_dict(torch.load(trfm_path, map_location=device))
72
  self.trfm.eval()
73
 
74
  # path to the pretrained models