3v324v23 commited on
Commit
c59e741
·
1 Parent(s): 421cdfe
Files changed (1) hide show
  1. strip_state.py +10 -0
strip_state.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Load the checkpoint
4
+ checkpoint = torch.load('/workspace/mixed_wefa_unanon/step_300k.ckpt', map_location='cpu')
5
+
6
+ # Extract the state dictionary
7
+ state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
8
+
9
+ # Save the state dictionary to a new checkpoint
10
+ torch.save(state_dict, 'step_300k_slim.ckpt')