|
import torch.nn as nn |
|
from .normalize import Normalize |
|
from .conv import CausalConv3d |
|
import torch |
|
import numpy as np |
|
from einops import rearrange |
|
from .block import Block |
|
from .ops import video_to_image |
|
|
|
class LinearAttention(Block): |
|
def __init__(self, dim, heads=4, dim_head=32): |
|
super().__init__() |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) |
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
qkv = self.to_qkv(x) |
|
q, k, v = rearrange( |
|
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 |
|
) |
|
k = k.softmax(dim=-1) |
|
context = torch.einsum("bhdn,bhen->bhde", k, v) |
|
out = torch.einsum("bhde,bhdn->bhen", context, q) |
|
out = rearrange( |
|
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w |
|
) |
|
return self.to_out(out) |
|
|
|
|
|
class LinAttnBlock(LinearAttention): |
|
"""to match AttnBlock usage""" |
|
|
|
def __init__(self, in_channels): |
|
super().__init__(dim=in_channels, heads=1, dim_head=in_channels) |
|
|
|
|
|
class AttnBlock3D(Block): |
|
"""Compatible with old versions, there are issues, use with caution.""" |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b, c, t, h, w = q.shape |
|
q = q.reshape(b * t, c, h * w) |
|
q = q.permute(0, 2, 1) |
|
k = k.reshape(b * t, c, h * w) |
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
v = v.reshape(b * t, c, h * w) |
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
h_ = h_.reshape(b, c, t, h, w) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
class AttnBlock3DFix(nn.Module): |
|
""" |
|
Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. |
|
""" |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
|
|
b, c, t, h, w = q.shape |
|
q = q.permute(0, 2, 1, 3, 4) |
|
q = q.reshape(b * t, c, h * w) |
|
q = q.permute(0, 2, 1) |
|
|
|
|
|
k = k.permute(0, 2, 1, 3, 4) |
|
k = k.reshape(b * t, c, h * w) |
|
|
|
|
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
|
|
|
|
v = v.permute(0, 2, 1, 3, 4) |
|
v = v.reshape(b * t, c, h * w) |
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
|
|
|
|
h_ = h_.reshape(b, t, c, h, w) |
|
h_ = h_.permute(0, 2, 1, 3 ,4) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
class AttnBlock(Block): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.k = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
@video_to_image |
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b, c, h, w = q.shape |
|
q = q.reshape(b, c, h * w) |
|
q = q.permute(0, 2, 1) |
|
k = k.reshape(b, c, h * w) |
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
v = v.reshape(b, c, h * w) |
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
h_ = h_.reshape(b, c, h, w) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
class TemporalAttnBlock(Block): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.k = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b, c, t, h, w = q.shape |
|
q = rearrange(q, "b c t h w -> (b h w) t c") |
|
k = rearrange(k, "b c t h w -> (b h w) c t") |
|
v = rearrange(v, "b c t h w -> (b h w) c t") |
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) |
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla"): |
|
assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" |
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") |
|
print(attn_type) |
|
if attn_type == "vanilla": |
|
return AttnBlock(in_channels) |
|
elif attn_type == "vanilla3D": |
|
return AttnBlock3D(in_channels) |
|
elif attn_type == "none": |
|
return nn.Identity(in_channels) |
|
else: |
|
return LinAttnBlock(in_channels) |