Spaces:
Runtime error
Runtime error
import tops | |
import torch | |
from tops import checkpointer | |
from tops.config import instantiate | |
from tops.logger import warn | |
from dp2.generator.deep_privacy1 import MSGGenerator | |
def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None): | |
state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"] | |
if ckpt_mapper is not None: | |
state = ckpt_mapper(state) | |
if isinstance(G, MSGGenerator): | |
G.load_state_dict(state) | |
else: | |
load_state_dict(G, state) | |
tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M") | |
if "w_centers" in ckpt: | |
G.style_net.register_buffer("w_centers", ckpt["w_centers"]) | |
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") | |
if "style_net.w_centers" in state: | |
G.style_net.register_buffer("w_centers", state["style_net.w_centers"]) | |
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") | |
def build_trained_generator(cfg, map_location=None): | |
map_location = map_location if map_location is not None else tops.get_device() | |
G = instantiate(cfg.generator) | |
G.eval() | |
G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None | |
if hasattr(cfg, "ckpt_mapper"): | |
ckpt_mapper = instantiate(cfg.ckpt_mapper) | |
else: | |
ckpt_mapper = None | |
if "model_url" in cfg.common: | |
ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum) | |
load_generator_state(ckpt, G, ckpt_mapper) | |
return G.to(map_location) | |
try: | |
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") | |
load_generator_state(ckpt, G, ckpt_mapper) | |
except FileNotFoundError as e: | |
tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}") | |
return G.to(map_location) | |
def build_trained_discriminator(cfg, map_location=None): | |
map_location = map_location if map_location is not None else tops.get_device() | |
D = instantiate(cfg.discriminator).to(map_location) | |
D.eval() | |
try: | |
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") | |
if hasattr(cfg, "ckpt_mapper_D"): | |
ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"]) | |
D.load_state_dict(ckpt["discriminator"]) | |
except FileNotFoundError as e: | |
tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}") | |
return D | |
def load_state_dict(module: torch.nn.Module, state_dict: dict): | |
module_sd = module.state_dict() | |
to_remove = [] | |
for key, item in state_dict.items(): | |
if key not in module_sd: | |
continue | |
if item.shape != module_sd[key].shape: | |
to_remove.append(key) | |
warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}") | |
for key in to_remove: | |
state_dict.pop(key) | |
for key, item in state_dict.items(): | |
if key not in module_sd: | |
warn(f"Did not fin key in model state dict: {key}") | |
for key, item in module_sd.items(): | |
if key not in state_dict: | |
warn(f"Did not find key in state dict: {key}") | |
module.load_state_dict(state_dict, strict=False) | |