Spaces:
Build error
Build error
Fix typo
Browse files- fromage/models.py +1 -1
fromage/models.py
CHANGED
@@ -644,7 +644,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
644 |
model = model.bfloat16()
|
645 |
model = model.cuda()
|
646 |
|
647 |
-
Load pretrained linear mappings and [RET] embeddings.
|
648 |
checkpoint = torch.load(model_ckpt_path)
|
649 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
650 |
with torch.no_grad():
|
|
|
644 |
model = model.bfloat16()
|
645 |
model = model.cuda()
|
646 |
|
647 |
+
# Load pretrained linear mappings and [RET] embeddings.
|
648 |
checkpoint = torch.load(model_ckpt_path)
|
649 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
650 |
with torch.no_grad():
|