import math from inspect import isfunction from functools import partial import matplotlib.pyplot as plt from tqdm.auto import tqdm from einops import rearrange import torch from torch import nn, einsum import torch.nn.functional as F def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): """https://arxiv.org/abs/1512.03385""" def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if exists(time_emb_dim) else None ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.block1(x) if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) h = rearrange(time_emb, "b c -> b c 1 1") + h h = self.block2(h) return h + self.res_conv(x) class ConvNextBlock(nn.Module): """https://arxiv.org/abs/2201.03545""" def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): super().__init__() self.mlp = ( nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) if exists(time_emb_dim) else None ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) self.net = nn.Sequential( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) h = h + rearrange(condition, "b c -> b c 1 1") h = self.net(h) return h + self.res_conv(x) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) class LinearCrossAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32) -> None: super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False) self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x, cross_attend): b, c, h, w = x.shape q = self.to_q(x) k, v = self.to_kv(cross_attend).chunk(2, dim=1) q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads) k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads) v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.out(out) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim q = q.softmax(dim=-2) # calculate the softmax with respect to rows of k k = k.softmax(dim=-1) # normalize the values in the attention matrix q = q * self.scale # dot product of q and v matrices context = torch.einsum("b h d n, b h e n -> b h d e", k, v) # dot product of context and q out = torch.einsum("b h d e, b h d n -> b h e n", context, q) # rearrange the output to match the pytorch convention out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.GroupNorm(1, dim) def forward(self, x, *args, **kwargs): x = self.norm(x) return self.fn(x, *args, **kwargs)