Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import logging | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
from numpy import ndarray | |
import torch | |
from torch import Generator, FloatTensor | |
from diffusers.schedulers.scheduling_euler_discrete import ( | |
EulerDiscreteScheduler as DiffusersEulerDiscreteScheduler, | |
EulerDiscreteSchedulerOutput, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from ..utils.noise_util import video_fusion_noise | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
class EulerDiscreteScheduler(DiffusersEulerDiscreteScheduler): | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: ndarray | List[float] | None = None, | |
prediction_type: str = "epsilon", | |
interpolation_type: str = "linear", | |
use_karras_sigmas: bool | None = False, | |
timestep_spacing: str = "linspace", | |
steps_offset: int = 0, | |
): | |
super().__init__( | |
num_train_timesteps, | |
beta_start, | |
beta_end, | |
beta_schedule, | |
trained_betas, | |
prediction_type, | |
interpolation_type, | |
use_karras_sigmas, | |
timestep_spacing, | |
steps_offset, | |
) | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
s_churn: float = 0.0, | |
s_tmin: float = 0.0, | |
s_tmax: float = float("inf"), | |
s_noise: float = 1.0, | |
generator: Optional[torch.Generator] = None, | |
return_dict: bool = True, | |
w_ind_noise: float = 0.5, | |
noise_type: str = "random", | |
) -> Union[EulerDiscreteSchedulerOutput, Tuple]: | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
s_churn (`float`): | |
s_tmin (`float`): | |
s_tmax (`float`): | |
s_noise (`float`, defaults to 1.0): | |
Scaling factor for noise added to the sample. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`): | |
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or | |
tuple. | |
Returns: | |
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is | |
returned, otherwise a tuple is returned where the first element is the sample tensor. | |
""" | |
if ( | |
isinstance(timestep, int) | |
or isinstance(timestep, torch.IntTensor) | |
or isinstance(timestep, torch.LongTensor) | |
): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if not self.is_scale_input_called: | |
logger.warning( | |
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
"See `StableDiffusionPipeline` for a usage example." | |
) | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
sigma = self.sigmas[self.step_index] | |
gamma = ( | |
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigma <= s_tmax | |
else 0.0 | |
) | |
device = model_output.device | |
if noise_type == "random": | |
noise = randn_tensor( | |
model_output.shape, | |
dtype=model_output.dtype, | |
device=device, | |
generator=generator, | |
) | |
elif noise_type == "video_fusion": | |
noise = video_fusion_noise( | |
model_output, w_ind_noise=w_ind_noise, generator=generator | |
) | |
eps = noise * s_noise | |
sigma_hat = sigma * (gamma + 1) | |
if gamma > 0: | |
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | |
# backwards compatibility | |
if ( | |
self.config.prediction_type == "original_sample" | |
or self.config.prediction_type == "sample" | |
): | |
pred_original_sample = model_output | |
elif self.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma_hat * model_output | |
elif self.config.prediction_type == "v_prediction": | |
# * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( | |
sample / (sigma**2 + 1) | |
) | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma_hat | |
dt = self.sigmas[self.step_index + 1] - sigma_hat | |
prev_sample = sample + derivative * dt | |
# upon completion increase step index by one | |
self._step_index += 1 | |
if not return_dict: | |
return (prev_sample,) | |
return EulerDiscreteSchedulerOutput( | |
prev_sample=prev_sample, pred_original_sample=pred_original_sample | |
) | |
def step_bk( | |
self, | |
model_output: FloatTensor, | |
timestep: float | FloatTensor, | |
sample: FloatTensor, | |
s_churn: float = 0, | |
s_tmin: float = 0, | |
s_tmax: float = float("inf"), | |
s_noise: float = 1, | |
generator: Generator | None = None, | |
return_dict: bool = True, | |
w_ind_noise: float = 0.5, | |
noise_type: str = "random", | |
) -> EulerDiscreteSchedulerOutput | Tuple: | |
""" | |
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): direct output from learned diffusion model. | |
timestep (`float`): current timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
current instance of sample being created by diffusion process. | |
s_churn (`float`) | |
s_tmin (`float`) | |
s_tmax (`float`) | |
s_noise (`float`) | |
generator (`torch.Generator`, optional): Random number generator. | |
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class | |
Returns: | |
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: | |
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a | |
`tuple`. When returning a tuple, the first element is the sample tensor. | |
""" | |
if ( | |
isinstance(timestep, int) | |
or isinstance(timestep, torch.IntTensor) | |
or isinstance(timestep, torch.LongTensor) | |
): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if not self.is_scale_input_called: | |
logger.warning( | |
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
"See `StableDiffusionPipeline` for a usage example." | |
) | |
if isinstance(timestep, torch.Tensor): | |
timestep = timestep.to(self.timesteps.device) | |
step_index = (self.timesteps == timestep).nonzero().item() | |
sigma = self.sigmas[step_index] | |
gamma = ( | |
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigma <= s_tmax | |
else 0.0 | |
) | |
device = model_output.device | |
if noise_type == "random": | |
noise = randn_tensor( | |
model_output.shape, | |
dtype=model_output.dtype, | |
device=device, | |
generator=generator, | |
) | |
elif noise_type == "video_fusion": | |
noise = video_fusion_noise( | |
model_output, w_ind_noise=w_ind_noise, generator=generator | |
) | |
eps = noise * s_noise | |
sigma_hat = sigma * (gamma + 1) | |
if gamma > 0: | |
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | |
# backwards compatibility | |
if ( | |
self.config.prediction_type == "original_sample" | |
or self.config.prediction_type == "sample" | |
): | |
pred_original_sample = model_output | |
elif self.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma_hat * model_output | |
elif self.config.prediction_type == "v_prediction": | |
# * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( | |
sample / (sigma**2 + 1) | |
) | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma_hat | |
dt = self.sigmas[step_index + 1] - sigma_hat | |
prev_sample = sample + derivative * dt | |
if not return_dict: | |
return (prev_sample,) | |
return EulerDiscreteSchedulerOutput( | |
prev_sample=prev_sample, pred_original_sample=pred_original_sample | |
) | |