Spaces:
Running
Running
categories = [ | |
"Bags", | |
"Feet", | |
"Hands", | |
"Head", | |
"Lower Body", | |
"Neck", | |
"Outwear", | |
"Upper Body", | |
"Waist", | |
"Whole Body", | |
] | |
import torch | |
from torch import nn | |
from collections import OrderedDict | |
import logging | |
logger = logging.getLogger(__name__) | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
if self.weight.dtype != x.dtype: | |
orig_type = x.dtype | |
ret = super().forward(x.type(self.weight.dtype)) | |
return ret.type(orig_type) | |
else: | |
return super().forward(x) | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
n_head: int, | |
attn_mask: torch.Tensor = None, | |
): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential( | |
OrderedDict( | |
[ | |
( | |
"c_fc", | |
nn.Linear(d_model, d_model * 4), | |
), | |
("gelu", QuickGELU()), | |
( | |
"c_proj", | |
nn.Linear(d_model * 4, d_model), | |
), | |
] | |
) | |
) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor): | |
self.attn_mask = ( | |
self.attn_mask.to(dtype=x.dtype, device=x.device) | |
if self.attn_mask is not None | |
else None | |
) | |
return self.attn( | |
x, | |
x, | |
x, | |
need_weights=False, | |
attn_mask=self.attn_mask, | |
)[0] | |
def forward(self, x: torch.Tensor): | |
x = x + self.attention(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
width: int, | |
layers: int, | |
heads: int, | |
attn_mask: torch.Tensor = None, | |
): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential( | |
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] | |
) | |
def forward(self, x: torch.Tensor): | |
return self.resblocks(x) | |
class ConditionalViT(nn.Module): | |
def __init__( | |
self, | |
input_resolution: int, | |
patch_size: int, | |
width: int, | |
layers: int, | |
heads: int, | |
output_dim: int, | |
n_categories: int = None, | |
**kwargs, | |
): | |
if kwargs: | |
logger.warning(f"Got unused kwargs : {kwargs}") | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.output_dim = output_dim | |
self.conv1 = nn.Conv2d( | |
in_channels=3, | |
out_channels=width, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=False, | |
) | |
scale = width**-0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.n_categories = n_categories | |
if self.n_categories: | |
self.c_embedding = nn.Embedding(self.n_categories, width) | |
self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width)) | |
self.positional_embedding = nn.Parameter( | |
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) | |
) | |
self.ln_pre = LayerNorm(width) | |
self.transformer = Transformer(width, layers, heads) | |
self.ln_post = LayerNorm(width) | |
self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052) | |
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
def forward(self, imgs: torch.Tensor, c: torch.Tensor = None): | |
""" | |
imgs : Batch of images | |
c : category indices. 0 = "No given category". | |
""" | |
x = self.conv1(imgs) # shape = [*, width, grid, grid] | |
# shape = [*, width, grid ** 2] | |
x = x.reshape(x.shape[0], x.shape[1], -1) | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
# [CLS, grid] + maybe Categories. | |
tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD | |
pos_embed = [self.positional_embedding] # LD | |
if self.n_categories and c is not None: # If c is None, we don't add the token | |
tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D | |
pos_embed += [self.c_pos_embedding] # 1D | |
x = torch.cat( | |
tokens, | |
dim=1, | |
) # shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D | |
pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D | |
x = x + pos_embed | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_post(x[:, 0, :]) | |
x = x @ self.proj | |
return x | |
# SIZES | |
B32_Params = { | |
"input_resolution": 224, | |
"patch_size": 32, | |
"width": 768, | |
"layers": 12, | |
"heads": 12, | |
"output_dim": 512, | |
} | |
B16_Params = { | |
"input_resolution": 224, | |
"patch_size": 16, | |
"width": 768, | |
"layers": 12, | |
"heads": 12, | |
"output_dim": 512, | |
} | |
params = {"B32": B32_Params, "B16": B16_Params} |