File size: 2,548 Bytes
7a11626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from pathlib import Path
import torch
from ml_collections.config_flags import config_flags

from sde.config import get_config
from sde import ddpm, ncsnv2, ncsnpp  # need to import to trigger its registry
from sde import utils as mutils
from sde.ema import ExponentialMovingAverage

from adapt import ScoreAdapter

device = torch.device("cuda")


def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    # state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['model'].load_state_dict(loaded_state['model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['step'] = loaded_state['step']
    return state


def save_checkpoint(ckpt_dir, state):
    saved_state = {
        'optimizer': state['optimizer'].state_dict(),
        'model': state['model'].state_dict(),
        'ema': state['ema'].state_dict(),
        'step': state['step']
    }
    torch.save(saved_state, ckpt_dir)


class VESDE(ScoreAdapter):
    def __init__(self):
        config = get_config()
        config.device = device
        ckpt_fname = self.checkpoint_root() / "sde" / 'checkpoint_127.pth'

        score_model = mutils.create_model(config)
        ema = ExponentialMovingAverage(
            score_model.parameters(), decay=config.model.ema_rate
        )
        state = dict(model=score_model, ema=ema, step=0)
        self._data_shape = (
            config.data.num_channels, config.data.image_size, config.data.image_size
        )

        self._σ_min = float(config.model.sigma_min * 2)

        state = restore_checkpoint(ckpt_fname, state, device=config.device)
        ema.copy_to(score_model.parameters())

        score_model.eval()
        score_model = score_model.module  # remove DataParallel

        self.model = score_model
        self._device = device

    def data_shape(self):
        return self._data_shape

    @property
    def σ_min(self):
        return self._σ_min

    @torch.no_grad()
    def denoise(self, xs, σ):
        N = xs.shape[0]
        # see Karras eqn. 212-215 for the 1/2 σ correction
        cond_t = (0.5 * σ) * torch.ones(N, device=self.device)
        # note that the forward function the model has been modified; see comments
        n_hat = self.model(xs, cond_t)
        Ds = xs + σ * n_hat
        return Ds

    def unet_is_cond(self):
        return False

    def use_cls_guidance(self):
        return False

    def snap_t_to_nearest_tick(self, t):
        return super().snap_t_to_nearest_tick(t)