anime_diffusion / models /structure /Advanced_Network_Helpers.py
pawlo2013's picture
fixed code readability
5086590
import math
from inspect import isfunction
from functools import partial
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
"""https://arxiv.org/abs/1512.03385"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
if exists(time_emb_dim)
else None
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.block1(x)
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
h = rearrange(time_emb, "b c -> b c 1 1") + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
"""https://arxiv.org/abs/2201.03545"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
self.net = nn.Sequential(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
h = h + rearrange(condition, "b c -> b c 1 1")
h = self.net(h)
return h + self.res_conv(x)
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class LinearCrossAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32) -> None:
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x, cross_attend):
b, c, h, w = x.shape
q = self.to_q(x)
k, v = self.to_kv(cross_attend).chunk(2, dim=1)
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
q = q.softmax(dim=-2)
# calculate the softmax with respect to rows of k
k = k.softmax(dim=-1)
# normalize the values in the attention matrix
q = q * self.scale
# dot product of q and v matrices
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
# dot product of context and q
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
# rearrange the output to match the pytorch convention
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)