Spaces:
No application file
No application file
File size: 5,260 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
from __future__ import annotations
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
from ..data.data_util import batch_index_fill, batch_index_select
from . import Model_Register
@Model_Register.register
class TemporalConvLayer(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
"""
def __init__(
self,
in_dim,
out_dim=None,
dropout=0.0,
keep_content_condition: bool = False,
femb_channels: Optional[int] = None,
need_temporal_weight: bool = True,
):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.keep_content_condition = keep_content_condition
self.femb_channels = femb_channels
self.need_temporal_weight = need_temporal_weight
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
)
# zero out the last layer params,so the conv block is identity
# nn.init.zeros_(self.conv4[-1].weight)
# nn.init.zeros_(self.conv4[-1].bias)
self.temporal_weight = nn.Parameter(
torch.tensor(
[
1e-5,
]
)
) # initialize parameter with 0
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
self.skip_temporal_layers = False # Whether to skip temporal layer
def forward(
self,
hidden_states,
num_frames=1,
sample_index: torch.LongTensor = None,
vision_conditon_frames_sample_index: torch.LongTensor = None,
femb: torch.Tensor = None,
):
if self.skip_temporal_layers is True:
return hidden_states
hidden_states_dtype = hidden_states.dtype
hidden_states = rearrange(
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
)
identity = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv3(hidden_states)
hidden_states = self.conv4(hidden_states)
# 保留condition对应的frames,便于保持前序内容帧,提升一致性
if self.keep_content_condition:
mask = torch.ones_like(hidden_states, device=hidden_states.device)
mask = batch_index_fill(
mask, dim=2, index=vision_conditon_frames_sample_index, value=0
)
if self.need_temporal_weight:
hidden_states = (
identity + torch.abs(self.temporal_weight) * mask * hidden_states
)
else:
hidden_states = identity + mask * hidden_states
else:
if self.need_temporal_weight:
hidden_states = (
identity + torch.abs(self.temporal_weight) * hidden_states
)
else:
hidden_states = identity + hidden_states
hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
hidden_states = hidden_states.to(dtype=hidden_states_dtype)
return hidden_states
|