File size: 276 Bytes
587b6c9
 
04a69d4
 
 
 
 
 
587b6c9
04a69d4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import pickle

import jax

dic = pickle.load(
    open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb")
)
dic = jax.device_get(dic)
del dic["optim_state_dict"]
pickle.dump(
    dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb")
)