yuancwang
init
5548515
raw
history blame
3.97 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
import logging
import os
import re
import yaml
import torch
from collections import OrderedDict
import datetime
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
if torch.cuda.is_available():
logging.info("Checkpoint: loading from checkpoint %s for GPU" % path)
checkpoint = torch.load(path)
else:
logging.info("Checkpoint: loading from checkpoint %s for CPU" % path)
checkpoint = torch.load(path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
info_path = re.sub(".pt$", ".yaml", path)
configs = {}
if os.path.exists(info_path):
with open(info_path, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
return configs
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
"""
Args:
infos (dict or None): any info you want to save.
"""
logging.info("Checkpoint: save to checkpoint %s" % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub(".pt$", ".yaml", path)
if infos is None:
infos = {}
infos["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
with open(info_path, "w") as fout:
data = yaml.dump(infos)
fout.write(data)
def filter_modules(model_state_dict, modules):
new_mods = []
incorrect_mods = []
mods_model = model_state_dict.keys()
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.warning(
"module(s) %s don't match or (partially match) "
"available modules in model.",
incorrect_mods,
)
logging.warning("for information, the existing modules in model are:")
logging.warning("%s", mods_model)
return new_mods
def load_trained_modules(model: torch.nn.Module, args: None):
# Load encoder modules with pre-trained model(s).
enc_model_path = args.enc_init
enc_modules = args.enc_init_mods
main_state_dict = model.state_dict()
logging.warning("model(s) found for pre-initialization")
if os.path.isfile(enc_model_path):
logging.info("Checkpoint: loading from checkpoint %s for CPU" % enc_model_path)
model_state_dict = torch.load(enc_model_path, map_location="cpu")
modules = filter_modules(model_state_dict, enc_modules)
partial_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
partial_state_dict[key] = value
main_state_dict.update(partial_state_dict)
else:
logging.warning("model was not found : %s", enc_model_path)
model.load_state_dict(main_state_dict)
configs = {}
return configs