PITI-Synthesis / glide_text2im /text2im_model.py
tfwang's picture
add app file
bd366ed
raw
history blame
6.6 kB
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import random
from .nn import timestep_embedding
from .unet import UNetModel
from .xf import LayerNorm, Transformer, convert_module_to_f16
from timm.models.vision_transformer import PatchEmbed
class Text2ImModel(nn.Module):
def __init__(
self,
text_ctx,
xf_width,
xf_layers,
xf_heads,
xf_final_ln,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout,
channel_mult,
use_fp16,
num_heads,
num_heads_upsample,
num_head_channels,
use_scale_shift_norm,
resblock_updown,
in_channels = 3,
n_class = 3,
image_size = 64,
):
super().__init__()
self.encoder = Encoder(img_size=image_size, patch_size=image_size//16, in_chans=n_class,
xf_width=xf_width, xf_layers=8, xf_heads=xf_heads, model_channels=model_channels)
self.in_channels = in_channels
self.decoder = Text2ImUNet(
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
use_fp16=use_fp16,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
num_head_channels=num_head_channels,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
encoder_channels=xf_width
)
def forward(self, xt, timesteps, ref=None, uncond_p=0.0):
latent_outputs =self.encoder(ref, uncond_p)
pred = self.decoder(xt, timesteps, latent_outputs)
return pred
class Text2ImUNet(UNetModel):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.transformer_proj = nn.Linear(512, self.model_channels * 4) ###
def forward(self, x, timesteps, latent_outputs):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
xf_proj, xf_out = latent_outputs["xf_proj"], latent_outputs["xf_out"]
xf_proj = self.transformer_proj(xf_proj) ###
emb = emb + xf_proj.to(emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, xf_out)
hs.append(h)
h = self.middle_block(h, emb, xf_out)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, xf_out)
h = h.type(x.dtype)
h = self.out(h)
return h
class Encoder(nn.Module):
def __init__(
self,
img_size,
patch_size,
in_chans,
xf_width,
xf_layers,
xf_heads,
model_channels,
):
super().__init__( )
self.transformer = Transformer(
xf_width,
xf_layers,
xf_heads,
)
self.cnn = CNN(in_chans)
self.final_ln = LayerNorm(xf_width)
self.cls_token = nn.Parameter(th.empty(1, 1, xf_width, dtype=th.float32))
self.positional_embedding = nn.Parameter(th.empty(1, 256 + 1, xf_width, dtype=th.float32))
def forward(self, ref, uncond_p=0.0):
x = self.cnn(ref)
x = x.flatten(2).transpose(1, 2)
x = x + self.positional_embedding[:, 1:, :]
cls_token = self.cls_token + self.positional_embedding[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = th.cat((x, cls_tokens), dim=1)
xf_out = self.transformer(x)
if self.final_ln is not None:
xf_out = self.final_ln(xf_out)
xf_proj = xf_out[:, -1]
xf_out = xf_out[:, :-1].permute(0, 2, 1) # NLC -> NCL
outputs = dict(xf_proj=xf_proj, xf_out=xf_out)
return outputs
class SuperResText2ImModel(Text2ImModel):
"""
A text2im model that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
if "in_channels" in kwargs:
kwargs = dict(kwargs)
kwargs["in_channels"] = kwargs["in_channels"] * 2
else:
# Curse you, Python. Or really, just curse positional arguments :|.
args = list(args)
args[1] = args[1] * 2
super().__init__(*args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(
low_res, (new_height, new_width), mode="bilinear", align_corners=False
)
# ##########
# upsampled = upsampled + th.randn_like(upsampled)*0.0005*th.log(1 + 0.1* timesteps.reshape(timesteps.shape[0], 1,1,1))
# ##########
x = th.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=True)
def conv7x7(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=7,
stride=stride, padding=3, bias=True)
class CNN(nn.Module):
def __init__(self, in_channels=3):
super(CNN, self).__init__()
self.conv1 = conv7x7(in_channels, 32) #256
self.norm1 = nn.InstanceNorm2d(32, affine=True)
self.LReLU1 = nn.LeakyReLU(0.2)
self.conv2 = conv3x3(32, 64, 2) #128
self.norm2 = nn.InstanceNorm2d(64, affine=True)
self.LReLU2 = nn.LeakyReLU(0.2)
self.conv3 = conv3x3(64, 128, 2) #64
self.norm3 = nn.InstanceNorm2d(128, affine=True)
self.LReLU3 = nn.LeakyReLU(0.2)
self.conv4 = conv3x3(128, 256, 2) #32
self.norm4 = nn.InstanceNorm2d(256, affine=True)
self.LReLU4 = nn.LeakyReLU(0.2)
self.conv5 = conv3x3(256, 512, 2) #16
self.norm5 = nn.InstanceNorm2d(512, affine=True)
self.LReLU5 = nn.LeakyReLU(0.2)
self.conv6 = conv3x3(512, 512, 1)
def forward(self, x):
x = self.LReLU1(self.norm1(self.conv1(x)))
x = self.LReLU2(self.norm2(self.conv2(x)))
x = self.LReLU3(self.norm3(self.conv3(x)))
x = self.LReLU4(self.norm4(self.conv4(x)))
x = self.LReLU5(self.norm5(self.conv5(x)))
x = self.conv6(x)
return x