Spaces:
Running
Running
import math | |
from inspect import isfunction | |
from functools import partial | |
import matplotlib.pyplot as plt | |
from tqdm.auto import tqdm | |
from einops import rearrange | |
import torch | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
def Upsample(dim): | |
return nn.ConvTranspose2d(dim, dim, 4, 2, 1) | |
def Downsample(dim): | |
return nn.Conv2d(dim, dim, 4, 2, 1) | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, time): | |
device = time.device | |
half_dim = self.dim // 2 | |
embeddings = math.log(10000) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
embeddings = time[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
return embeddings | |
class Block(nn.Module): | |
def __init__(self, dim, dim_out, groups=8): | |
super().__init__() | |
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
self.norm = nn.GroupNorm(groups, dim_out) | |
self.act = nn.SiLU() | |
def forward(self, x, scale_shift=None): | |
x = self.proj(x) | |
x = self.norm(x) | |
if exists(scale_shift): | |
scale, shift = scale_shift | |
x = x * (scale + 1) + shift | |
x = self.act(x) | |
return x | |
class ResnetBlock(nn.Module): | |
"""https://arxiv.org/abs/1512.03385""" | |
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): | |
super().__init__() | |
self.mlp = ( | |
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) | |
if exists(time_emb_dim) | |
else None | |
) | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None): | |
h = self.block1(x) | |
if exists(self.mlp) and exists(time_emb): | |
time_emb = self.mlp(time_emb) | |
h = rearrange(time_emb, "b c -> b c 1 1") + h | |
h = self.block2(h) | |
return h + self.res_conv(x) | |
class ConvNextBlock(nn.Module): | |
"""https://arxiv.org/abs/2201.03545""" | |
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): | |
super().__init__() | |
self.mlp = ( | |
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) | |
if exists(time_emb_dim) | |
else None | |
) | |
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) | |
self.net = nn.Sequential( | |
nn.GroupNorm(1, dim) if norm else nn.Identity(), | |
nn.Conv2d(dim, dim_out * mult, 3, padding=1), | |
nn.GELU(), | |
nn.GroupNorm(1, dim_out * mult), | |
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), | |
) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None): | |
h = self.ds_conv(x) | |
if exists(self.mlp) and exists(time_emb): | |
assert exists(time_emb), "time embedding must be passed in" | |
condition = self.mlp(time_emb) | |
h = h + rearrange(condition, "b c -> b c 1 1") | |
h = self.net(h) | |
return h + self.res_conv(x) | |
class Attention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
q = q * self.scale | |
sim = einsum("b h d i, b h d j -> b h i j", q, k) | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
out = einsum("b h i j, b h d j -> b h i d", attn, v) | |
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
return self.to_out(out) | |
class LinearCrossAttention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32) -> None: | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False) | |
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.out = nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x, cross_attend): | |
b, c, h, w = x.shape | |
q = self.to_q(x) | |
k, v = self.to_kv(cross_attend).chunk(2, dim=1) | |
q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads) | |
k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads) | |
v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads) | |
q = q * self.scale | |
sim = einsum("b h d i, b h d j -> b h i j", q, k) | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
out = einsum("b h i j, b h d j -> b h i d", attn, v) | |
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
return self.out(out) | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) | |
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
# calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim | |
q = q.softmax(dim=-2) | |
# calculate the softmax with respect to rows of k | |
k = k.softmax(dim=-1) | |
# normalize the values in the attention matrix | |
q = q * self.scale | |
# dot product of q and v matrices | |
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) | |
# dot product of context and q | |
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) | |
# rearrange the output to match the pytorch convention | |
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) | |
return self.to_out(out) | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.GroupNorm(1, dim) | |
def forward(self, x, *args, **kwargs): | |
x = self.norm(x) | |
return self.fn(x, *args, **kwargs) | |