File size: 338 Bytes
c59e741 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
# Load the checkpoint
checkpoint = torch.load('/workspace/mixed_wefa_unanon/step_300k.ckpt', map_location='cpu')
# Extract the state dictionary
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
# Save the state dictionary to a new checkpoint
torch.save(state_dict, 'step_300k_slim.ckpt')
|