File size: 5,260 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py
from __future__ import annotations

from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer
from ..data.data_util import batch_index_fill, batch_index_select
from . import Model_Register


@Model_Register.register
class TemporalConvLayer(nn.Module):
    """
    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    """

    def __init__(
        self,
        in_dim,
        out_dim=None,
        dropout=0.0,
        keep_content_condition: bool = False,
        femb_channels: Optional[int] = None,
        need_temporal_weight: bool = True,
    ):
        super().__init__()
        out_dim = out_dim or in_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.keep_content_condition = keep_content_condition
        self.femb_channels = femb_channels
        self.need_temporal_weight = need_temporal_weight
        # conv layers
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_dim),
            nn.SiLU(),
            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )

        # zero out the last layer params,so the conv block is identity
        #         nn.init.zeros_(self.conv4[-1].weight)
        #         nn.init.zeros_(self.conv4[-1].bias)
        self.temporal_weight = nn.Parameter(
            torch.tensor(
                [
                    1e-5,
                ]
            )
        )  # initialize parameter with 0
        # zero out the last layer params,so the conv block is identity
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)
        self.skip_temporal_layers = False  # Whether to skip temporal layer

    def forward(
        self,
        hidden_states,
        num_frames=1,
        sample_index: torch.LongTensor = None,
        vision_conditon_frames_sample_index: torch.LongTensor = None,
        femb: torch.Tensor = None,
    ):
        if self.skip_temporal_layers is True:
            return hidden_states
        hidden_states_dtype = hidden_states.dtype
        hidden_states = rearrange(
            hidden_states, "(b t) c h w -> b c t h w", t=num_frames
        )
        identity = hidden_states
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)
        hidden_states = self.conv3(hidden_states)
        hidden_states = self.conv4(hidden_states)
        # 保留condition对应的frames,便于保持前序内容帧,提升一致性
        if self.keep_content_condition:
            mask = torch.ones_like(hidden_states, device=hidden_states.device)
            mask = batch_index_fill(
                mask, dim=2, index=vision_conditon_frames_sample_index, value=0
            )
            if self.need_temporal_weight:
                hidden_states = (
                    identity + torch.abs(self.temporal_weight) * mask * hidden_states
                )
            else:
                hidden_states = identity + mask * hidden_states
        else:
            if self.need_temporal_weight:
                hidden_states = (
                    identity + torch.abs(self.temporal_weight) * hidden_states
                )
            else:
                hidden_states = identity + hidden_states
        hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w")
        hidden_states = hidden_states.to(dtype=hidden_states_dtype)
        return hidden_states