File size: 2,919 Bytes
bd366ed
 
f4a50a2
bd366ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430429e
bd366ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430429e
bd366ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from typing import Tuple
#from . import dist_util
import PIL
import numpy as np
import torch as th
from .script_util import (
    create_gaussian_diffusion,
    create_model_and_diffusion,
    model_and_diffusion_defaults,
)

# Sample from the base model.

#@th.inference_mode()
def sample(
    glide_model,
    glide_options,
    side_x,
    side_y,
    prompt,
    batch_size=1,
    guidance_scale=4,
    device="cpu",
    prediction_respacing="100",
    upsample_enabled=False,
    upsample_temp=0.997,
    mode = '',
):

    eval_diffusion = create_gaussian_diffusion(
        steps=glide_options["diffusion_steps"],
        learn_sigma=glide_options["learn_sigma"],
        noise_schedule=glide_options["noise_schedule"],
        predict_xstart=glide_options["predict_xstart"],
        rescale_timesteps=glide_options["rescale_timesteps"],
        rescale_learned_sigmas=glide_options["rescale_learned_sigmas"],
        timestep_respacing=prediction_respacing
    )
 
    # Create the classifier-free guidance tokens (empty)
    full_batch_size = batch_size * 2
    cond_ref   =  prompt['ref']
    uncond_ref = th.ones_like(cond_ref) 
    
    model_kwargs = {}
    model_kwargs['ref'] =  th.cat([cond_ref, uncond_ref], 0).to(device)

    def cfg_model_fn(x_t, ts, **kwargs):
        half = x_t[: len(x_t) // 2]
        combined = th.cat([half, half], dim=0)
        model_out = glide_model(combined, ts, **kwargs)
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
 
        half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)

        eps = th.cat([half_eps, half_eps], dim=0)
        return th.cat([eps, rest], dim=1)


    if upsample_enabled:
        model_kwargs['low_res'] = prompt['low_res'].to(device)
        noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
        model_fn = glide_model # just use the base model, no need for CFG.
        model_kwargs['ref'] =  model_kwargs['ref'][:batch_size]

        samples = eval_diffusion.p_sample_loop(
        model_fn,
        (batch_size, 3, side_y, side_x),  # only thing that's changed
        noise=noise,
        device=device,
        clip_denoised=True,
        progress=False,
        model_kwargs=model_kwargs,
        cond_fn=None,
    )[:batch_size]

    else:
        model_fn = cfg_model_fn # so we use CFG for the base model.
        noise = th.randn((batch_size, 3, side_y, side_x), device=device) 
        noise = th.cat([noise, noise], 0)  
        samples = eval_diffusion.p_sample_loop(
            model_fn,
            (full_batch_size, 3, side_y, side_x),  # only thing that's changed
            noise=noise,
            device=device,
            clip_denoised=True,
            progress=False,
            model_kwargs=model_kwargs,
            cond_fn=None,
        )[:batch_size]
    
    return samples