|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from fast_pytorch_kmeans import KMeans |
|
from torch import einsum |
|
import torch.distributed as dist |
|
from einops import rearrange |
|
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim): |
|
""" |
|
This matches the implementation in Denoising Diffusion Probabilistic Models: |
|
From Fairseq. |
|
Build sinusoidal embeddings. |
|
This matches the implementation in tensor2tensor, but differs slightly |
|
from the description in Section 3.5 of "Attention Is All You Need". |
|
""" |
|
assert len(timesteps.shape) == 1 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) |
|
emb = emb.to(device=timesteps.device) |
|
emb = timesteps.float()[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0,1,0,0)) |
|
return emb |
|
|
|
|
|
def nonlinearity(x): |
|
|
|
return x*torch.sigmoid(x) |
|
|
|
|
|
def Normalize(in_channels): |
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
self.conv = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
def forward(self, x): |
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
if self.with_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
|
|
self.conv = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=0) |
|
|
|
def forward(self, x): |
|
if self.with_conv: |
|
pad = (0,1,0,1) |
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
else: |
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
|
return x |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout) |
|
self.conv2 = torch.nn.Conv2d(out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
def forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x+h |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.k = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.v = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.proj_out = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b,c,h,w = q.shape |
|
q = q.reshape(b,c,h*w) |
|
q = q.permute(0,2,1) |
|
k = k.reshape(b,c,h*w) |
|
w_ = torch.bmm(q,k) |
|
w_ = w_ * (int(c)**(-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
v = v.reshape(b,c,h*w) |
|
w_ = w_.permute(0,2,1) |
|
h_ = torch.bmm(v,w_) |
|
h_ = h_.reshape(b,c,h,w) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x+h_ |
|
|
|
|
|
class Swish(nn.Module): |
|
def forward(self, x): |
|
return x * torch.sigmoid(x) |
|
|
|
|
|
class Encoder(nn.Module): |
|
""" |
|
Encoder of VQ-GAN to map input batch of images to latent space. |
|
Dimension Transformations: |
|
3x256x256 --Conv2d--> 32x256x256 |
|
for loop: |
|
--ResBlock--> 64x256x256 --DownBlock--> 64x128x128 |
|
--ResBlock--> 128x128x128 --DownBlock--> 128x64x64 |
|
--ResBlock--> 256x64x64 --DownBlock--> 256x32x32 |
|
--ResBlock--> 512x32x32 |
|
--ResBlock--> 512x32x32 |
|
--NonLocalBlock--> 512x32x32 |
|
--ResBlock--> 512x32x32 |
|
--GroupNorm--> |
|
--Swish--> |
|
--Conv2d-> 256x32x32 |
|
""" |
|
|
|
def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs): |
|
super(Encoder, self).__init__() |
|
layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)] |
|
for i in range(len(channels) - 1): |
|
in_channels = channels[i] |
|
out_channels = channels[i + 1] |
|
for j in range(num_res_blocks): |
|
layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0)) |
|
in_channels = out_channels |
|
if resolution in attn_resolutions: |
|
layers.append(AttnBlock(in_channels)) |
|
if i < len(channels) - 2: |
|
layers.append(Downsample(channels[i + 1], with_conv=True)) |
|
resolution //= 2 |
|
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0)) |
|
layers.append(AttnBlock(channels[-1])) |
|
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0)) |
|
layers.append(Normalize(channels[-1])) |
|
layers.append(Swish()) |
|
layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1)) |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs): |
|
super(Decoder, self).__init__() |
|
ch_mult = channels[1:] |
|
num_resolutions = len(ch_mult) |
|
block_in = ch_mult[num_resolutions - 1] |
|
curr_res = resolution// 2 ** (num_resolutions - 1) |
|
|
|
layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1), |
|
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0), |
|
AttnBlock(block_in), |
|
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0) |
|
] |
|
|
|
for i in reversed(range(num_resolutions)): |
|
block_out = ch_mult[i] |
|
for i_block in range(num_res_blocks+1): |
|
layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0)) |
|
block_in = block_out |
|
if curr_res in attn_resolutions: |
|
layers.append(AttnBlock(block_in)) |
|
if i > 0: |
|
layers.append(Upsample(block_in, with_conv=True)) |
|
curr_res = curr_res * 2 |
|
|
|
layers.append(Normalize(block_in)) |
|
layers.append(Swish()) |
|
layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)) |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Codebook(nn.Module): |
|
""" |
|
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly |
|
avoids costly matrix multiplications and allows for post-hoc remapping of indices. |
|
""" |
|
def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5): |
|
super().__init__() |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.beta = beta |
|
|
|
self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim) |
|
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) |
|
|
|
self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2 |
|
self.q_counter = 0 |
|
self.reservoir_size = int(reservoir_size) |
|
self.reservoir = None |
|
|
|
def forward(self, z): |
|
z = rearrange(z, 'b c h w -> b h w c').contiguous() |
|
batch_size = z.size(0) |
|
z_flattened = z.view(-1, self.codebook_dim) |
|
if self.training: |
|
self.q_counter += 1 |
|
|
|
if self.q_counter > self.q_start_collect: |
|
z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim) |
|
z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim) |
|
self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0) |
|
self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach() |
|
if self.q_counter < self.q_init: |
|
z_q = rearrange(z, 'b h w c -> b c h w').contiguous() |
|
return z_q, z_q.new_tensor(0), None |
|
else: |
|
|
|
if self.q_init <= self.q_counter < self.q_re_end: |
|
if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1: |
|
kmeans = KMeans(n_clusters=self.codebook_size) |
|
world_size = dist.get_world_size() |
|
print("Updating codebook from reservoir.") |
|
if world_size > 1: |
|
global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)] |
|
dist.all_gather(global_reservoir, self.reservoir.clone()) |
|
global_reservoir = torch.cat(global_reservoir, dim=0) |
|
else: |
|
global_reservoir = self.reservoir |
|
kmeans.fit_predict(global_reservoir) |
|
self.embedding.weight.data = kmeans.centroids.detach() |
|
|
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ |
|
torch.sum(self.embedding.weight**2, dim=1) - 2 * \ |
|
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) |
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
z_q = self.embedding(min_encoding_indices).view(z.shape) |
|
|
|
|
|
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) |
|
|
|
|
|
z_q = z + (z_q - z).detach() |
|
|
|
|
|
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() |
|
|
|
return z_q, loss, min_encoding_indices |
|
|
|
def get_codebook_entry(self, indices, shape): |
|
|
|
z_q = self.embedding(indices) |
|
|
|
if shape is not None: |
|
z_q = z_q.view(shape) |
|
|
|
z_q = z_q.permute(0, 3, 1, 2).contiguous() |
|
|
|
return z_q |
|
|
|
|
|
if __name__ == '__main__': |
|
enc = Encoder() |
|
dec = Decoder() |
|
print(sum([p.numel() for p in enc.parameters()])) |
|
print(sum([p.numel() for p in dec.parameters()])) |
|
x = torch.randn(1, 3, 512, 512) |
|
res = enc(x) |
|
print(res.shape) |
|
res = dec(res) |
|
print(res.shape) |
|
|