File size: 7,471 Bytes
056ab49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import numpy as np

class DDPMSampler:

    def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
        # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
        # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
        self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2  #beta
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)  # alpha bar
        self.one = torch.tensor(1.0)

        self.generator = generator

        self.num_train_timesteps = num_training_steps
        self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) ##[999, 998, ...0]

    def set_inference_timesteps(self, num_inference_steps=50):
      # num_inference_steps = 50
      # step ratio = num_training_steps // inference_steps = 20
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_train_timesteps // self.num_inference_steps # 1000/50 = 20
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)  #[980, 960, ..0]
        self.timesteps = torch.from_numpy(timesteps)

    def _get_previous_timestep(self, timestep: int) -> int:
        prev_t = timestep - self.num_train_timesteps // self.num_inference_steps  #eg: t = 960, t-1 = 960-20 = 940
        return prev_t

    def _get_variance(self, timestep: int) -> torch.Tensor:
        prev_t = self._get_previous_timestep(timestep)  #t-1

        alpha_prod_t = self.alphas_cumprod[timestep] #alpha bar t
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one  #alpha bar t minus 1
        current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev  #beta t

        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
        # and sample from it to get previous sample
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t  #variance#

        # we always take the log of variance, so clamp it to ensure it's not 0
        variance = torch.clamp(variance, min=1e-20)

        return variance

    def set_strength(self, strength=1):
        """

            Set how much noise to add to the input image.

            More noise (strength ~ 1) means that the output will be further from the input image.

            Less noise (strength ~ 0) means that the output will be closer to the input image.

        """
        # more strength -> start step is approximately 0 that is model starts from pure noise and generates the image from it, strength = 1,  start step = 50 - (50 * 1) = 0
        # less strenght -> start step is skipped till 50 so model has the less noisified image a time step 50, model reconstructs the image from the less noisified image, strength = 0, start_step = 50

        # start_step is the number of noise levels to skip
        #eg inf_steps = 50, strength = 1,  start step = 50 - (50 * 1) = 0, strength = 0, start_step = 50
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        self.timesteps = self.timesteps[start_step:]  #skip time_steps, if start_step = 50 8#
        self.start_step = start_step  #50, in this case

    def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        t = timestep #t
        prev_t = self._get_previous_timestep(t)  #t-1

        # 1. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[t] #alpha_bar_t
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one  #alpha_bar_t-1
        beta_prod_t = 1 - alpha_prod_t  #beta_bar_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev #beta_bar_t-1
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev  #alpha_t
        current_beta_t = 1 - current_alpha_t  #beta_t

        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)  #x_0 - gaussian noise

        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t  #coeff_x_0
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t  #coff_x_t

        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents  #

        # 6. Add noise
        variance = 0
        if t > 0:
            device = model_output.device
            noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
            # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            variance = (self._get_variance(t) ** 0.5) * noise

        # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # the variable "variance" is already multiplied by the noise N(0, 1)
        pred_prev_sample = pred_prev_sample + variance  #predicted xt-1

        return pred_prev_sample

    def add_noise(

        self,

        original_samples: torch.FloatTensor,

        timesteps: torch.IntTensor,

    ) -> torch.FloatTensor:
    #forward noisification
    #q(xt | x_not) = N(xt; sqrt(alpha_cumprod); (1 - alpha_cumprod)I)
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)  #alpha_bar
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5  #sqrt(alpha_bar_t)
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()  #flatten
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):  #for boardcasting
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5  #sqrt(1 - alpha_bar_t)
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
        # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)  #noise
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise  #noisy samples
        return noisy_samples