Spaces:
Paused
Paused
import torch | |
from omegaconf import OmegaConf | |
from sgm.util import instantiate_from_config | |
from sgm.modules.diffusionmodules.sampling import * | |
SD_XL_BASE_RATIOS = { | |
"0.5": (704, 1408), | |
"0.52": (704, 1344), | |
"0.57": (768, 1344), | |
"0.6": (768, 1280), | |
"0.68": (832, 1216), | |
"0.72": (832, 1152), | |
"0.78": (896, 1152), | |
"0.82": (896, 1088), | |
"0.88": (960, 1088), | |
"0.94": (960, 1024), | |
"1.0": (1024, 1024), | |
"1.07": (1024, 960), | |
"1.13": (1088, 960), | |
"1.21": (1088, 896), | |
"1.29": (1152, 896), | |
"1.38": (1152, 832), | |
"1.46": (1216, 832), | |
"1.67": (1280, 768), | |
"1.75": (1344, 768), | |
"1.91": (1344, 704), | |
"2.0": (1408, 704), | |
"2.09": (1472, 704), | |
"2.4": (1536, 640), | |
"2.5": (1600, 640), | |
"2.89": (1664, 576), | |
"3.0": (1728, 576), | |
} | |
def init_model(cfgs): | |
model_cfg = OmegaConf.load(cfgs.model_cfg_path) | |
ckpt = cfgs.load_ckpt_path | |
model = instantiate_from_config(model_cfg.model) | |
model.init_from_ckpt(ckpt) | |
if cfgs.type == "train": | |
model.train() | |
else: | |
if cfgs.use_gpu: | |
model.to(torch.device("cuda", index=cfgs.gpu)) | |
model.eval() | |
model.freeze() | |
return model | |
def init_sampling(cfgs): | |
discretization_config = { | |
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", | |
} | |
if cfgs.dual_conditioner: | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.DualCFG", | |
"params": {"scale": cfgs.scale}, | |
} | |
sampler = EulerEDMDualSampler( | |
num_steps=cfgs.steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=999.0, | |
s_noise=1.0, | |
verbose=True, | |
device=torch.device("cuda", index=cfgs.gpu) | |
) | |
else: | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", | |
"params": {"scale": cfgs.scale[0]}, | |
} | |
sampler = EulerEDMSampler( | |
num_steps=cfgs.steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=999.0, | |
s_noise=1.0, | |
verbose=True, | |
device=torch.device("cuda", index=cfgs.gpu) | |
) | |
return sampler | |
def deep_copy(batch): | |
c_batch = {} | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor): | |
c_batch[key] = torch.clone(batch[key]) | |
elif isinstance(batch[key], (tuple, list)): | |
c_batch[key] = batch[key].copy() | |
else: | |
c_batch[key] = batch[key] | |
return c_batch | |
def prepare_batch(cfgs, batch): | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu: | |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) | |
if not cfgs.dual_conditioner: | |
batch_uc = deep_copy(batch) | |
if "ntxt" in batch: | |
batch_uc["txt"] = batch["ntxt"] | |
else: | |
batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))] | |
if "label" in batch: | |
batch_uc["label"] = ["" for _ in range(len(batch["label"]))] | |
return batch, batch_uc, None | |
else: | |
batch_uc_1 = deep_copy(batch) | |
batch_uc_2 = deep_copy(batch) | |
batch_uc_1["ref"] = torch.zeros_like(batch["ref"]) | |
batch_uc_2["ref"] = torch.zeros_like(batch["ref"]) | |
batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))] | |
return batch, batch_uc_1, batch_uc_2 |