|
import argparse |
|
import json |
|
import os |
|
import pickle |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .config import FLAGS |
|
from .torch_model import Generator |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
def load_checkpoint(filepath, device): |
|
assert os.path.isfile(filepath) |
|
print("Loading '{}'".format(filepath)) |
|
checkpoint_dict = torch.load(filepath, map_location=device) |
|
print("Complete.") |
|
return checkpoint_dict |
|
|
|
|
|
def convert_to_haiku(a, h, device): |
|
generator = Generator(h).to(device) |
|
state_dict_g = load_checkpoint(a.checkpoint_file, device) |
|
generator.load_state_dict(state_dict_g["generator"]) |
|
generator.eval() |
|
generator.remove_weight_norm() |
|
hk_map = {} |
|
for a, b in generator.state_dict().items(): |
|
print(a, b.shape) |
|
if a.startswith("conv_pre"): |
|
a = "generator/~/conv1_d" |
|
elif a.startswith("conv_post"): |
|
a = "generator/~/conv1_d_1" |
|
elif a.startswith("ups."): |
|
ii = a.split(".")[1] |
|
a = f"generator/~/ups_{ii}" |
|
elif a.startswith("resblocks."): |
|
_, x, y, z, _ = a.split(".") |
|
ver = h.resblock |
|
a = f"generator/~/res_block{ver}_{x}/~/{y}_{z}" |
|
print(a, b.shape) |
|
if a not in hk_map: |
|
hk_map[a] = {} |
|
if len(b.shape) == 1: |
|
hk_map[a]["b"] = b.numpy() |
|
else: |
|
if "ups" in a: |
|
hk_map[a]["w"] = np.rot90(b.numpy(), k=1, axes=(0, 2)) |
|
elif "conv" in a: |
|
hk_map[a]["w"] = np.swapaxes(b.numpy(), 0, 2) |
|
else: |
|
hk_map[a]["w"] = b.numpy() |
|
|
|
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "wb") as f: |
|
pickle.dump(hk_map, f) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--checkpoint-file", required=True) |
|
parser.add_argument("--config-file", required=True) |
|
a = parser.parse_args() |
|
|
|
config_file = a.config_file |
|
with open(config_file) as f: |
|
data = f.read() |
|
|
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
|
|
device = torch.device("cpu") |
|
convert_to_haiku(a, h, device) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|