# Copyright 2023 The HuggingFace Team. All rights reserved. # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs): assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class Downsample1D(nn.Module): """A 1D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs): assert inputs.shape[1] == self.channels return self.conv(inputs) class Upsample2D(nn.Module): """A 2D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch # https://github.com/pytorch/pytorch/issues/86679 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) return hidden_states class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class FirUpsample2D(nn.Module): """A 2D FIR upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.use_conv = use_conv self.fir_kernel = fir_kernel self.out_channels = out_channels def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 # Setup filter kernel. if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) if self.use_conv: convH = weight.shape[2] convW = weight.shape[3] inC = weight.shape[1] pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. output_shape = ( (hidden_states.shape[2] - 1) * factor + convH, (hidden_states.shape[3] - 1) * factor + convW, ) output_padding = ( output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 ) output = upfirdn2d_native( inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height class FirDownsample2D(nn.Module): """A 2D FIR downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.use_conv = use_conv self.out_channels = out_channels def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return hidden_states # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv2d(inputs, weight, stride=2) class KUpsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or "ada_group" for a stronger conditioning with scale and shift. kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", skip_time_act=False, time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, up=False, down=False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) elif self.time_embedding_norm == "spatial": self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) def forward(self, input_tensor, temb): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: return tensor[:, :, None, :] elif len(tensor.shape) == 4: return tensor[:, :, 0, :] else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() def forward(self, inputs): intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) intermediate_repr = rearrange_dims(intermediate_repr) output = self.mish(intermediate_repr) return output # unet_rl.py class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) def forward(self, inputs, t): """ Args: inputs : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] """ t = self.time_emb_act(t) t = self.time_emb(t) out = self.conv_in(inputs) + rearrange_dims(t) out = self.conv_out(out) return out + self.residual_conv(inputs) def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: multiple of the upsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) return output def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] _, channel, in_h, in_w = tensor.shape tensor = tensor.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ def __init__(self, in_dim, out_dim=None, dropout=0.0): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim self.out_dim = out_dim # conv layers self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) def forward(self, hidden_states, num_frames=1): hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) ) identity = hidden_states hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.conv3(hidden_states) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states