File size: 2,332 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import json
import os
import pickle

import numpy as np
import torch

from .config import FLAGS
from .torch_model import Generator


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict


def convert_to_haiku(a, h, device):
    generator = Generator(h).to(device)
    state_dict_g = load_checkpoint(a.checkpoint_file, device)
    generator.load_state_dict(state_dict_g["generator"])
    generator.eval()
    generator.remove_weight_norm()
    hk_map = {}
    for a, b in generator.state_dict().items():
        print(a, b.shape)
        if a.startswith("conv_pre"):
            a = "generator/~/conv1_d"
        elif a.startswith("conv_post"):
            a = "generator/~/conv1_d_1"
        elif a.startswith("ups."):
            ii = a.split(".")[1]
            a = f"generator/~/ups_{ii}"
        elif a.startswith("resblocks."):
            _, x, y, z, _ = a.split(".")
            ver = h.resblock
            a = f"generator/~/res_block{ver}_{x}/~/{y}_{z}"
        print(a, b.shape)
        if a not in hk_map:
            hk_map[a] = {}
        if len(b.shape) == 1:
            hk_map[a]["b"] = b.numpy()
        else:
            if "ups" in a:
                hk_map[a]["w"] = np.rot90(b.numpy(), k=1, axes=(0, 2))
            elif "conv" in a:
                hk_map[a]["w"] = np.swapaxes(b.numpy(), 0, 2)
            else:
                hk_map[a]["w"] = b.numpy()

    FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
    with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "wb") as f:
        pickle.dump(hk_map, f)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint-file", required=True)
    parser.add_argument("--config-file", required=True)
    a = parser.parse_args()

    config_file = a.config_file
    with open(config_file) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)

    device = torch.device("cpu")
    convert_to_haiku(a, h, device)


if __name__ == "__main__":
    main()