open-chameleon / modules.py
alandao's picture
Upload modules.py with huggingface_hub
ddf288e verified
raw
history blame
14.5 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fast_pytorch_kmeans import KMeans
from torch import einsum
import torch.distributed as dist
from einops import rearrange
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class Encoder(nn.Module):
"""
Encoder of VQ-GAN to map input batch of images to latent space.
Dimension Transformations:
3x256x256 --Conv2d--> 32x256x256
for loop:
--ResBlock--> 64x256x256 --DownBlock--> 64x128x128
--ResBlock--> 128x128x128 --DownBlock--> 128x64x64
--ResBlock--> 256x64x64 --DownBlock--> 256x32x32
--ResBlock--> 512x32x32
--ResBlock--> 512x32x32
--NonLocalBlock--> 512x32x32
--ResBlock--> 512x32x32
--GroupNorm-->
--Swish-->
--Conv2d-> 256x32x32
"""
def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
super(Encoder, self).__init__()
layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)]
for i in range(len(channels) - 1):
in_channels = channels[i]
out_channels = channels[i + 1]
for j in range(num_res_blocks):
layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0))
in_channels = out_channels
if resolution in attn_resolutions:
layers.append(AttnBlock(in_channels))
if i < len(channels) - 2:
layers.append(Downsample(channels[i + 1], with_conv=True))
resolution //= 2
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
layers.append(AttnBlock(channels[-1]))
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
layers.append(Normalize(channels[-1]))
layers.append(Swish())
layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Decoder(nn.Module):
def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
super(Decoder, self).__init__()
ch_mult = channels[1:]
num_resolutions = len(ch_mult)
block_in = ch_mult[num_resolutions - 1]
curr_res = resolution// 2 ** (num_resolutions - 1)
layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1),
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0),
AttnBlock(block_in),
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0)
]
for i in reversed(range(num_resolutions)):
block_out = ch_mult[i]
for i_block in range(num_res_blocks+1):
layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0))
block_in = block_out
if curr_res in attn_resolutions:
layers.append(AttnBlock(block_in))
if i > 0:
layers.append(Upsample(block_in, with_conv=True))
curr_res = curr_res * 2
layers.append(Normalize(block_in))
layers.append(Swish())
layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Codebook(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.beta = beta
self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2
self.q_counter = 0
self.reservoir_size = int(reservoir_size)
self.reservoir = None
def forward(self, z):
z = rearrange(z, 'b c h w -> b h w c').contiguous()
batch_size = z.size(0)
z_flattened = z.view(-1, self.codebook_dim)
if self.training:
self.q_counter += 1
# x_flat = x.permute(0, 2, 3, 1).reshape(-1, z.shape(1))
if self.q_counter > self.q_start_collect:
z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim)
z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim)
self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0)
self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach()
if self.q_counter < self.q_init:
z_q = rearrange(z, 'b h w c -> b c h w').contiguous()
return z_q, z_q.new_tensor(0), None # z_q, loss, min_encoding_indices
else:
# if self.q_counter < self.q_init + self.q_re_end:
if self.q_init <= self.q_counter < self.q_re_end:
if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1:
kmeans = KMeans(n_clusters=self.codebook_size)
world_size = dist.get_world_size()
print("Updating codebook from reservoir.")
if world_size > 1:
global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)]
dist.all_gather(global_reservoir, self.reservoir.clone())
global_reservoir = torch.cat(global_reservoir, dim=0)
else:
global_reservoir = self.reservoir
kmeans.fit_predict(global_reservoir) # reservoir is 20k encoded latents
self.embedding.weight.data = kmeans.centroids.detach()
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
return z_q, loss, min_encoding_indices
def get_codebook_entry(self, indices, shape):
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
if __name__ == '__main__':
enc = Encoder()
dec = Decoder()
print(sum([p.numel() for p in enc.parameters()]))
print(sum([p.numel() for p in dec.parameters()]))
x = torch.randn(1, 3, 512, 512)
res = enc(x)
print(res.shape)
res = dec(res)
print(res.shape)