File size: 945 Bytes
b3f324b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from torch import nn
import yaml
import torch
from omegaconf import OmegaConf
from .vqgan import VQModel, GumbelVQ
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
if is_gumbel:
model = GumbelVQ(**config.model.params)
else:
model = VQModel(**config.model.params)
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)
return model.eval()
class SDVQVAEWrapper(nn.Module):
def __init__(self, name):
super(SDVQVAEWrapper, self).__init__()
raise NotImplementedError
def encode(self, x): # b c h w
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
|