import os from typing import Tuple #from . import dist_util import PIL import numpy as np import torch as th from .script_util import ( create_gaussian_diffusion, create_model_and_diffusion, model_and_diffusion_defaults, ) # Sample from the base model. #@th.inference_mode() def sample( glide_model, glide_options, side_x, side_y, prompt, batch_size=1, guidance_scale=4, device="cpu", prediction_respacing="100", upsample_enabled=False, upsample_temp=0.997, mode = '', ): eval_diffusion = create_gaussian_diffusion( steps=glide_options["diffusion_steps"], learn_sigma=glide_options["learn_sigma"], noise_schedule=glide_options["noise_schedule"], predict_xstart=glide_options["predict_xstart"], rescale_timesteps=glide_options["rescale_timesteps"], rescale_learned_sigmas=glide_options["rescale_learned_sigmas"], timestep_respacing=prediction_respacing ) # Create the classifier-free guidance tokens (empty) full_batch_size = batch_size * 2 cond_ref = prompt['ref'] uncond_ref = th.ones_like(cond_ref) model_kwargs = {} model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).to(device) def cfg_model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] combined = th.cat([half, half], dim=0) model_out = glide_model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return th.cat([eps, rest], dim=1) if upsample_enabled: model_kwargs['low_res'] = prompt['low_res'].to(device) noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp model_fn = glide_model # just use the base model, no need for CFG. model_kwargs['ref'] = model_kwargs['ref'][:batch_size] samples = eval_diffusion.p_sample_loop( model_fn, (batch_size, 3, side_y, side_x), # only thing that's changed noise=noise, device=device, clip_denoised=True, progress=False, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] else: model_fn = cfg_model_fn # so we use CFG for the base model. noise = th.randn((batch_size, 3, side_y, side_x), device=device) noise = th.cat([noise, noise], 0) samples = eval_diffusion.p_sample_loop( model_fn, (full_batch_size, 3, side_y, side_x), # only thing that's changed noise=noise, device=device, clip_denoised=True, progress=False, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] return samples