train-wefadoor-master / strip_state.py
3v324v23's picture
276
a34661c
raw
history blame
388 Bytes
import torch
# Load the checkpoint
checkpoint = torch.load('/workspace/train-wefadoor-master/anydoor/step_276500.ckpt', map_location='cpu')
# Extract the state dictionary
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
# Save the state dictionary to a new checkpoint
torch.save(state_dict, '/workspace/train-wefadoor-master/anydoor/step_276k.ckpt')