File size: 3,340 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tops
import torch
from tops import checkpointer
from tops.config import instantiate
from tops.logger import warn
from dp2.generator.deep_privacy1 import MSGGenerator


def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
    state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
    if ckpt_mapper is not None:
        state = ckpt_mapper(state)
    if isinstance(G, MSGGenerator):
        G.load_state_dict(state)
    else:
        load_state_dict(G, state)
    tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
    if "w_centers" in ckpt:
        G.style_net.register_buffer("w_centers", ckpt["w_centers"])
        tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
    if "style_net.w_centers" in state:
        G.style_net.register_buffer("w_centers", state["style_net.w_centers"])
        tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")


def build_trained_generator(cfg, map_location=None):
    map_location = map_location if map_location is not None else tops.get_device()
    G = instantiate(cfg.generator)
    G.eval()
    G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
    if hasattr(cfg, "ckpt_mapper"):
        ckpt_mapper = instantiate(cfg.ckpt_mapper)
    else:
        ckpt_mapper = None
    if "model_url" in cfg.common:
        ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum)
        load_generator_state(ckpt, G, ckpt_mapper)
        return G.to(map_location)
    try:
        ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
        load_generator_state(ckpt, G, ckpt_mapper)
    except FileNotFoundError as e:
        tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
    return G.to(map_location)


def build_trained_discriminator(cfg, map_location=None):
    map_location = map_location if map_location is not None else tops.get_device()
    D = instantiate(cfg.discriminator).to(map_location)
    D.eval()
    try:
        ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
        if hasattr(cfg, "ckpt_mapper_D"):
            ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
        D.load_state_dict(ckpt["discriminator"])
    except FileNotFoundError as e:
        tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
    return D


def load_state_dict(module: torch.nn.Module, state_dict: dict):
    module_sd = module.state_dict()
    to_remove = []
    for key, item in state_dict.items():
        if key not in module_sd:
            continue
        if item.shape != module_sd[key].shape:
            to_remove.append(key)
            warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
    for key in to_remove:
        state_dict.pop(key)
    for key, item in state_dict.items():
        if key not in module_sd:
            warn(f"Did not fin key in model state dict: {key}")
    for key, item in module_sd.items():
        if key not in state_dict:
            warn(f"Did not find key in state dict: {key}")
    module.load_state_dict(state_dict, strict=False)