Spaces:
Runtime error
Runtime error
Topallaj Denis
commited on
Commit
•
027aeb0
1
Parent(s):
84b598c
modified torch load
Browse files
main.py
CHANGED
@@ -68,7 +68,7 @@ class EndpointHandler():
|
|
68 |
self.vocab = WordVocab(vocab_content)
|
69 |
self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
|
70 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
71 |
-
self.trfm.load_state_dict(torch.load(trfm_path
|
72 |
self.trfm.eval()
|
73 |
|
74 |
# path to the pretrained models
|
|
|
68 |
self.vocab = WordVocab(vocab_content)
|
69 |
self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
|
70 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
71 |
+
self.trfm.load_state_dict(torch.load(trfm_path, map_location=device))
|
72 |
self.trfm.eval()
|
73 |
|
74 |
# path to the pretrained models
|