|
from torch import nn |
|
import yaml |
|
import torch |
|
from omegaconf import OmegaConf |
|
from .vqgan import VQModel, GumbelVQ |
|
|
|
def load_config(config_path, display=False): |
|
config = OmegaConf.load(config_path) |
|
if display: |
|
print(yaml.dump(OmegaConf.to_container(config))) |
|
return config |
|
|
|
|
|
def load_vqgan(config, ckpt_path=None, is_gumbel=False): |
|
if is_gumbel: |
|
model = GumbelVQ(**config.model.params) |
|
else: |
|
model = VQModel(**config.model.params) |
|
if ckpt_path is not None: |
|
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
missing, unexpected = model.load_state_dict(sd, strict=False) |
|
return model.eval() |
|
|
|
|
|
class SDVQVAEWrapper(nn.Module): |
|
def __init__(self, name): |
|
super(SDVQVAEWrapper, self).__init__() |
|
raise NotImplementedError |
|
|
|
def encode(self, x): |
|
raise NotImplementedError |
|
|
|
def decode(self, x): |
|
raise NotImplementedError |
|
|