import torch # Load the checkpoint checkpoint = torch.load('/workspace/mixed_wefa_unanon/step_300k.ckpt', map_location='cpu') # Extract the state dictionary state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint # Save the state dictionary to a new checkpoint torch.save(state_dict, 'step_300k_slim.ckpt')