Spaces:
Runtime error
Runtime error
import os, sys | |
import torch | |
import hashlib | |
import datetime | |
from collections import OrderedDict | |
import json | |
now_dir = os.getcwd() | |
sys.path.append(now_dir) | |
def replace_keys_in_dict(d, old_key_part, new_key_part): | |
if isinstance(d, OrderedDict): | |
updated_dict = OrderedDict() | |
else: | |
updated_dict = {} | |
for key, value in d.items(): | |
new_key = key.replace(old_key_part, new_key_part) | |
if isinstance(value, dict): | |
value = replace_keys_in_dict(value, old_key_part, new_key_part) | |
updated_dict[new_key] = value | |
return updated_dict | |
def extract_model( | |
ckpt, | |
sr, | |
pitch_guidance, | |
name, | |
model_dir, | |
epoch, | |
step, | |
version, | |
hps, | |
overtrain_info, | |
): | |
try: | |
print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") | |
model_dir_path = os.path.dirname(model_dir) | |
os.makedirs(model_dir_path, exist_ok=True) | |
if "best_epoch" in model_dir: | |
pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth" | |
else: | |
pth_file = f"{name}_{epoch}e_{step}s.pth" | |
pth_file_old_version_path = os.path.join( | |
model_dir_path, f"{pth_file}_old_version.pth" | |
) | |
model_dir_path = os.path.dirname(model_dir) | |
if os.path.exists(os.path.join(model_dir_path, "model_info.json")): | |
with open(os.path.join(model_dir_path, "model_info.json"), "r") as f: | |
data = json.load(f) | |
dataset_lenght = data.get("total_dataset_duration", None) | |
else: | |
dataset_lenght = None | |
with open(os.path.join(now_dir, "assets", "config.json"), "r") as f: | |
data = json.load(f) | |
model_author = data.get("model_author", None) | |
opt = OrderedDict( | |
weight={ | |
key: value.half() for key, value in ckpt.items() if "enc_q" not in key | |
} | |
) | |
opt["config"] = [ | |
hps.data.filter_length // 2 + 1, | |
32, | |
hps.model.inter_channels, | |
hps.model.hidden_channels, | |
hps.model.filter_channels, | |
hps.model.n_heads, | |
hps.model.n_layers, | |
hps.model.kernel_size, | |
hps.model.p_dropout, | |
hps.model.resblock, | |
hps.model.resblock_kernel_sizes, | |
hps.model.resblock_dilation_sizes, | |
hps.model.upsample_rates, | |
hps.model.upsample_initial_channel, | |
hps.model.upsample_kernel_sizes, | |
hps.model.spk_embed_dim, | |
hps.model.gin_channels, | |
hps.data.sample_rate, | |
] | |
opt["epoch"] = epoch | |
opt["step"] = step | |
opt["sr"] = sr | |
opt["f0"] = pitch_guidance | |
opt["version"] = version | |
opt["creation_date"] = datetime.datetime.now().isoformat() | |
hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}" | |
model_hash = hashlib.sha256(hash_input.encode()).hexdigest() | |
opt["model_hash"] = model_hash | |
opt["overtrain_info"] = overtrain_info | |
opt["dataset_lenght"] = dataset_lenght | |
opt["model_name"] = name | |
opt["author"] = model_author | |
torch.save(opt, os.path.join(model_dir_path, pth_file)) | |
model = torch.load(model_dir, map_location=torch.device("cpu")) | |
torch.save( | |
replace_keys_in_dict( | |
replace_keys_in_dict( | |
model, ".parametrizations.weight.original1", ".weight_v" | |
), | |
".parametrizations.weight.original0", | |
".weight_g", | |
), | |
pth_file_old_version_path, | |
) | |
os.remove(model_dir) | |
os.rename(pth_file_old_version_path, model_dir) | |
except Exception as error: | |
print(f"An error occurred extracting the model: {error}") | |