diff --git a/.gitignore b/.gitignore index f24b7378899ae1bd896dd9cceacd6991ce4e1e57..f7c0c41c815a4a34cc2758b562175b6f20fed4a5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ .ipynb_checkpoints █ .vscode env -test.py +test*.py *.jpeg __pycache__ +sample_task.txt +.idea diff --git a/external/briarmbg.py b/external/briarmbg.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae7650b1fbc2e592d07f09571772865df844ba8 --- /dev/null +++ b/external/briarmbg.py @@ -0,0 +1,460 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear") + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__( + self, + in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super(myrebnconv, self).__init__() + + self.conv = nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self, x): + return self.rl(self.bn(self.conv(x))) + + +class BriaRMBG(nn.Module, PyTorchModelHubMixin): + def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}): + super(BriaRMBG, self).__init__() + in_ch = config["in_ch"] + out_ch = config["out_ch"] + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, x) + + return [ + F.sigmoid(d1), + F.sigmoid(d2), + F.sigmoid(d3), + F.sigmoid(d4), + F.sigmoid(d5), + F.sigmoid(d6), + ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] diff --git a/external/llite/library/custom_train_functions.py b/external/llite/library/custom_train_functions.py index 629e1a2ebe0a0df63b637217fbf80d6f558a89db..e0a026dae865fa02d2f133f2232f361b56ea012c 100644 --- a/external/llite/library/custom_train_functions.py +++ b/external/llite/library/custom_train_functions.py @@ -1,529 +1,529 @@ -import torch -import argparse -import random -import re -from typing import List, Optional, Union - - -def prepare_scheduler_for_custom_training(noise_scheduler, device): - if hasattr(noise_scheduler, "all_snr"): - return - - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - alpha = sqrt_alphas_cumprod - sigma = sqrt_one_minus_alphas_cumprod - all_snr = (alpha / sigma) ** 2 - - noise_scheduler.all_snr = all_snr.to(device) - - -def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): - # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") - - def enforce_zero_terminal_snr(betas): - # Convert betas to alphas_bar_sqrt - alphas = 1 - betas - alphas_bar = alphas.cumprod(0) - alphas_bar_sqrt = alphas_bar.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - # Shift so last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - # Scale so first timestep is back to old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 - alphas = alphas_bar[1:] / alphas_bar[:-1] - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - return betas - - betas = noise_scheduler.betas - betas = enforce_zero_terminal_snr(betas) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) - - noise_scheduler.betas = betas - noise_scheduler.alphas = alphas - noise_scheduler.alphas_cumprod = alphas_cumprod - - -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): - snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) - if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) - else: - snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) - loss = loss * snr_weight - return loss - - -def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): - scale = get_snr_scale(timesteps, noise_scheduler) - loss = loss * scale - return loss - - -def get_snr_scale(timesteps, noise_scheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - scale = snr_t / (snr_t + 1) - # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") - return scale - - -def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): - scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") - loss = loss + loss / scale * v_pred_like_loss - return loss - -def apply_debiased_estimation(loss, timesteps, noise_scheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) - loss = weight * loss - return loss - -# TODO train_utilと分散しているのでどちらかに寄せる - - -def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): - parser.add_argument( - "--min_snr_gamma", - type=float, - default=None, - help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", - ) - parser.add_argument( - "--scale_v_pred_loss_like_noise_pred", - action="store_true", - help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", - ) - parser.add_argument( - "--v_pred_like_loss", - type=float, - default=None, - help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", - ) - parser.add_argument( - "--debiased_estimation_loss", - action="store_true", - help="debiased estimation loss / debiased estimation loss", - ) - if support_weighted_captions: - parser.add_argument( - "--weighted_captions", - action="store_true", - default=False, - help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", - ) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - tokenizer, - text_encoder, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - if clip_skip is None or clip_skip == 1: - text_embedding = text_encoder(text_input_chunk)[0] - else: - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = text_encoder(text_input)[0] - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings - - -def get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt: Union[str, List[str]], - device, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - clip_skip=None, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - tokenizer, - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - return text_embeddings - - -# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): - b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! - u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) - for i in range(iterations): - r = random.random() * 2 + 2 # Rather than always going 2x, - wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i - if wn == 1 or hn == 1: - break # Lowest resolution is 1x1 - return noise / noise.std() # Scaled back to roughly unit variance - - -# https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): - if noise_offset is None: - return noise - if adaptive_noise_scale is not None: - # latent shape: (batch_size, channels, height, width) - # abs mean value for each channel - latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) - - # multiply adaptive noise scale to the mean value and add it to the noise offset - noise_offset = noise_offset + adaptive_noise_scale * latent_mean - noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative - - noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - return noise - - -""" -########################################## -# Perlin Noise -def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - - grid = ( - torch.stack( - torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), - dim=-1, - ) - % 1 - ) - angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) - gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - - tile_grads = ( - lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] - .repeat_interleave(d[0], 0) - .repeat_interleave(d[1], 1) - ) - dot = lambda grad, shift: ( - torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) - * grad[: shape[0], : shape[1]] - ).sum(dim=-1) - - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) - - -def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): - noise = torch.zeros(shape, device=device) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise - - -def perlin_noise(noise, device, octaves): - _, c, w, h = noise.shape - perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) - noise_perlin = [] - for _ in range(c): - noise_perlin.append(perlin()) - noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) - noise += noise_perlin # broadcast for each batch - return noise / noise.std() # Scaled back to roughly unit variance -""" +import torch +import argparse +import random +import re +from typing import List, Optional, Union + + +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + + noise_scheduler.all_snr = all_snr.to(device) + + +def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): + # fix beta: zero terminal SNR + print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + + def enforce_zero_terminal_snr(betas): + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + betas = noise_scheduler.betas + betas = enforce_zero_terminal_snr(betas) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + # print("original:", noise_scheduler.betas) + # print("fixed:", betas) + + noise_scheduler.betas = betas + noise_scheduler.alphas = alphas + noise_scheduler.alphas_cumprod = alphas_cumprod + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) + loss = loss * snr_weight + return loss + + +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + scale = get_snr_scale(timesteps, noise_scheduler) + loss = loss * scale + return loss + + +def get_snr_scale(timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + # # show debug info + # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + return scale + + +def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): + scale = get_snr_scale(timesteps, noise_scheduler) + # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + loss = loss + loss / scale * v_pred_like_loss + return loss + +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1/torch.sqrt(snr_t) + loss = weight * loss + return loss + +# TODO train_utilと分散しているのでどちらかに寄せる + + +def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): + parser.add_argument( + "--min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) + parser.add_argument( + "--scale_v_pred_loss_like_noise_pred", + action="store_true", + help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", + ) + parser.add_argument( + "--v_pred_like_loss", + type=float, + default=None, + help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", + ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) + if support_weighted_captions: + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + tokenizer, + text_encoder, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = text_encoder(text_input_chunk)[0] + else: + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + if clip_skip is None or clip_skip == 1: + text_embeddings = text_encoder(text_input)[0] + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings + + +def get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + device, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + tokenizer, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + return text_embeddings + + +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 +def pyramid_noise_like(noise, device, iterations=6, discount=0.4): + b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! + u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) + for i in range(iterations): + r = random.random() * 2 + 2 # Rather than always going 2x, + wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i + if wn == 1 or hn == 1: + break # Lowest resolution is 1x1 + return noise / noise.std() # Scaled back to roughly unit variance + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): + if noise_offset is None: + return noise + if adaptive_noise_scale is not None: + # latent shape: (batch_size, channels, height, width) + # abs mean value for each channel + latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) + + # multiply adaptive noise scale to the mean value and add it to the noise offset + noise_offset = noise_offset + adaptive_noise_scale * latent_mean + noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative + + noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + return noise + + +""" +########################################## +# Perlin Noise +def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = ( + torch.stack( + torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), + dim=-1, + ) + % 1 + ) + angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + tile_grads = ( + lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) + dot = lambda grad, shift: ( + torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape, device=device) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise + + +def perlin_noise(noise, device, octaves): + _, c, w, h = noise.shape + perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) + noise_perlin = [] + for _ in range(c): + noise_perlin.append(perlin()) + noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) + noise += noise_perlin # broadcast for each batch + return noise / noise.std() # Scaled back to roughly unit variance +""" diff --git a/external/midas/__init__.py b/external/midas/__init__.py index 3bc396acf23049f182cac19f9e801bc319371138..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/external/midas/__init__.py +++ b/external/midas/__init__.py @@ -1,39 +0,0 @@ -import cv2 -import numpy as np -import torch -from einops import rearrange - -from .api import MiDaSInference - -model = None - - -def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1): - global model - if not model: - model = MiDaSInference(model_type="dpt_hybrid").cuda() - assert input_image.ndim == 3 - image_depth = input_image - with torch.no_grad(): - image_depth = torch.from_numpy(image_depth).float().cuda() - image_depth = image_depth / 127.5 - 1.0 - image_depth = rearrange(image_depth, "h w c -> 1 c h w") - depth = model(image_depth)[0] - - depth_pt = depth.clone() - depth_pt -= torch.min(depth_pt) - depth_pt /= torch.max(depth_pt) - depth_pt = depth_pt.cpu().numpy() - depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) - - depth_np = depth.cpu().numpy() - x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) - y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) - z = np.ones_like(x) * a - x[depth_pt < bg_th] = 0 - y[depth_pt < bg_th] = 0 - normal = np.stack([x, y, z], axis=2) - normal /= np.sum(normal**2.0, axis=2, keepdims=True) ** 0.5 - normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) - - return depth_image, normal_image diff --git a/external/midas/base_model.py b/external/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/external/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/external/midas/blocks.py b/external/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/external/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/external/midas/dpt_depth.py b/external/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/external/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/external/midas/midas_net.py b/external/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/external/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/external/midas/midas_net_custom.py b/external/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/external/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/external/midas/transforms.py b/external/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/external/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/external/midas/vit.py b/external/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/external/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/external/realesrgan/__init__.py b/external/realesrgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfea78f284116dee22510d4aa91f9e44afb7d472 --- /dev/null +++ b/external/realesrgan/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * +#from .version import * diff --git a/external/realesrgan/archs/__init__.py b/external/realesrgan/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fbbf3b78e33b61fd4c33a564a9a617010d90de --- /dev/null +++ b/external/realesrgan/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/external/realesrgan/archs/discriminator_arch.py b/external/realesrgan/archs/discriminator_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4b66ab1226d6793de846bc9828bbe427031a0e2d --- /dev/null +++ b/external/realesrgan/archs/discriminator_arch.py @@ -0,0 +1,67 @@ +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + + +@ARCH_REGISTRY.register() +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/external/realesrgan/archs/srvgg_arch.py b/external/realesrgan/archs/srvgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..12ecf436886c7f8886bf3902174ae40c9c33928d --- /dev/null +++ b/external/realesrgan/archs/srvgg_arch.py @@ -0,0 +1,69 @@ +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn as nn +from torch.nn import functional as F + + +@ARCH_REGISTRY.register() +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch = 3, num_out_ch = 3, num_feat = 64, num_conv = 16, upscale = 4, act_type = 'prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace = True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters = num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace = True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters = num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor = self.upscale, mode = 'nearest') + out += base + return out diff --git a/external/realesrgan/data/__init__.py b/external/realesrgan/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f8fdd1aa47c12de9687c578094303eb7369246 --- /dev/null +++ b/external/realesrgan/data/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import dataset modules for registry +# scan all the files that end with '_dataset.py' under the data folder +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/external/realesrgan/data/realesrgan_dataset.py b/external/realesrgan/data/realesrgan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf2d9e6583a6789b771679734ce55bb8a22e628 --- /dev/null +++ b/external/realesrgan/data/realesrgan_dataset.py @@ -0,0 +1,192 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data + + +@DATASET_REGISTRY.register() +class RealESRGANDataset(data.Dataset): + """Dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.gt_folder] + self.io_backend_opt['client_keys'] = ['gt'] + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [os.path.join(self.gt_folder, v) for v in paths] + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + logger = get_root_logger() + logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = 400 + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/external/realesrgan/data/realesrgan_paired_dataset.py b/external/realesrgan/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..29751ae6749dabd3b8d1ca657f4a4cd392a9080e --- /dev/null +++ b/external/realesrgan/data/realesrgan_paired_dataset.py @@ -0,0 +1,117 @@ +import os +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient +from basicsr.utils.img_util import imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data +from torchvision.transforms.functional import normalize + + +@DATASET_REGISTRY.register() +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + in_channels = opt['in_channels'] if 'in_channels' in opt else 3 + if in_channels == 1: + self.flag = 'grayscale' + elif in_channels == 3: + self.flag = 'color' + else: + self.flag = 'unchanged' + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, flag = self.flag, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, flag = self.flag, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/external/realesrgan/models/__init__.py b/external/realesrgan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0be7105dc75d150c49976396724085f678dc0675 --- /dev/null +++ b/external/realesrgan/models/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import model modules for registry +# scan all the files that end with '_model.py' under the model folder +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] diff --git a/external/realesrgan/models/realesrgan_model.py b/external/realesrgan/models/realesrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c298a09c42433177f90001a0a31d029576072ccd --- /dev/null +++ b/external/realesrgan/models/realesrgan_model.py @@ -0,0 +1,258 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.srgan_model import SRGANModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/external/realesrgan/models/realesrnet_model.py b/external/realesrgan/models/realesrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d11668f3712bffcd062c57db14d22ca3a0e1e59d --- /dev/null +++ b/external/realesrgan/models/realesrnet_model.py @@ -0,0 +1,188 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.sr_model import SRModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/external/realesrgan/train.py b/external/realesrgan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9cec9ed80d9f362984779548dcec921a636a04 --- /dev/null +++ b/external/realesrgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import realesrgan.archs +import realesrgan.data +import realesrgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/external/realesrgan/utils.py b/external/realesrgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c4688cba5f1db83441feee425943d76d00f1cd --- /dev/null +++ b/external/realesrgan/utils.py @@ -0,0 +1,302 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from basicsr.utils.download_util import load_file_from_url +from torch.nn import functional as F + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, scale, model_path, dni_weight = None, model = None, tile = 0, tile_pad = 10, pre_pad = 10, half = False, device = None, gpu_id = None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + if isinstance(model_path, list): + # dni + assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.' + loadnet = self.dni(model_path[0], model_path[1], dni_weight) + else: + # if the model_path starts with https, it will first download models to the folder: weights + if model_path.startswith('https://'): + model_path = load_file_from_url(url = model_path, model_dir = os.path.join(ROOT_DIR, 'weights'), progress = True, file_name = None) + loadnet = torch.load(model_path, map_location = torch.device('cpu')) + + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'): + """Deep network interpolation. + + ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` + """ + net_a = torch.load(net_a, map_location = torch.device(loc)) + net_b = torch.load(net_b, map_location = torch.device(loc)) + for k, v_a in net_a[key].items(): + net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] + return net_a + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale = None, num_out_ch = 3, alpha_upsampler = 'realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + if num_out_ch != 3: + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) + else: + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + img_struct_list = [] + for i in range(3, num_out_ch): + img_struct_list.append(i) + output_img = output_img[[2, 1, 0] + img_struct_list, :, :] + output_img = np.transpose(output_img, (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA' and num_out_ch == 3: + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize(output, (int(w_input * outscale), int(h_input * outscale)), interpolation = cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') \ No newline at end of file diff --git a/handler.py b/handler.py index c84abb898fc5dbd14f7726ba43c6cc9fc965bd29..9dc69e66add1792d3cb2b759de61a6afa5c18eea 100644 --- a/handler.py +++ b/handler.py @@ -1,5 +1,10 @@ -import json import os +import sys + +path = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(1, os.path.join(path, "external")) + + from pathlib import Path from typing import Any, Dict, List diff --git a/inference.py b/inference.py index aa63ffc23aa54a45d7974baf160a8b2dc25bf462..c52b117820b73d69450ca082edcaf9e00679407a 100644 --- a/inference.py +++ b/inference.py @@ -17,10 +17,9 @@ from internals.pipelines.img_classifier import ImageClassifier from internals.pipelines.img_to_text import Image2Text from internals.pipelines.inpainter import InPainter from internals.pipelines.object_remove import ObjectRemoval -from internals.pipelines.pose_detector import PoseDetector from internals.pipelines.prompt_modifier import PromptModifier from internals.pipelines.realtime_draw import RealtimeDraw -from internals.pipelines.remove_background import RemoveBackgroundV2 +from internals.pipelines.remove_background import RemoveBackgroundV3 from internals.pipelines.replace_background import ReplaceBackground from internals.pipelines.safety_checker import SafetyChecker from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler @@ -45,7 +44,6 @@ from internals.util.config import ( set_model_config, set_root_dir, ) -from internals.util.failure_hander import FailureHandler from internals.util.lora_style import LoraStyle from internals.util.model_loader import load_model_from_config from internals.util.slack import Slack @@ -57,14 +55,13 @@ auto_mode = False prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences()) upscaler = Upscaler() -pose_detector = PoseDetector() inpainter = InPainter() high_res = HighRes() img2text = Image2Text() img_classifier = ImageClassifier() object_removal = ObjectRemoval() replace_background = ReplaceBackground() -remove_background_v2 = RemoveBackgroundV2() +remove_background_v3 = RemoveBackgroundV3() replace_background = ReplaceBackground() controlnet = ControlNet() lora_style = LoraStyle() @@ -92,7 +89,7 @@ def get_patched_prompt_text2img(task: Task): def get_patched_prompt_tile_upscale(task: Task): return prompt_util.get_patched_prompt_tile_upscale( - task, avatar, lora_style, img_classifier, img2text + task, avatar, lora_style, img_classifier, img2text, is_sdxl=get_is_sdxl() ) @@ -126,20 +123,19 @@ def canny(task: Task): "num_inference_steps": task.get_steps(), "width": width, "height": height, - "negative_prompt": [ - f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" - ] - * get_num_return_sequences(), + "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), + "apply_preprocess": task.get_apply_preprocess(), **task.cnc_kwargs(), **lora_patcher.kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) + (images, has_nsfw), control_image = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, + "seed": task.get_seed(), "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), @@ -147,6 +143,9 @@ def canny(task: Task): } images, _ = high_res.apply(**kwargs) + upload_image( + control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore + ) generated_image_urls = upload_images(images, "_canny", task.get_taskId()) lora_patcher.cleanup() @@ -162,48 +161,102 @@ def canny(task: Task): @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert -def tile_upscale(task: Task): - output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) - - prompt = get_patched_prompt_tile_upscale(task) - - if get_is_sdxl(): - lora_patcher = lora_style.get_patcher( - [sdxl_tileupscaler.pipe, high_res.pipe], task.get_style() - ) - lora_patcher.patch() +def canny_img2img(task: Task): + prompt, _ = get_patched_prompt(task) - images, has_nsfw = sdxl_tileupscaler.process( - prompt=prompt, - imageUrl=task.get_imageUrl(), - resize_dimension=task.get_resize_dimension(), - negative_prompt=task.get_negative_prompt(), - width=task.get_width(), - height=task.get_height(), - model_id=task.get_model_id(), - ) + width, height = get_intermediate_dimension(task) - lora_patcher.cleanup() - else: - controlnet.load_model("tile_upscaler") + controlnet.load_model("canny_2x") - lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) - lora_patcher.patch() + lora_patcher = lora_style.get_patcher( + [controlnet.pipe, high_res.pipe], task.get_style() + ) + lora_patcher.patch() + kwargs = { + "prompt": prompt, + "imageUrl": task.get_imageUrl(), + "seed": task.get_seed(), + "num_inference_steps": task.get_steps(), + "width": width, + "height": height, + "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), + **task.cnci2i_kwargs(), + **lora_patcher.kwargs(), + } + (images, has_nsfw), control_image = controlnet.process(**kwargs) + if task.get_high_res_fix(): + # we run both here normal upscaler and highres + # and show normal upscaler image as output + # but use highres image for tile upscale kwargs = { - "imageUrl": task.get_imageUrl(), + "prompt": prompt, + "negative_prompt": [task.get_negative_prompt()] + * get_num_return_sequences(), + "images": images, "seed": task.get_seed(), - "num_inference_steps": task.get_steps(), - "negative_prompt": task.get_negative_prompt(), "width": task.get_width(), "height": task.get_height(), - "prompt": prompt, - "resize_dimension": task.get_resize_dimension(), - **task.cnt_kwargs(), + "num_inference_steps": task.get_steps(), + **task.high_res_kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) - lora_patcher.cleanup() - controlnet.cleanup() + images, _ = high_res.apply(**kwargs) + + # upload_images(images_high_res, "_canny_2x_highres", task.get_taskId()) + + for i, image in enumerate(images): + img = upscaler.upscale( + image=image, + width=task.get_width(), + height=task.get_height(), + face_enhance=task.get_face_enhance(), + resize_dimension=None, + ) + img = Upscaler.to_pil(img) + images[i] = img.resize((task.get_width(), task.get_height())) + + upload_image( + control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore + ) + generated_image_urls = upload_images(images, "_canny_2x", task.get_taskId()) + + lora_patcher.cleanup() + controlnet.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def tile_upscale(task: Task): + output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) + + prompt = get_patched_prompt_tile_upscale(task) + + controlnet.load_model("tile_upscaler") + + lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) + lora_patcher.patch() + + kwargs = { + "imageUrl": task.get_imageUrl(), + "seed": task.get_seed(), + "num_inference_steps": task.get_steps(), + "negative_prompt": task.get_negative_prompt(), + "width": task.get_width(), + "height": task.get_height(), + "prompt": prompt, + "resize_dimension": task.get_resize_dimension(), + **task.cnt_kwargs(), + } + (images, has_nsfw), _ = controlnet.process(**kwargs) + lora_patcher.cleanup() + controlnet.cleanup() generated_image_url = upload_image(images[0], output_key) @@ -229,12 +282,7 @@ def scribble(task: Task): ) lora_patcher.patch() - image = download_image(task.get_imageUrl()).resize((width, height)) - if get_is_sdxl(): - # We use sketch in SDXL - image = ControlNet.pidinet_image(image) - else: - image = ControlNet.scribble_image(image) + image = controlnet.preprocess_image(task.get_imageUrl(), width, height) kwargs = { "image": [image] * get_num_return_sequences(), @@ -244,9 +292,10 @@ def scribble(task: Task): "height": height, "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), + "apply_preprocess": task.get_apply_preprocess(), **task.cns_kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) + (images, has_nsfw), condition_image = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { @@ -256,11 +305,15 @@ def scribble(task: Task): "images": images, "width": task.get_width(), "height": task.get_height(), + "seed": task.get_seed(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) + upload_image( + condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore + ) generated_image_urls = upload_images(images, "_scribble", task.get_taskId()) lora_patcher.cleanup() @@ -296,16 +349,21 @@ def linearart(task: Task): "height": height, "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), + "apply_preprocess": task.get_apply_preprocess(), **task.cnl_kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) + (images, has_nsfw), condition_image = controlnet.process(**kwargs) if task.get_high_res_fix(): + # we run both here normal upscaler and highres + # and show normal upscaler image as output + # but use highres image for tile upscale kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, + "seed": task.get_seed(), "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), @@ -313,6 +371,22 @@ def linearart(task: Task): } images, _ = high_res.apply(**kwargs) + # upload_images(images_high_res, "_linearart_highres", task.get_taskId()) + # + # for i, image in enumerate(images): + # img = upscaler.upscale( + # image=image, + # width=task.get_width(), + # height=task.get_height(), + # face_enhance=task.get_face_enhance(), + # resize_dimension=None, + # ) + # img = Upscaler.to_pil(img) + # images[i] = img + + upload_image( + condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore + ) generated_image_urls = upload_images(images, "_linearart", task.get_taskId()) lora_patcher.cleanup() @@ -341,20 +415,14 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): ) lora_patcher.patch() - if not task.get_pose_estimation(): + if not task.get_apply_preprocess(): + poses = [download_image(task.get_imageUrl()).resize((width, height))] + elif not task.get_pose_estimation(): print("Not detecting pose") pose = download_image(task.get_imageUrl()).resize( (task.get_width(), task.get_height()) ) poses = [pose] * get_num_return_sequences() - elif task.get_pose_coordinates(): - infered_pose = pose_detector.transform( - image=task.get_imageUrl(), - client_coordinates=task.get_pose_coordinates(), - width=task.get_width(), - height=task.get_height(), - ) - poses = [infered_pose] * get_num_return_sequences() else: poses = [ controlnet.detect_pose(task.get_imageUrl()) @@ -370,8 +438,11 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId())) + scale = task.cnp_kwargs().pop("controlnet_conditioning_scale", None) + factor = task.cnp_kwargs().pop("control_guidance_end", None) kwargs = { - "control_guidance_end": [0.5, 1.0], + "controlnet_conditioning_scale": [1.0, scale or 1.0], + "control_guidance_end": [0.5, factor or 1.0], } else: images = poses[0] @@ -389,7 +460,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): **task.cnp_kwargs(), **lora_patcher.kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) + (images, has_nsfw), _ = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { @@ -400,11 +471,12 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), + "seed": task.get_seed(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) - upload_image(poses[0], "crecoAI/{}_pose.png".format(task.get_taskId())) + upload_image(poses[0], "crecoAI/{}_condition.png".format(task.get_taskId())) generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) @@ -431,12 +503,11 @@ def text2img(task: Task): ) lora_patcher.patch() - torch.manual_seed(task.get_seed()) - kwargs = { "params": params, "num_inference_steps": task.get_steps(), "height": height, + "seed": task.get_seed(), "width": width, "negative_prompt": task.get_negative_prompt(), **task.t2i_kwargs(), @@ -455,6 +526,7 @@ def text2img(task: Task): "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), + "seed": task.get_seed(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) @@ -478,11 +550,9 @@ def img2img(task: Task): width, height = get_intermediate_dimension(task) - torch.manual_seed(task.get_seed()) - if get_is_sdxl(): # we run lineart for img2img - controlnet.load_model("linearart") + controlnet.load_model("canny") lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() @@ -498,10 +568,11 @@ def img2img(task: Task): "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), - **task.cnl_kwargs(), - "adapter_conditioning_scale": 0.3, + "controlnet_conditioning_scale": 0.5, + # "adapter_conditioning_scale": 0.3, + **task.i2i_kwargs(), } - images, has_nsfw = controlnet.process(**kwargs) + (images, has_nsfw), _ = controlnet.process(**kwargs) else: lora_patcher = lora_style.get_patcher( [img2img_pipe.pipe, high_res.pipe], task.get_style() @@ -516,6 +587,7 @@ def img2img(task: Task): "num_inference_steps": task.get_steps(), "width": width, "height": height, + "seed": task.get_seed(), **task.i2i_kwargs(), **lora_patcher.kwargs(), } @@ -530,6 +602,7 @@ def img2img(task: Task): "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), + "seed": task.get_seed(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) @@ -568,7 +641,9 @@ def inpaint(task: Task): "num_inference_steps": task.get_steps(), **task.ip_kwargs(), } - images = inpainter.process(**kwargs) + images, mask = inpainter.process(**kwargs) + + upload_image(mask, "crecoAI/{}_mask.png".format(task.get_taskId())) generated_image_urls = upload_images(images, key, task.get_taskId()) @@ -617,9 +692,7 @@ def replace_bg(task: Task): @update_db @slack.auto_send_alert def remove_bg(task: Task): - output_image = remove_background_v2.remove( - task.get_imageUrl(), model_type=task.get_modelType() - ) + output_image = remove_background_v3.remove(task.get_imageUrl()) output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) image_url = upload_image(output_image, output_key) @@ -732,6 +805,67 @@ def rt_draw_img(task: Task): return {"image": base64_image} +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def depth_rig(task: Task): + # Note : This task is for only processing a hardcoded character rig model using depth controlnet + # Hack : This model requires hardcoded depth images for optimal processing, so we pass it by default + default_depth_url = "https://s3.ap-south-1.amazonaws.com/assets.autodraft.in/character-sheet/rigs/character-rig-depth-map.png" + + params = get_patched_prompt_text2img(task) + + width, height = get_intermediate_dimension(task) + + controlnet.load_model("depth") + + lora_patcher = lora_style.get_patcher( + [controlnet.pipe2, high_res.pipe], task.get_style() + ) + lora_patcher.patch() + + kwargs = { + "params": params, + "prompt": params.prompt, + "num_inference_steps": task.get_steps(), + "imageUrl": default_depth_url, + "height": height, + "seed": task.get_seed(), + "width": width, + "negative_prompt": task.get_negative_prompt(), + **task.t2i_kwargs(), + **lora_patcher.kwargs(), + } + (images, has_nsfw), condition_image = controlnet.process(**kwargs) + + if task.get_high_res_fix(): + kwargs = { + "prompt": params.prompt + if params.prompt + else [""] * get_num_return_sequences(), + "negative_prompt": [task.get_negative_prompt()] + * get_num_return_sequences(), + "images": images, + "width": task.get_width(), + "height": task.get_height(), + "num_inference_steps": task.get_steps(), + "seed": task.get_seed(), + **task.high_res_kwargs(), + } + images, _ = high_res.apply(**kwargs) + + upload_image(condition_image, "crecoAI/{}_condition.png".format(task.get_taskId())) + generated_image_urls = upload_images(images, "", task.get_taskId()) + + lora_patcher.cleanup() + + return { + **params.__dict__, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + def custom_action(task: Task): from external.scripts import __scripts__ @@ -759,6 +893,14 @@ def custom_action(task: Task): def load_model_by_task(task_type: TaskType, model_id=-1): + from internals.pipelines.controlnets import clear_networks + + # pre-cleanup inpaint and controlnet models + if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT: + clear_networks() + else: + inpainter.unload() + if not text2img_pipe.is_loaded(): text2img_pipe.load(get_model_dir()) img2img_pipe.create(text2img_pipe) @@ -782,12 +924,14 @@ def load_model_by_task(task_type: TaskType, model_id=-1): upscaler.load() else: if task_type == TaskType.TILE_UPSCALE: - if get_is_sdxl(): - sdxl_tileupscaler.create(high_res, text2img_pipe, model_id) - else: - controlnet.load_model("tile_upscaler") + # if get_is_sdxl(): + # sdxl_tileupscaler.create(high_res, text2img_pipe, model_id) + # else: + controlnet.load_model("tile_upscaler") elif task_type == TaskType.CANNY: controlnet.load_model("canny") + elif task_type == TaskType.CANNY_IMG2IMG: + controlnet.load_model("canny_2x") elif task_type == TaskType.SCRIBBLE: controlnet.load_model("scribble") elif task_type == TaskType.LINEARART: @@ -798,23 +942,24 @@ def load_model_by_task(task_type: TaskType, model_id=-1): def unload_model_by_task(task_type: TaskType): if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT: - inpainter.unload() + # inpainter.unload() + pass elif task_type == TaskType.REPLACE_BG: replace_background.unload() elif task_type == TaskType.OBJECT_REMOVAL: object_removal.unload() elif task_type == TaskType.TILE_UPSCALE: - if get_is_sdxl(): - sdxl_tileupscaler.unload() - else: - controlnet.unload() - elif task_type == TaskType.CANNY: + # if get_is_sdxl(): + # sdxl_tileupscaler.unload() + # else: controlnet.unload() - elif task_type == TaskType.SCRIBBLE: - controlnet.unload() - elif task_type == TaskType.LINEARART: - controlnet.unload() - elif task_type == TaskType.POSE: + elif ( + task_type == TaskType.CANNY + or task_type == TaskType.CANNY_IMG2IMG + or task_type == TaskType.SCRIBBLE + or task_type == TaskType.LINEARART + or task_type == TaskType.POSE + ): controlnet.unload() @@ -831,8 +976,6 @@ def model_fn(model_dir): set_model_config(config) set_root_dir(__file__) - FailureHandler.register() - avatar.load_local(model_dir) lora_style.load(model_dir) @@ -855,15 +998,12 @@ def auto_unload_task(func): @auto_unload_task -@FailureHandler.clear def predict_fn(data, pipe): task = Task(data) print("task is ", data) clear_cuda_and_gc() - FailureHandler.handle(task) - try: task_type = task.get_type() @@ -894,11 +1034,16 @@ def predict_fn(data, pipe): avatar.fetch_from_network(task.get_model_id()) if task_type == TaskType.TEXT_TO_IMAGE: + # Hack : Character Rigging Model Task Redirection + if task.get_model_id() == 2000336 or task.get_model_id() == 2000341: + return depth_rig(task) return text2img(task) elif task_type == TaskType.IMAGE_TO_IMAGE: return img2img(task) elif task_type == TaskType.CANNY: return canny(task) + elif task_type == TaskType.CANNY_IMG2IMG: + return canny_img2img(task) elif task_type == TaskType.POSE: return pose(task) elif task_type == TaskType.TILE_UPSCALE: diff --git a/internals/data/task.py b/internals/data/task.py index 991560e9db3ce1afb75c0476bc1b0d75dadc388b..b9972ca3681a888e5ac5a81536d83721480a2e95 100644 --- a/internals/data/task.py +++ b/internals/data/task.py @@ -11,6 +11,7 @@ class TaskType(Enum): POSE = "POSE" CANNY = "CANNY" REMOVE_BG = "REMOVE_BG" + CANNY_IMG2IMG = "CANNY_IMG2IMG" INPAINT = "INPAINT" UPSCALE_IMAGE = "UPSCALE_IMAGE" TILE_UPSCALE = "TILE_UPSCALE" @@ -47,12 +48,18 @@ class Task: elif len(prompt) > 200: self.__data["prompt"] = data.get("prompt", "")[:200] + ", " + def get_environment(self) -> str: + return self.__data.get("stage", "prod") + def get_taskId(self) -> str: return self.__data.get("task_id") def get_sourceId(self) -> str: return self.__data.get("source_id") + def get_slack_url(self) -> str: + return self.__data.get("slack_url", None) + def get_imageUrl(self) -> str: return self.__data.get("imageUrl", None) @@ -150,12 +157,18 @@ class Task: def get_access_token(self) -> str: return self.__data.get("access_token", "") + def get_apply_preprocess(self) -> bool: + return self.__data.get("apply_preprocess", True) + def get_high_res_fix(self) -> bool: return self.__data.get("high_res_fix", False) def get_base_dimension(self): return self.__data.get("base_dimension", None) + def get_process_mode(self): + return self.__data.get("process_mode", None) + def get_action_data(self) -> dict: "If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key" return self.__data.get("action_data", {}) @@ -175,6 +188,9 @@ class Task: def cnc_kwargs(self) -> dict: return dict(self.__get_kwargs("cnc_")) + def cnci2i_kwargs(self) -> dict: + return dict(self.__get_kwargs("cnci2i_")) + def cnp_kwargs(self) -> dict: return dict(self.__get_kwargs("cnp_")) @@ -192,7 +208,7 @@ class Task: def __get_kwargs(self, prefix: str): for k, v in self.__data.items(): - if k.startswith(prefix): + if k.startswith(prefix) and v != -1: yield k[len(prefix) :], v @property diff --git a/internals/pipelines/commons.py b/internals/pipelines/commons.py index d67b23659532e257394efc7f37e2d65c1400dbb6..3be2676dc8e69d2e4023180858942b2daac1c044 100644 --- a/internals/pipelines/commons.py +++ b/internals/pipelines/commons.py @@ -11,11 +11,14 @@ from diffusers import ( from internals.data.result import Result from internals.pipelines.twoStepPipeline import two_step_pipeline +from internals.util import get_generators from internals.util.commons import disable_safety_checker, download_image from internals.util.config import ( + get_base_model_revision, get_base_model_variant, get_hf_token, get_is_sdxl, + get_low_gpu_mem, get_num_return_sequences, ) @@ -38,6 +41,9 @@ class Text2Img(AbstractPipeline): def load(self, model_dir: str): if get_is_sdxl(): + print( + f"Loading model {model_dir} - {get_base_model_variant()}, {get_base_model_revision()}" + ) vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ) @@ -47,6 +53,7 @@ class Text2Img(AbstractPipeline): token=get_hf_token(), use_safetensors=True, variant=get_base_model_variant(), + revision=get_base_model_revision(), ) pipe.vae = vae pipe.to("cuda") @@ -70,9 +77,9 @@ class Text2Img(AbstractPipeline): self.__patch() def __patch(self): - if get_is_sdxl(): - self.pipe.enable_vae_tiling() - self.pipe.enable_vae_slicing() + if get_is_sdxl() or get_low_gpu_mem(): + self.pipe.vae.enable_tiling() + self.pipe.vae.enable_slicing() self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() @@ -82,12 +89,15 @@ class Text2Img(AbstractPipeline): num_inference_steps: int, height: int, width: int, + seed: int, negative_prompt: str, iteration: float = 3.0, **kwargs, ): prompt = params.prompt + generator = get_generators(seed, get_num_return_sequences()) + if params.prompt_left and params.prompt_right: # multi-character pipelines prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]] @@ -99,6 +109,7 @@ class Text2Img(AbstractPipeline): "width": width, "num_inference_steps": num_inference_steps, "negative_prompt": [negative_prompt or ""] * len(prompt), + "generator": generator, **kwargs, } result = self.pipe.multi_character_diffusion(**kwargs) @@ -125,8 +136,11 @@ class Text2Img(AbstractPipeline): "width": width, "negative_prompt": [negative_prompt or ""] * get_num_return_sequences(), "num_inference_steps": num_inference_steps, + "guidance_scale": 7.5, + "generator": generator, **kwargs, } + print(kwargs) result = self.pipe.__call__(**kwargs) return Result.from_result(result) @@ -145,6 +159,7 @@ class Img2Img(AbstractPipeline): torch_dtype=torch.float16, token=get_hf_token(), variant=get_base_model_variant(), + revision=get_base_model_revision(), use_safetensors=True, ).to("cuda") else: @@ -183,20 +198,24 @@ class Img2Img(AbstractPipeline): num_inference_steps: int, width: int, height: int, + seed: int, strength: float = 0.75, guidance_scale: float = 7.5, **kwargs, ): image = download_image(imageUrl).resize((width, height)) + generator = get_generators(seed, get_num_return_sequences()) + kwargs = { "prompt": prompt, - "image": image, + "image": [image] * get_num_return_sequences(), "strength": strength, "negative_prompt": negative_prompt, "guidance_scale": guidance_scale, "num_images_per_prompt": 1, "num_inference_steps": num_inference_steps, + "generator": generator, **kwargs, } result = self.pipe.__call__(**kwargs) diff --git a/internals/pipelines/controlnets.py b/internals/pipelines/controlnets.py index 7b4d6dac93136d71e951506f25b2499f3f77a2bd..3fea5d71968b8c2495cdd6754e57f5e95f0852a2 100644 --- a/internals/pipelines/controlnets.py +++ b/internals/pipelines/controlnets.py @@ -1,3 +1,4 @@ +import os from typing import AbstractSet, List, Literal, Optional, Union import cv2 @@ -17,6 +18,7 @@ from diffusers import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLAdapterPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetPipeline, T2IAdapter, UniPCMultistepScheduler, @@ -29,9 +31,9 @@ from tqdm import gui from transformers import pipeline import internals.util.image as ImageUtil -from external.midas import apply_midas from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline +from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( @@ -39,9 +41,51 @@ from internals.util.config import ( get_hf_token, get_is_sdxl, get_model_dir, + get_num_return_sequences, ) -CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"] +CONTROLNET_TYPES = Literal[ + "pose", "canny", "scribble", "linearart", "tile_upscaler", "canny_2x" +] + +__CN_MODELS = {} +MAX_CN_MODELS = 3 + + +def clear_networks(): + global __CN_MODELS + __CN_MODELS = {} + + +def load_network_model_by_key(repo_id: str, pipeline_type: str): + global __CN_MODELS + + if repo_id in __CN_MODELS: + return __CN_MODELS[repo_id] + + if len(__CN_MODELS) >= MAX_CN_MODELS: + __CN_MODELS = {} + + if pipeline_type == "controlnet": + model = ControlNetModel.from_pretrained( + repo_id, + torch_dtype=torch.float16, + cache_dir=get_hf_cache_dir(), + token=get_hf_token(), + ).to("cuda") + elif pipeline_type == "t2i": + model = T2IAdapter.from_pretrained( + repo_id, + torch_dtype=torch.float16, + varient="fp16", + token=get_hf_token(), + ).to("cuda") + else: + raise Exception("Invalid pipeline type") + + __CN_MODELS[repo_id] = model + + return model class StableDiffusionNetworkModelPipelineLoader: @@ -57,11 +101,6 @@ class StableDiffusionNetworkModelPipelineLoader: pipeline_type, base_pipe: Optional[AbstractSet] = None, ): - if is_sdxl and is_img2img: - # Does not matter pipeline type but tile upscale is not supported - print("Warning: Tile upscale is not supported on SDXL") - return None - if base_pipe is None: pretrained = True kwargs = { @@ -75,7 +114,17 @@ class StableDiffusionNetworkModelPipelineLoader: kwargs = { **base_pipe.pipe.components, # pyright: ignore } + if get_is_sdxl(): + kwargs.pop("image_encoder", None) + kwargs.pop("feature_extractor", None) + if is_sdxl and is_img2img and pipeline_type == "controlnet": + model = ( + StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained + if pretrained + else StableDiffusionXLControlNetImg2ImgPipeline + ) + return model(controlnet=network_model, **kwargs).to("cuda") if is_sdxl and pipeline_type == "controlnet": model = ( StableDiffusionXLControlNetPipeline.from_pretrained @@ -146,9 +195,10 @@ class ControlNet(AbstractPipeline): def load_model(self, task_name: CONTROLNET_TYPES): "Appropriately loads the network module, pipelines and cache it for reuse." - config = self.__model_sdxl if get_is_sdxl() else self.__model_normal if self.__current_task_name == task_name: return + + config = self.__model_sdxl if get_is_sdxl() else self.__model_normal model = config[task_name] if not model: raise Exception(f"ControlNet is not supported for {task_name}") @@ -176,31 +226,13 @@ class ControlNet(AbstractPipeline): def __load_network_model(self, model_name, pipeline_type): "Loads the network module, eg: ControlNet or T2I Adapters" - def load_controlnet(model): - return ControlNetModel.from_pretrained( - model, - torch_dtype=torch.float16, - cache_dir=get_hf_cache_dir(), - ).to("cuda") - - def load_t2i(model): - return T2IAdapter.from_pretrained( - model, - torch_dtype=torch.float16, - varient="fp16", - ).to("cuda") - if type(model_name) == str: - if pipeline_type == "controlnet": - return load_controlnet(model_name) - if pipeline_type == "t2i": - return load_t2i(model_name) - raise Exception("Invalid pipeline type") + return load_network_model_by_key(model_name, pipeline_type) elif type(model_name) == list: if pipeline_type == "controlnet": cns = [] for model in model_name: - cns.append(load_controlnet(model)) + cns.append(load_network_model_by_key(model, pipeline_type)) return MultiControlNetModel(cns).to("cuda") elif pipeline_type == "t2i": raise Exception("Multi T2I adapters are not supported") @@ -219,9 +251,10 @@ class ControlNet(AbstractPipeline): pipe.enable_vae_slicing() pipe.enable_xformers_memory_efficient_attention() # this scheduler produces good outputs for t2i adapters - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - pipe.scheduler.config - ) + if pipeline_type == "t2i": + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config + ) else: pipe.enable_xformers_memory_efficient_attention() return pipe @@ -229,7 +262,7 @@ class ControlNet(AbstractPipeline): # If the pipeline type is changed we should reload all # the pipelines if not self.__loaded or self.__pipe_type != pipeline_type: - # controlnet pipeline for tile upscaler + # controlnet pipeline for tile upscaler or any pipeline with img2img + network support pipe = StableDiffusionNetworkModelPipelineLoader( is_sdxl=get_is_sdxl(), is_img2img=True, @@ -278,6 +311,8 @@ class ControlNet(AbstractPipeline): def process(self, **kwargs): if self.__current_task_name == "pose": return self.process_pose(**kwargs) + if self.__current_task_name == "depth": + return self.process_depth(**kwargs) if self.__current_task_name == "canny": return self.process_canny(**kwargs) if self.__current_task_name == "scribble": @@ -286,6 +321,8 @@ class ControlNet(AbstractPipeline): return self.process_linearart(**kwargs) if self.__current_task_name == "tile_upscaler": return self.process_tile_upscaler(**kwargs) + if self.__current_task_name == "canny_2x": + return self.process_canny_2x(**kwargs) raise Exception("ControlNet is not loaded with any model") @torch.inference_mode() @@ -298,16 +335,22 @@ class ControlNet(AbstractPipeline): negative_prompt: List[str], height: int, width: int, - guidance_scale: float = 9, + guidance_scale: float = 7.5, + apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) + + init_image = self.preprocess_image(imageUrl, width, height) + if apply_preprocess: + init_image = ControlNet.canny_detect_edge(init_image) + init_image = init_image.resize((width, height)) - init_image = download_image(imageUrl).resize((width, height)) - init_image = ControlNet.canny_detect_edge(init_image) + # if get_is_sdxl(): + # kwargs["controlnet_conditioning_scale"] = 0.5 kwargs = { "prompt": prompt, @@ -318,11 +361,67 @@ class ControlNet(AbstractPipeline): "num_inference_steps": num_inference_steps, "height": height, "width": width, + "generator": generator, **kwargs, } + print(kwargs) result = self.pipe2.__call__(**kwargs) - return Result.from_result(result) + return Result.from_result(result), init_image + + @torch.inference_mode() + def process_canny_2x( + self, + prompt: List[str], + imageUrl: str, + seed: int, + num_inference_steps: int, + negative_prompt: List[str], + height: int, + width: int, + guidance_scale: float = 8.5, + **kwargs, + ): + if self.__current_task_name != "canny_2x": + raise Exception("ControlNet is not loaded with canny model") + + generator = get_generators(seed, get_num_return_sequences()) + + init_image = self.preprocess_image(imageUrl, width, height) + canny_image = ControlNet.canny_detect_edge(init_image).resize((width, height)) + depth_image = ControlNet.depth_image(init_image).resize((width, height)) + + condition_scale = kwargs.get("controlnet_conditioning_scale", None) + condition_factor = kwargs.get("control_guidance_end", None) + print("condition_scale", condition_scale) + + if not get_is_sdxl(): + kwargs["guidance_scale"] = 7.5 + kwargs["strength"] = 0.8 + kwargs["controlnet_conditioning_scale"] = [condition_scale or 1.0, 0.3] + else: + kwargs["controlnet_conditioning_scale"] = [condition_scale or 0.8, 0.3] + + kwargs["control_guidance_end"] = [condition_factor or 1.0, 1.0] + + kwargs = { + "prompt": prompt[0], + "image": [init_image] * get_num_return_sequences(), + "control_image": [canny_image, depth_image], + "guidance_scale": guidance_scale, + "num_images_per_prompt": get_num_return_sequences(), + "negative_prompt": negative_prompt[0], + "num_inference_steps": num_inference_steps, + "strength": 1.0, + "height": height, + "width": width, + "generator": generator, + **kwargs, + } + print(kwargs) + + result = self.pipe.__call__(**kwargs) + return Result.from_result(result), canny_image @torch.inference_mode() def process_pose( @@ -340,22 +439,23 @@ class ControlNet(AbstractPipeline): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) kwargs = { "prompt": prompt[0], "image": image, - "num_images_per_prompt": 4, + "num_images_per_prompt": get_num_return_sequences(), "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt[0], "guidance_scale": guidance_scale, "height": height, "width": width, + "generator": generator, **kwargs, } print(kwargs) result = self.pipe2.__call__(**kwargs) - return Result.from_result(result) + return Result.from_result(result), image @torch.inference_mode() def process_tile_upscaler( @@ -374,26 +474,60 @@ class ControlNet(AbstractPipeline): if self.__current_task_name != "tile_upscaler": raise Exception("ControlNet is not loaded with tile_upscaler model") - torch.manual_seed(seed) + init_image = None + # find the correct seed and imageUrl from imageUrl + try: + p = os.path.splitext(imageUrl)[0] + p = p.split("/")[-1] + p = p.split("_")[-1] - init_image = download_image(imageUrl).resize((width, height)) - condition_image = self.__resize_for_condition_image( - init_image, resize_dimension - ) + seed = seed + int(p) + + if "_canny_2x" or "_linearart" in imageUrl: + imageUrl = imageUrl.replace("_canny_2x", "_canny_2x_highres").replace( + "_linearart_highres", "" + ) + init_image = download_image(imageUrl) + width, height = init_image.size + + print("Setting imageUrl with width and height", imageUrl, width, height) + except Exception as e: + print("Failed to extract seed from imageUrl", e) + + print("Setting seed", seed) + generator = get_generators(seed) + + if not init_image: + init_image = download_image(imageUrl).resize((width, height)) + + condition_image = ImageUtil.resize_image(init_image, 1024) + if get_is_sdxl(): + condition_image = condition_image.resize(init_image.size) + else: + condition_image = self.__resize_for_condition_image( + init_image, resize_dimension + ) + + if get_is_sdxl(): + kwargs["strength"] = 1.0 + kwargs["controlnet_conditioning_scale"] = 1.0 + kwargs["image"] = init_image + else: + kwargs["image"] = condition_image + kwargs["guidance_scale"] = guidance_scale kwargs = { - "image": condition_image, "prompt": prompt, "control_image": condition_image, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": condition_image.size[1], "width": condition_image.size[0], - "guidance_scale": guidance_scale, + "generator": generator, **kwargs, } result = self.pipe.__call__(**kwargs) - return Result.from_result(result) + return Result.from_result(result), condition_image @torch.inference_mode() def process_scribble( @@ -406,16 +540,28 @@ class ControlNet(AbstractPipeline): height: int, width: int, guidance_scale: float = 7.5, + apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "scribble": raise Exception("ControlNet is not loaded with scribble model") - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) + + if apply_preprocess: + if get_is_sdxl(): + # We use sketch in SDXL + image = [ + ControlNet.pidinet_image(image[0]).resize((width, height)) + ] * len(image) + else: + image = [ + ControlNet.scribble_image(image[0]).resize((width, height)) + ] * len(image) sdxl_args = ( { - "guidance_scale": 6, + "guidance_scale": guidance_scale, "adapter_conditioning_scale": 1.0, "adapter_conditioning_factor": 1.0, } @@ -431,11 +577,12 @@ class ControlNet(AbstractPipeline): "height": height, "width": width, "guidance_scale": guidance_scale, + "generator": generator, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) - return Result.from_result(result) + return Result.from_result(result), image[0] @torch.inference_mode() def process_linearart( @@ -448,20 +595,26 @@ class ControlNet(AbstractPipeline): height: int, width: int, guidance_scale: float = 7.5, + apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "linearart": raise Exception("ControlNet is not loaded with linearart model") - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) - init_image = download_image(imageUrl).resize((width, height)) - condition_image = ControlNet.linearart_condition_image(init_image) + init_image = self.preprocess_image(imageUrl, width, height) + + if apply_preprocess: + condition_image = ControlNet.linearart_condition_image(init_image) + condition_image = condition_image.resize(init_image.size) + else: + condition_image = init_image # we use t2i adapter and the conditioning scale should always be 0.8 sdxl_args = ( { - "guidance_scale": 6, + "guidance_scale": guidance_scale, "adapter_conditioning_scale": 1.0, "adapter_conditioning_factor": 1.0, } @@ -470,18 +623,68 @@ class ControlNet(AbstractPipeline): ) kwargs = { - "image": [condition_image] * 4, + "image": [condition_image] * get_num_return_sequences(), + "prompt": prompt, + "num_inference_steps": num_inference_steps, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "guidance_scale": guidance_scale, + "generator": generator, + **sdxl_args, + **kwargs, + } + result = self.pipe2.__call__(**kwargs) + return Result.from_result(result), condition_image + + @torch.inference_mode() + def process_depth( + self, + imageUrl: str, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]], + num_inference_steps: int, + seed: int, + height: int, + width: int, + guidance_scale: float = 7.5, + apply_preprocess: bool = True, + **kwargs, + ): + if self.__current_task_name != "depth": + raise Exception("ControlNet is not loaded with depth model") + + generator = get_generators(seed, get_num_return_sequences()) + + init_image = self.preprocess_image(imageUrl, width, height) + + if apply_preprocess: + condition_image = ControlNet.depth_image(init_image) + condition_image = condition_image.resize(init_image.size) + else: + condition_image = init_image + + # for using the depth controlnet in this SDXL model, these hyperparamters are optimal + sdxl_args = ( + {"controlnet_conditioning_scale": 0.2, "control_guidance_end": 0.2} + if get_is_sdxl() + else {} + ) + + kwargs = { + "image": [condition_image] * get_num_return_sequences(), "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, + "generator": generator, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) - return Result.from_result(result) + return Result.from_result(result), condition_image def cleanup(self): """Doesn't do anything considering new diffusers has itself a cleanup mechanism @@ -504,12 +707,15 @@ class ControlNet(AbstractPipeline): def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image: processor = LineartDetector.from_pretrained("lllyasviel/Annotators") if get_is_sdxl(): - kwargs = {"detect_resolution": 384, **kwargs} + kwargs = {"detect_resolution": 384, "image_resolution": 1024, **kwargs} + else: + kwargs = {} image = processor.__call__(input_image=image, **kwargs) return image @staticmethod + @torch.inference_mode() def depth_image(image: Image.Image) -> Image.Image: global midas, midas_transforms if "midas" not in globals(): @@ -555,6 +761,10 @@ class ControlNet(AbstractPipeline): canny_image = Image.fromarray(image_array) return canny_image + def preprocess_image(self, imageUrl, width, height) -> Image.Image: + image = download_image(imageUrl, mode="RGBA").resize((width, height)) + return ImageUtil.alpha_to_white(image) + def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size @@ -572,6 +782,7 @@ class ControlNet(AbstractPipeline): "linearart": "lllyasviel/control_v11p_sd15_lineart", "scribble": "lllyasviel/control_v11p_sd15_scribble", "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", + "canny_2x": "lllyasviel/control_v11p_sd15_canny, lllyasviel/control_v11f1p_sd15_depth", } __model_normal_types = { "pose": "controlnet", @@ -579,19 +790,24 @@ class ControlNet(AbstractPipeline): "linearart": "controlnet", "scribble": "controlnet", "tile_upscaler": "controlnet", + "canny_2x": "controlnet", } __model_sdxl = { "pose": "thibaud/controlnet-openpose-sdxl-1.0", - "canny": "diffusers/controlnet-canny-sdxl-1.0", + "canny": "Autodraft/controlnet-canny-sdxl-1.0", + "depth": "Autodraft/controlnet-depth-sdxl-1.0", + "canny_2x": "Autodraft/controlnet-canny-sdxl-1.0, Autodraft/controlnet-depth-sdxl-1.0", "linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0", "scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0", - "tile_upscaler": None, + "tile_upscaler": "Autodraft/ControlNet_SDXL_tile_upscale", } __model_sdxl_types = { "pose": "controlnet", "canny": "controlnet", + "canny_2x": "controlnet", + "depth": "controlnet", "linearart": "t2i", "scribble": "t2i", - "tile_upscaler": None, + "tile_upscaler": "controlnet", } diff --git a/internals/pipelines/high_res.py b/internals/pipelines/high_res.py index ad06b8f34c69038ae9fae9832e199d44113391c0..f86149ec68fb57ee5bae8cc158c73b2a57ab63aa 100644 --- a/internals/pipelines/high_res.py +++ b/internals/pipelines/high_res.py @@ -1,15 +1,22 @@ import math -from typing import List, Optional +from typing import Dict, List, Optional from PIL import Image from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline, Img2Img +from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc -from internals.util.config import get_base_dimension, get_is_sdxl, get_model_dir +from internals.util.config import ( + get_base_dimension, + get_is_sdxl, + get_model_dir, + get_num_return_sequences, +) +from internals.util.sdxl_lightning import LightningMixin -class HighRes(AbstractPipeline): +class HighRes(AbstractPipeline, LightningMixin): def load(self, img2img: Optional[Img2Img] = None): if hasattr(self, "pipe"): return @@ -21,6 +28,9 @@ class HighRes(AbstractPipeline): self.pipe = img2img.pipe self.img2img = img2img + if get_is_sdxl(): + self.configure_sdxl_lightning(img2img.pipe) + def apply( self, prompt: List[str], @@ -28,6 +38,7 @@ class HighRes(AbstractPipeline): images, width: int, height: int, + seed: int, num_inference_steps: int, strength: float = 0.5, guidance_scale: int = 9, @@ -35,7 +46,18 @@ class HighRes(AbstractPipeline): ): clear_cuda_and_gc() + generator = get_generators(seed, get_num_return_sequences()) + images = [image.resize((width, height)) for image in images] + + # if get_is_sdxl(): + # kwargs["guidance_scale"] = kwargs.get("guidance_scale", 15) + # kwargs["strength"] = kwargs.get("strength", 0.6) + + if get_is_sdxl(): + extra_args = self.enable_sdxl_lightning() + kwargs.update(extra_args) + kwargs = { "prompt": prompt, "image": images, @@ -43,9 +65,16 @@ class HighRes(AbstractPipeline): "negative_prompt": negative_prompt, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, + "generator": generator, **kwargs, } + + print(kwargs) result = self.pipe.__call__(**kwargs) + + if get_is_sdxl(): + self.disable_sdxl_lightning() + return Result.from_result(result) @staticmethod diff --git a/internals/pipelines/inpaint_imageprocessor.py b/internals/pipelines/inpaint_imageprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..40dfce83aaff0c941bdac962515c6b01eb171d88 --- /dev/null +++ b/internals/pipelines/inpaint_imageprocessor.py @@ -0,0 +1,976 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate +from PIL import Image, ImageFilter, ImageOps + +PipelineImageInput = Union[ + PIL.Image.Image, + np.ndarray, + torch.FloatTensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.FloatTensor], +] + +PipelineDepthInput = PipelineImageInput + + +class VaeImageProcessor(ConfigMixin): + """ + Image processor for VAE. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + self.config.do_convert_rgb = False + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode="L") for image in images + ] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def pil_to_numpy( + images: Union[List[PIL.Image.Image], PIL.Image.Image] + ) -> np.ndarray: + """ + Convert a PIL image or a list of PIL images to NumPy arrays. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images = np.stack(images, axis=0) + + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def normalize( + images: Union[np.ndarray, torch.Tensor] + ) -> Union[np.ndarray, torch.Tensor]: + """ + Normalize an image array to [-1,1]. + """ + return 2.0 * images - 1.0 + + @staticmethod + def denormalize( + images: Union[np.ndarray, torch.Tensor] + ) -> Union[np.ndarray, torch.Tensor]: + """ + Denormalize an image array to [0,1]. + """ + return (images / 2 + 0.5).clamp(0, 1) + + @staticmethod + def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts a PIL image to RGB format. + """ + image = image.convert("RGB") + + return image + + @staticmethod + def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts a PIL image to grayscale format. + """ + image = image.convert("L") + + return image + + @staticmethod + def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: + """ + Applies Gaussian blur to an image. + """ + image = image.filter(ImageFilter.GaussianBlur(blur_factor)) + + return image + + @staticmethod + def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): + """ + Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; + for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. + + Args: + mask_image (PIL.Image.Image): Mask image. + width (int): Width of the image to be processed. + height (int): Height of the image to be processed. + pad (int, optional): Padding to be added to the crop region. Defaults to 0. + + Returns: + tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio. + """ + + mask_image = mask_image.convert("L") + mask = np.array(mask_image) + + # 1. find a rectangular region that contains all masked ares in an image + h, w = mask.shape + crop_left = 0 + for i in range(w): + if not (mask[:, i] == 0).all(): + break + crop_left += 1 + + crop_right = 0 + for i in reversed(range(w)): + if not (mask[:, i] == 0).all(): + break + crop_right += 1 + + crop_top = 0 + for i in range(h): + if not (mask[i] == 0).all(): + break + crop_top += 1 + + crop_bottom = 0 + for i in reversed(range(h)): + if not (mask[i] == 0).all(): + break + crop_bottom += 1 + + # 2. add padding to the crop region + x1, y1, x2, y2 = ( + int(max(crop_left - pad, 0)), + int(max(crop_top - pad, 0)), + int(min(w - crop_right + pad, w)), + int(min(h - crop_bottom + pad, h)), + ) + + # 3. expands crop region to match the aspect ratio of the image to be processed + ratio_crop_region = (x2 - x1) / (y2 - y1) + ratio_processing = width / height + + if ratio_crop_region > ratio_processing: + desired_height = (x2 - x1) / ratio_processing + desired_height_diff = int(desired_height - (y2 - y1)) + y1 -= desired_height_diff // 2 + y2 += desired_height_diff - desired_height_diff // 2 + if y2 >= mask_image.height: + diff = y2 - mask_image.height + y2 -= diff + y1 -= diff + if y1 < 0: + y2 -= y1 + y1 -= y1 + if y2 >= mask_image.height: + y2 = mask_image.height + else: + desired_width = (y2 - y1) * ratio_processing + desired_width_diff = int(desired_width - (x2 - x1)) + x1 -= desired_width_diff // 2 + x2 += desired_width_diff - desired_width_diff // 2 + if x2 >= mask_image.width: + diff = x2 - mask_image.width + x2 -= diff + x1 -= diff + if x1 < 0: + x2 -= x1 + x1 -= x1 + if x2 >= mask_image.width: + x2 = mask_image.width + + return x1, y1, x2, y2 + + def _resize_and_fill( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + """ + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + + Args: + image: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + """ + + ratio = width / height + src_ratio = image.width / image.height + + src_w = width if ratio < src_ratio else image.width * height // image.height + src_h = height if ratio >= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste( + resized.resize((width, fill_height), box=(0, 0, width, 0)), + box=(0, 0), + ) + res.paste( + resized.resize( + (width, fill_height), + box=(0, resized.height, width, resized.height), + ), + box=(0, fill_height + src_h), + ) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste( + resized.resize((fill_width, height), box=(0, 0, 0, height)), + box=(0, 0), + ) + res.paste( + resized.resize( + (fill_width, height), + box=(resized.width, 0, resized.width, height), + ), + box=(fill_width + src_w, 0), + ) + + return res + + def _resize_and_crop( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + """ + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + + Args: + image: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + """ + ratio = width / height + src_ratio = image.width / image.height + + src_w = width if ratio > src_ratio else image.width * height // image.height + src_h = height if ratio <= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + return res + + def resize( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: int, + width: int, + resize_mode: str = "default", # "defalt", "fill", "crop" + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + """ + Resize image. + + Args: + image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. + height (`int`): + The height to resize to. + width (`int`): + The width to resize to. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit + within the specified width and height, and it may not maintaining the original aspect ratio. + If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image + within the dimensions, filling empty with data from image. + If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image + within the dimensions, cropping the excess. + Note that resize_mode `fill` and `crop` are only supported for PIL image input. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The resized image. + """ + if resize_mode != "default" and not isinstance(image, PIL.Image.Image): + raise ValueError( + f"Only PIL image input is supported for resize_mode {resize_mode}" + ) + if isinstance(image, PIL.Image.Image): + if resize_mode == "default": + image = image.resize( + (width, height), resample=PIL_INTERPOLATION[self.config.resample] + ) + elif resize_mode == "fill": + image = self._resize_and_fill(image, width, height) + elif resize_mode == "crop": + image = self._resize_and_crop(image, width, height) + else: + raise ValueError(f"resize_mode {resize_mode} is not supported") + + elif isinstance(image, torch.Tensor): + image = torch.nn.functional.interpolate( + image, + size=(height, width), + ) + elif isinstance(image, np.ndarray): + image = self.numpy_to_pt(image) + image = torch.nn.functional.interpolate( + image, + size=(height, width), + ) + image = self.pt_to_numpy(image) + return image + + def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: + """ + Create a mask. + + Args: + image (`PIL.Image.Image`): + The image input, should be a PIL image. + + Returns: + `PIL.Image.Image`: + The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1. + """ + image[image < 0.5] = 0 + image[image >= 0.5] = 1 + + return image + + def get_default_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + ) -> Tuple[int, int]: + """ + This function return the height and width that are downscaled to the next integer multiple of + `vae_scale_factor`. + + Args: + image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have + shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should + have shape `[batch, channel, height, width]`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed image. If `None`, will use the height of `image` input. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed. If `None`, will use the width of the `image` input. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + width, height = ( + x - x % self.config.vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + + return height, width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + resize_mode: str = "default", # "defalt", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`pipeline_image_input`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit + within the specified width and height, and it may not maintaining the original aspect ratio. + If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image + within the dimensions, filling empty with data from image. + If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image + within the dimensions, cropping the excess. + Note that resize_mode `fill` and `crop` are only supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if ( + self.config.do_convert_grayscale + and isinstance(image, (torch.Tensor, np.ndarray)) + and image.ndim == 3 + ): + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channnel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, supported_formats): + image = [image] + elif not ( + isinstance(image, list) + and all(isinstance(i, supported_formats) for i in image) + ): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_default_height_width(image[0], height, width) + image = [ + self.resize(i, height, width, resize_mode=resize_mode) + for i in image + ] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = ( + np.concatenate(image, axis=0) + if image[0].ndim == 4 + else np.stack(image, axis=0) + ) + + image = self.numpy_to_pt(image) + + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = ( + torch.cat(image, axis=0) + if image[0].ndim == 4 + else torch.stack(image, axis=0) + ) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == 4: + return image + + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image + + def postprocess( + self, + image: torch.FloatTensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.FloatTensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + The postprocessed image. + """ + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate( + "Unsupported output_type", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + output_type = "np" + + if output_type == "latent": + return image + + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [ + self.denormalize(image[i]) if do_denormalize[i] else image[i] + for i in range(image.shape[0]) + ] + ) + + if output_type == "pt": + return image + + image = self.pt_to_numpy(image) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) + + def apply_overlay( + self, + mask: PIL.Image.Image, + init_image: PIL.Image.Image, + image: PIL.Image.Image, + crop_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> PIL.Image.Image: + """ + overlay the inpaint output to the original image + """ + + image = image.resize(init_image.size) + width, height = image.width, image.height + + init_image = self.resize(init_image, width=width, height=height) + mask = self.resize(mask, width=width, height=height) + + init_image_masked = PIL.Image.new("RGBa", (width, height)) + init_image_masked.paste( + init_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(mask.convert("L")), + ) + init_image_masked = init_image_masked.convert("RGBA") + + if crop_coords is not None: + x, y, x2, y2 = crop_coords + w = x2 - x + h = y2 - y + base_image = PIL.Image.new("RGBA", (width, height)) + image = self.resize(image, height=h, width=w, resize_mode="crop") + base_image.paste(image, (x, y)) + image = base_image.convert("RGB") + + image = image.convert("RGBA") + image.alpha_composite(init_image_masked) + image = image.convert("RGB") + + return image + + +class VaeImageProcessorLDM3D(VaeImageProcessor): + """ + Image processor for VAE LDM3D. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + ): + super().__init__() + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode="L") for image in images + ] + else: + pil_images = [Image.fromarray(image[:, :, :3]) for image in images] + + return pil_images + + @staticmethod + def depth_pil_to_numpy( + images: Union[List[PIL.Image.Image], PIL.Image.Image] + ) -> np.ndarray: + """ + Convert a PIL image or a list of PIL images to NumPy arrays. + """ + if not isinstance(images, list): + images = [images] + + images = [ + np.array(image).astype(np.float32) / (2**16 - 1) for image in images + ] + images = np.stack(images, axis=0) + return images + + @staticmethod + def rgblike_to_depthmap( + image: Union[np.ndarray, torch.Tensor] + ) -> Union[np.ndarray, torch.Tensor]: + """ + Args: + image: RGB-like depth image + + Returns: depth map + + """ + return image[:, :, 1] * 2**8 + image[:, :, 2] + + def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: + """ + Convert a NumPy depth image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images_depth = images[:, :, :, 3:] + if images.shape[-1] == 6: + images_depth = (images_depth * 255).round().astype("uint8") + pil_images = [ + Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") + for image_depth in images_depth + ] + elif images.shape[-1] == 4: + images_depth = (images_depth * 65535.0).astype(np.uint16) + pil_images = [ + Image.fromarray(image_depth, mode="I;16") + for image_depth in images_depth + ] + else: + raise Exception("Not supported") + + return pil_images + + def postprocess( + self, + image: torch.FloatTensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.FloatTensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + The postprocessed image. + """ + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate( + "Unsupported output_type", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + output_type = "np" + + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [ + self.denormalize(image[i]) if do_denormalize[i] else image[i] + for i in range(image.shape[0]) + ] + ) + + image = self.pt_to_numpy(image) + + if output_type == "np": + if image.shape[-1] == 6: + image_depth = np.stack( + [self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0 + ) + else: + image_depth = image[:, :, :, 3:] + return image[:, :, :, :3], image_depth + + if output_type == "pil": + return self.numpy_to_pil(image), self.numpy_to_depth(image) + else: + raise Exception(f"This type {output_type} is not supported") + + def preprocess( + self, + rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + height: Optional[int] = None, + width: Optional[int] = None, + target_res: Optional[int] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if ( + self.config.do_convert_grayscale + and isinstance(rgb, (torch.Tensor, np.ndarray)) + and rgb.ndim == 3 + ): + raise Exception("This is not yet supported") + + if isinstance(rgb, supported_formats): + rgb = [rgb] + depth = [depth] + elif not ( + isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb) + ): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(rgb[0], PIL.Image.Image): + if self.config.do_convert_rgb: + raise Exception("This is not yet supported") + # rgb = [self.convert_to_rgb(i) for i in rgb] + # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth + if self.config.do_resize or target_res: + height, width = ( + self.get_default_height_width(rgb[0], height, width) + if not target_res + else target_res + ) + rgb = [self.resize(i, height, width) for i in rgb] + depth = [self.resize(i, height, width) for i in depth] + rgb = self.pil_to_numpy(rgb) # to np + rgb = self.numpy_to_pt(rgb) # to pt + + depth = self.depth_pil_to_numpy(depth) # to np + depth = self.numpy_to_pt(depth) # to pt + + elif isinstance(rgb[0], np.ndarray): + rgb = ( + np.concatenate(rgb, axis=0) + if rgb[0].ndim == 4 + else np.stack(rgb, axis=0) + ) + rgb = self.numpy_to_pt(rgb) + height, width = self.get_default_height_width(rgb, height, width) + if self.config.do_resize: + rgb = self.resize(rgb, height, width) + + depth = ( + np.concatenate(depth, axis=0) + if rgb[0].ndim == 4 + else np.stack(depth, axis=0) + ) + depth = self.numpy_to_pt(depth) + height, width = self.get_default_height_width(depth, height, width) + if self.config.do_resize: + depth = self.resize(depth, height, width) + + elif isinstance(rgb[0], torch.Tensor): + raise Exception("This is not yet supported") + # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0) + + # if self.config.do_convert_grayscale and rgb.ndim == 3: + # rgb = rgb.unsqueeze(1) + + # channel = rgb.shape[1] + + # height, width = self.get_default_height_width(rgb, height, width) + # if self.config.do_resize: + # rgb = self.resize(rgb, height, width) + + # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0) + + # if self.config.do_convert_grayscale and depth.ndim == 3: + # depth = depth.unsqueeze(1) + + # channel = depth.shape[1] + # # don't need any preprocess if the image is latents + # if depth == 4: + # return rgb, depth + + # height, width = self.get_default_height_width(depth, height, width) + # if self.config.do_resize: + # depth = self.resize(depth, height, width) + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if rgb.min() < 0 and do_normalize: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + rgb = self.normalize(rgb) + depth = self.normalize(depth) + + if self.config.do_binarize: + rgb = self.binarize(rgb) + depth = self.binarize(depth) + + return rgb, depth diff --git a/internals/pipelines/inpainter.py b/internals/pipelines/inpainter.py index af72c09a1292dc9169215e120dac5041216b45d2..378d62234743f1d5317c7fc8ce2556ac1ac2174e 100644 --- a/internals/pipelines/inpainter.py +++ b/internals/pipelines/inpainter.py @@ -1,18 +1,27 @@ from typing import List, Union import torch -from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline +from diffusers import ( + StableDiffusionInpaintPipeline, + StableDiffusionXLInpaintPipeline, + UNet2DConditionModel, +) from internals.pipelines.commons import AbstractPipeline +from internals.pipelines.high_res import HighRes +from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor +from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import disable_safety_checker, download_image from internals.util.config import ( + get_base_inpaint_model_revision, get_base_inpaint_model_variant, get_hf_cache_dir, get_hf_token, get_inpaint_model_path, get_is_sdxl, get_model_dir, + get_num_return_sequences, ) @@ -32,13 +41,27 @@ class InPainter(AbstractPipeline): return if get_is_sdxl(): - self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained( + # only take UNet from the repo + unet = UNet2DConditionModel.from_pretrained( get_inpaint_model_path(), torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), token=get_hf_token(), + subfolder="unet", variant=get_base_inpaint_model_variant(), + revision=get_base_inpaint_model_revision(), ).to("cuda") + kwargs = {**self.__base.pipe.components, "unet": unet} + self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda") + self.pipe.mask_processor = VaeImageProcessor( + vae_scale_factor=self.pipe.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.pipe.image_processor = VaeImageProcessor( + vae_scale_factor=self.pipe.vae_scale_factor + ) else: self.pipe = StableDiffusionInpaintPipeline.from_pretrained( get_inpaint_model_path(), @@ -90,11 +113,18 @@ class InPainter(AbstractPipeline): num_inference_steps: int, **kwargs, ): - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) input_img = download_image(image_url).resize((width, height)) mask_img = download_image(mask_image_url).resize((width, height)) + if get_is_sdxl(): + width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height) + mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33) + + kwargs["strength"] = 0.999 + kwargs["padding_mask_crop"] = 1000 + kwargs = { "prompt": prompt, "image": input_img, @@ -104,6 +134,7 @@ class InPainter(AbstractPipeline): "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "strength": 1.0, + "generator": generator, **kwargs, } - return self.pipe.__call__(**kwargs).images + return self.pipe.__call__(**kwargs).images, mask_img diff --git a/internals/pipelines/prompt_modifier.py b/internals/pipelines/prompt_modifier.py index 9892786086d79a2798d3816075e4ec9b9f6faece..22fe48911513f7e4497a6130de4d6f4d906376dc 100644 --- a/internals/pipelines/prompt_modifier.py +++ b/internals/pipelines/prompt_modifier.py @@ -2,6 +2,8 @@ from typing import List, Optional from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from internals.util.config import get_num_return_sequences + class PromptModifier: __loaded = False @@ -38,7 +40,7 @@ class PromptModifier: do_sample=False, max_new_tokens=75, num_beams=4, - num_return_sequences=num_of_sequences, + num_return_sequences=get_num_return_sequences(), eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0, diff --git a/internals/pipelines/realtime_draw.py b/internals/pipelines/realtime_draw.py index a5c36a9a8e9cdaa37fa5e075cba5f3e931e5d25b..1d58220b2b738dcb3fd09b8888fc61f071195346 100644 --- a/internals/pipelines/realtime_draw.py +++ b/internals/pipelines/realtime_draw.py @@ -9,7 +9,13 @@ from internals.pipelines.commons import AbstractPipeline from internals.pipelines.controlnets import ControlNet from internals.pipelines.high_res import HighRes from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline -from internals.util.config import get_base_dimension, get_hf_cache_dir, get_is_sdxl +from internals.util import get_generators +from internals.util.config import ( + get_base_dimension, + get_hf_cache_dir, + get_is_sdxl, + get_num_return_sequences, +) class RealtimeDraw(AbstractPipeline): @@ -60,7 +66,7 @@ class RealtimeDraw(AbstractPipeline): if get_is_sdxl(): raise Exception("SDXL is not supported for this method") - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) image = ImageUtil.resize_image(image, 512) @@ -70,6 +76,7 @@ class RealtimeDraw(AbstractPipeline): prompt=prompt, num_inference_steps=15, negative_prompt=negative_prompt, + generator=generator, guidance_scale=10, strength=0.8, ).images[0] @@ -84,7 +91,7 @@ class RealtimeDraw(AbstractPipeline): image: Optional[Image.Image] = None, image2: Optional[Image.Image] = None, ): - torch.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) b_dimen = get_base_dimension() @@ -104,6 +111,8 @@ class RealtimeDraw(AbstractPipeline): size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1]) image = image.resize(size) + torch.manual_seed(seed) + images = self.pipe.__call__( image=image, condition_image=image, @@ -129,6 +138,7 @@ class RealtimeDraw(AbstractPipeline): num_inference_steps=15, negative_prompt=negative_prompt, guidance_scale=10, + generator=generator, strength=0.9, width=image.size[0], height=image.size[1], diff --git a/internals/pipelines/remove_background.py b/internals/pipelines/remove_background.py index e7c35e8a4d056291e9c121247298805e3dee907b..5cade8fc757d39feeeb08fe6c92e76fc1c88341e 100644 --- a/internals/pipelines/remove_background.py +++ b/internals/pipelines/remove_background.py @@ -1,20 +1,22 @@ import io from pathlib import Path from typing import Union -import numpy as np -import cv2 +import cv2 +import huggingface_hub +import numpy as np +import onnxruntime as rt import torch import torch.nn.functional as F +from briarmbg import BriaRMBG # pyright: ignore from PIL import Image from rembg import remove -from internals.data.task import ModelType +from torchvision.transforms.functional import normalize import internals.util.image as ImageUtil from carvekit.api.high import HiInterface +from internals.data.task import ModelType from internals.util.commons import download_image, read_url -import onnxruntime as rt -import huggingface_hub class RemoveBackground: @@ -94,3 +96,51 @@ class RemoveBackgroundV2: img = np.concatenate([img, mask], axis=2, dtype=np.uint8) mask = mask.repeat(3, axis=2) return mask, img + + +class RemoveBackgroundV3: + def __init__(self): + net = BriaRMBG.from_pretrained("briaai/RMBG-1.4") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net.to(device) + self.net = net + + def remove(self, image: Union[str, Image.Image]) -> Image.Image: + if type(image) is str: + image = download_image(image, mode="RGBA") + + orig_image = image + w, h = orig_im_size = orig_image.size + image = self.__resize_image(orig_image) + im_np = np.array(image) + im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) + im_tensor = torch.unsqueeze(im_tensor, 0) + im_tensor = torch.divide(im_tensor, 255.0) + im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) + if torch.cuda.is_available(): + im_tensor = im_tensor.cuda() + + # inference + result = self.net(im_tensor) + # post process + result = torch.squeeze( + F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0 + ) + ma = torch.max(result) + mi = torch.min(result) + result = (result - mi) / (ma - mi) + # image to pil + im_array = (result * 255).cpu().data.numpy().astype(np.uint8) + pil_im = Image.fromarray(np.squeeze(im_array)) + # paste the mask on the original image + new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) + new_im.paste(orig_image, mask=pil_im) + # new_orig_image = orig_image.convert('RGBA') + + return new_im + + def __resize_image(self, image): + image = image.convert("RGB") + model_input_size = (1024, 1024) + image = image.resize(model_input_size, Image.BILINEAR) + return image diff --git a/internals/pipelines/replace_background.py b/internals/pipelines/replace_background.py index ae314716065fcac042fcda49938764ec9ad4e4b5..8e9239bc6d370e64516327ece670d6296f5e97ed 100644 --- a/internals/pipelines/replace_background.py +++ b/internals/pipelines/replace_background.py @@ -16,11 +16,12 @@ import internals.util.image as ImageUtil from internals.data.result import Result from internals.data.task import ModelType from internals.pipelines.commons import AbstractPipeline -from internals.pipelines.controlnets import ControlNet +from internals.pipelines.controlnets import ControlNet, load_network_model_by_key from internals.pipelines.high_res import HighRes from internals.pipelines.inpainter import InPainter from internals.pipelines.remove_background import RemoveBackgroundV2 from internals.pipelines.upscaler import Upscaler +from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( @@ -28,6 +29,7 @@ from internals.util.config import ( get_hf_token, get_inpaint_model_path, get_model_dir, + get_num_return_sequences, ) @@ -43,11 +45,9 @@ class ReplaceBackground(AbstractPipeline): ): if self.__loaded: return - controlnet_model = ControlNetModel.from_pretrained( - "lllyasviel/control_v11p_sd15_canny", - torch_dtype=torch.float16, - cache_dir=get_hf_cache_dir(), - ).to("cuda") + controlnet_model = load_network_model_by_key( + "lllyasviel/control_v11p_sd15_canny", "controlnet" + ) if base: pipe = StableDiffusionControlNetPipeline( **base.pipe.components, @@ -109,8 +109,7 @@ class ReplaceBackground(AbstractPipeline): if type(image) is str: image = download_image(image) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + generator = get_generators(seed, get_num_return_sequences()) image = image.convert("RGB") if max(image.size) > 1024: @@ -148,6 +147,7 @@ class ReplaceBackground(AbstractPipeline): guidance_scale=9, height=height, num_inference_steps=steps, + generator=generator, width=width, ) result = Result.from_result(result) diff --git a/internals/pipelines/safety_checker.py b/internals/pipelines/safety_checker.py index 562eb0bfcd930230985c57a05c08cb81d3cddab1..844c63d76cd9f9dcf70db4732d7d49e27f1001c7 100644 --- a/internals/pipelines/safety_checker.py +++ b/internals/pipelines/safety_checker.py @@ -31,10 +31,11 @@ class SafetyChecker: self.__loaded = True def apply(self, pipeline: AbstractPipeline): - model = self.model if not get_nsfw_access() else None - if model: + if not get_nsfw_access(): self.load() + model = self.model if not get_nsfw_access() else None + if not pipeline: return if hasattr(pipeline, "pipe"): diff --git a/internals/pipelines/sdxl_llite_pipeline.py b/internals/pipelines/sdxl_llite_pipeline.py index 9c0baa52a460fad76100ca2d01d85ad2371727de..5033271ab80a20b0bf5eb3e9521f5a4240933b37 100644 --- a/internals/pipelines/sdxl_llite_pipeline.py +++ b/internals/pipelines/sdxl_llite_pipeline.py @@ -1251,6 +1251,8 @@ class PipelineLike: class SDXLLLiteImg2ImgPipeline: + from diffusers import UNet2DConditionModel + def __init__(self): self.SCHEDULER_LINEAR_START = 0.00085 self.SCHEDULER_LINEAR_END = 0.0120 @@ -1261,7 +1263,7 @@ class SDXLLLiteImg2ImgPipeline: def replace_unet_modules( self, - unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, + unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa, diff --git a/internals/pipelines/sdxl_tile_upscale.py b/internals/pipelines/sdxl_tile_upscale.py index 038065830b34ea9a681307e85c90efc2a4ba5a9a..ee2c5d7080ff18246a04f31d2e4055ed51a91a11 100644 --- a/internals/pipelines/sdxl_tile_upscale.py +++ b/internals/pipelines/sdxl_tile_upscale.py @@ -4,8 +4,10 @@ from PIL import Image from torchvision import transforms import internals.util.image as ImageUtils +import internals.util.image as ImageUtil from carvekit.api import high from internals.data.result import Result +from internals.data.task import TaskType from internals.pipelines.commons import AbstractPipeline, Text2Img from internals.pipelines.controlnets import ControlNet from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline @@ -19,18 +21,16 @@ controlnet = ControlNet() class SDXLTileUpscaler(AbstractPipeline): __loaded = False + __current_process_mode = None def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int): if self.__loaded: return # temporal hack for upscale model till multicontrolnet support is added - model = ( - "thibaud/controlnet-openpose-sdxl-1.0" - if int(model_id) == 2000293 - else "diffusers/controlnet-canny-sdxl-1.0" - ) - controlnet = ControlNetModel.from_pretrained(model, torch_dtype=torch.float16) + controlnet = ControlNetModel.from_pretrained( + "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ) pipe = DemoFusionSDXLControlNetPipeline( **pipeline.pipe.components, controlnet=controlnet ) @@ -43,6 +43,7 @@ class SDXLTileUpscaler(AbstractPipeline): self.pipe = pipe + self.__current_process_mode = TaskType.CANNY.name self.__loaded = True def unload(self): @@ -52,6 +53,26 @@ class SDXLTileUpscaler(AbstractPipeline): clear_cuda_and_gc() + def __reload_controlnet(self, process_mode: str): + if self.__current_process_mode == process_mode: + return + + model = ( + "thibaud/controlnet-openpose-sdxl-1.0" + if process_mode == TaskType.POSE.name + else "diffusers/controlnet-canny-sdxl-1.0" + ) + controlnet = ControlNetModel.from_pretrained( + model, torch_dtype=torch.float16 + ).to("cuda") + + if hasattr(self, "pipe"): + self.pipe.controlnet = controlnet + + self.__current_process_mode = process_mode + + clear_cuda_and_gc() + def process( self, prompt: str, @@ -61,21 +82,36 @@ class SDXLTileUpscaler(AbstractPipeline): width: int, height: int, model_id: int, + seed: int, + process_mode: str, ): - if int(model_id) == 2000293: + generator = torch.manual_seed(seed) + + self.__reload_controlnet(process_mode) + + if process_mode == TaskType.POSE.name: + print("Running POSE") condition_image = controlnet.detect_pose(imageUrl) else: + print("Running CANNY") condition_image = download_image(imageUrl) condition_image = ControlNet.canny_detect_edge(condition_image) - img = download_image(imageUrl).resize((width, height)) + width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height) - img = ImageUtils.resize_image(img, get_base_dimension()) + img = download_image(imageUrl).resize((width, height)) condition_image = condition_image.resize(img.size) img2 = self.__resize_for_condition_image(img, resize_dimension) + img = self.pad_image(img) image_lr = self.load_and_process_image(img) - print("img", img2.size, img.size) + + out_img = self.pad_image(img2) + condition_image = self.pad_image(condition_image) + + print("img", img.size) + print("img2", img2.size) + print("condition", condition_image.size) if int(model_id) == 2000173: kwargs = { "prompt": prompt, @@ -83,6 +119,7 @@ class SDXLTileUpscaler(AbstractPipeline): "image": img2, "strength": 0.3, "num_inference_steps": 30, + "generator": generator, } images = self.high_res.pipe.__call__(**kwargs).images else: @@ -90,20 +127,24 @@ class SDXLTileUpscaler(AbstractPipeline): image_lr=image_lr, prompt=prompt, condition_image=condition_image, - negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic", + negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic, " + + negative_prompt, guidance_scale=11, sigma=0.8, num_inference_steps=24, - width=img2.size[0], - height=img2.size[1], + controlnet_conditioning_scale=0.5, + generator=generator, + width=out_img.size[0], + height=out_img.size[1], ) images = images[::-1] + iv = ImageUtil.resize_image(img2, images[0].size[0]) + images = [self.unpad_image(images[0], iv.size)] return images, False def load_and_process_image(self, pil_image): transform = transforms.Compose( [ - transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] @@ -113,6 +154,36 @@ class SDXLTileUpscaler(AbstractPipeline): image = image.to("cuda") return image + def pad_image(self, image): + w, h = image.size + if w == h: + return image + elif w > h: + new_image = Image.new(image.mode, (w, w), (0, 0, 0)) + pad_w = 0 + pad_h = (w - h) // 2 + new_image.paste(image, (0, pad_h)) + return new_image + else: + new_image = Image.new(image.mode, (h, h), (0, 0, 0)) + pad_w = (h - w) // 2 + pad_h = 0 + new_image.paste(image, (pad_w, 0)) + return new_image + + def unpad_image(self, padded_image, original_size): + w, h = original_size + if w == h: + return padded_image + elif w > h: + pad_h = (w - h) // 2 + unpadded_image = padded_image.crop((0, pad_h, w, h + pad_h)) + return unpadded_image + else: + pad_w = (h - w) // 2 + unpadded_image = padded_image.crop((pad_w, 0, w + pad_w, h)) + return unpadded_image + def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size diff --git a/internals/pipelines/upscaler.py b/internals/pipelines/upscaler.py index 5c9b5dee2611c127d68e98e27a507f621494c698..6b11dbbea3bf563b59ef18e459c882a0aee55f61 100644 --- a/internals/pipelines/upscaler.py +++ b/internals/pipelines/upscaler.py @@ -1,7 +1,8 @@ +import io import math import os from pathlib import Path -from typing import Union +from typing import Optional, Union import cv2 import numpy as np @@ -10,7 +11,7 @@ from basicsr.archs.srvgg_arch import SRVGGNetCompact from basicsr.utils.download_util import load_file_from_url from gfpgan import GFPGANer from PIL import Image -from realesrgan import RealESRGANer +from realesrgan import RealESRGANer # pyright: ignore import internals.util.image as ImageUtil from internals.util.commons import download_image @@ -55,8 +56,12 @@ class Upscaler: width: int, height: int, face_enhance: bool, - resize_dimension: int, + resize_dimension: Optional[int] = None, ) -> bytes: + "if resize dimension is not provided, use the smaller of width and height" + + self.load() + model = SRVGGNetCompact( num_in_ch=3, num_out_ch=3, @@ -67,7 +72,7 @@ class Upscaler: ) return self.__internal_upscale( image, - resize_dimension, + resize_dimension, # type: ignore face_enhance, width, height, @@ -83,6 +88,10 @@ class Upscaler: face_enhance: bool, resize_dimension: int, ) -> bytes: + "if resize dimension is not provided, use the smaller of width and height" + + self.load() + model = RRDBNet( num_in_ch=3, num_out_ch=3, @@ -124,18 +133,22 @@ class Upscaler: model, ) -> bytes: if type(image) is str: - image = download_image(image) + image = download_image(image, mode="RGBA") w, h = image.size - if max(w, h) > 1024: - image = ImageUtil.resize_image(image, dimension=1024) + # if max(w, h) > 1024: + # image = ImageUtil.resize_image(image, dimension=1024) in_path = str(Path.home() / ".cache" / "input_upscale.png") image.save(in_path) input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED) - dimension = min(input_image.shape[0], input_image.shape[1]) + dimension = max(input_image.shape[0], input_image.shape[1]) + if not resize_dimension: + resize_dimension = max(width, height) scale = max(math.floor(resize_dimension / dimension), 2) + print("Upscaling by: ", scale) + os.chdir(str(Path.home() / ".cache")) if scale == 4: print("Using 4x-Ultrasharp") @@ -174,3 +187,7 @@ class Upscaler: cv2.imwrite("out.png", output) out_bytes = cv2.imencode(".png", output)[1].tobytes() return out_bytes + + @staticmethod + def to_pil(buffer: bytes, mode="RGB") -> Image.Image: + return Image.open(io.BytesIO(buffer)).convert(mode) diff --git a/internals/util/__init__.py b/internals/util/__init__.py index f5bfce6542b96760057d85ee2db19848f6d508c5..d8e560ca89cd708639fd563d26d3ed623b8d5831 100644 --- a/internals/util/__init__.py +++ b/internals/util/__init__.py @@ -1,7 +1,13 @@ import os +import torch + from internals.util.config import get_root_dir def getcwd(): return get_root_dir() + + +def get_generators(seed, num_generators=1): + return [torch.Generator().manual_seed(seed + i) for i in range(num_generators)] diff --git a/internals/util/cache.py b/internals/util/cache.py index ca2f85040ba8eca5d4807dbe8806f9302dc0fe5b..cf9ec942262bf3d2fdea151d5ee0c25c72b1ef3d 100644 --- a/internals/util/cache.py +++ b/internals/util/cache.py @@ -1,5 +1,6 @@ import gc import os + import psutil import torch @@ -7,6 +8,7 @@ import torch def print_memory_usage(): process = psutil.Process(os.getpid()) print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB") + print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB") def clear_cuda_and_gc(): diff --git a/internals/util/commons.py b/internals/util/commons.py index 823426599ee71e03f43c541a6c9e533d85d69da3..124967d6d967c9e10d0d72c326842f272d65166c 100644 --- a/internals/util/commons.py +++ b/internals/util/commons.py @@ -11,7 +11,7 @@ from typing import Any, Optional, Union import boto3 import requests -from internals.util.config import api_endpoint, api_headers +from internals.util.config import api_endpoint, api_headers, elb_endpoint s3 = boto3.client("s3") import io @@ -103,7 +103,7 @@ def upload_images(images, processName: str, taskId: str): img_io.seek(0) key = "crecoAI/{}{}_{}.png".format(taskId, processName, i) res = requests.post( - api_endpoint() + elb_endpoint() + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName=" + "{}{}_{}.png".format(taskId, processName, i), headers=api_headers(), @@ -129,12 +129,12 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path): image.seek(0) print( - api_endpoint() + elb_endpoint() + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName=" + str(out_path).replace("crecoAI/", ""), ) res = requests.post( - api_endpoint() + elb_endpoint() + "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName=" + str(out_path).replace("crecoAI/", ""), headers=api_headers(), diff --git a/internals/util/config.py b/internals/util/config.py index bf5773da130db8579d38a6a5a0e40bf4212c97a5..9d1083dbbfd7b2cf4560162a40220c3c7230bf75 100644 --- a/internals/util/config.py +++ b/internals/util/config.py @@ -13,7 +13,7 @@ access_token = "" root_dir = "" model_config = None hf_token = base64.b64decode( - b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA==" + b"aGZfaXRvVVJzTmN1RHZab1hXZ3hIeFRRRGdvSHdrQ2VNUldGbA==" ).decode() hf_cache_dir = "/tmp/hf_hub" @@ -46,7 +46,7 @@ def set_model_config(config: ModelConfig): def set_configs_from_task(task: Task): global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences - name = task.get_queue_name() + name = task.get_environment() if name.startswith("gamma"): env = "gamma" else: @@ -120,14 +120,25 @@ def get_base_model_variant(): return model_config.base_model_variant # pyright: ignore +def get_base_model_revision(): + global model_config + return model_config.base_model_revision # pyright: ignore + + def get_base_inpaint_model_variant(): global model_config return model_config.base_inpaint_model_variant # pyright: ignore +def get_base_inpaint_model_revision(): + global model_config + return model_config.base_inpaint_model_revision # pyright: ignore + + def api_headers(): return { "Access-Token": access_token, + "Host": "api.autodraft.in" if env == "prod" else "gamma-api.autodraft.in", } @@ -138,8 +149,11 @@ def api_endpoint(): return "https://gamma-api.autodraft.in" -def comic_url(): +def elb_endpoint(): + # We use the ELB endpoint for uploading images since + # cloudflare has a hard limit of 100mb when the + # DNS is proxied if env == "prod": - return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80" + return "http://k8s-prod-ingresse-8ba91151af-2105029163.ap-south-1.elb.amazonaws.com" else: - return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80" + return "http://k8s-gamma-ingresse-fc1051bc41-1227070426.ap-south-1.elb.amazonaws.com" diff --git a/internals/util/failure_hander.py b/internals/util/failure_hander.py index 8bdfa2660675b41b0d32ffd8c93c90562a6590a1..293f5707b1f81b7d81d56f9e6137ce38016f38a4 100644 --- a/internals/util/failure_hander.py +++ b/internals/util/failure_hander.py @@ -16,10 +16,13 @@ class FailureHandler: path = FailureHandler.__task_path path.parent.mkdir(parents=True, exist_ok=True) if path.exists(): - task = Task(json.loads(path.read_text())) - set_configs_from_task(task) - # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE")) - updateSource(task.get_sourceId(), task.get_userId(), "FAILED") + try: + task = Task(json.loads(path.read_text())) + set_configs_from_task(task) + # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE")) + updateSource(task.get_sourceId(), task.get_userId(), "FAILED") + except Exception as e: + print("Failed to handle task", e) os.remove(path) @staticmethod diff --git a/internals/util/image.py b/internals/util/image.py index cba7712eec8e93f14807fb34f653aa99ac04640d..2381d8542861dd62617c87adcc6ad5d4dcfc5406 100644 --- a/internals/util/image.py +++ b/internals/util/image.py @@ -48,3 +48,21 @@ def padd_image(image: Image.Image, to_width: int, to_height: int) -> Image.Image img = Image.new("RGBA", (to_width, to_height), (0, 0, 0, 0)) img.paste(image, ((to_width - iw) // 2, (to_height - ih) // 2)) return img + + +def alpha_to_white(img: Image.Image) -> Image.Image: + if img.mode == "RGBA": + data = img.getdata() + + new_data = [] + + for item in data: + if item[3] == 0: + new_data.append((255, 255, 255, 255)) + else: + new_data.append(item) + + img.putdata(new_data) + + img = img.convert("RGB") + return img diff --git a/internals/util/lora_style.py b/internals/util/lora_style.py index 8c4fb52bc1ac833cdd4a9d07215e063a5cf236a4..d58eeaea26b68d56709da5028cd0d52c0677aadb 100644 --- a/internals/util/lora_style.py +++ b/internals/util/lora_style.py @@ -52,9 +52,18 @@ class LoraStyle: def patch(self): def run(pipe): path = self.__style["path"] - pipe.load_lora_weights( - os.path.dirname(path), weight_name=os.path.basename(path) - ) + name = str(self.__style["tag"]).replace(" ", "_") + weight = self.__style.get("weight", 1.0) + if name not in pipe.get_list_adapters().get("unet", []): + print( + f"Loading lora {os.path.basename(path)} with weights {weight}, name: {name}" + ) + pipe.load_lora_weights( + os.path.dirname(path), + weight_name=os.path.basename(path), + adapter_name=name, + ) + pipe.set_adapters([name], adapter_weights=[weight]) for p in self.pipe: run(p) @@ -105,7 +114,17 @@ class LoraStyle: def prepend_style_to_prompt(self, prompt: str, key: str) -> str: if key in self.__styles: style = self.__styles[key] - return f"{', '.join(style['text'])}, {prompt}" + prompt = f"{', '.join(style['text'])}, {prompt}" + prompt = prompt.replace(", ", "") + return prompt + + def append_style_to_prompt(self, prompt: str, key: str) -> str: + if key in self.__styles and "text_append" in self.__styles[key]: + style = self.__styles[key] + if prompt.endswith(","): + prompt = prompt[:-1] + prompt = f"{prompt}, {', '.join(style['text_append'])}" + prompt = prompt.replace(", ", "") return prompt def get_patcher( @@ -140,7 +159,9 @@ class LoraStyle: "path": str(file_path), "weight": attr["weight"], "type": attr["type"], + "tag": item["tag"], "text": attr["text"], + "text_append": attr.get("text_append", []), "negativePrompt": attr["negativePrompt"], } return styles @@ -159,4 +180,7 @@ class LoraStyle: @staticmethod def unload_lora_weights(pipe): - pipe.unload_lora_weights() + # we keep the lora layers in the adapters and unset it whenever + # not required instead of completely unloading it + pipe.set_adapters([]) + # pipe.unload_lora_weights() diff --git a/internals/util/model_loader.py b/internals/util/model_loader.py index 62bf3b3b09ff158b021b0d2154f69060d37711e7..6a3eced8cc08c428a8aedc11ab6be196daa1c3d6 100644 --- a/internals/util/model_loader.py +++ b/internals/util/model_loader.py @@ -18,7 +18,9 @@ class ModelConfig: base_dimension: int = 512 low_gpu_mem: bool = False base_model_variant: Optional[str] = None + base_model_revision: Optional[str] = None base_inpaint_model_variant: Optional[str] = None + base_inpaint_model_revision: Optional[str] = None def load_model_from_config(path): @@ -31,7 +33,11 @@ def load_model_from_config(path): is_sdxl = config.get("is_sdxl", False) base_dimension = config.get("base_dimension", 512) base_model_variant = config.get("base_model_variant", None) + base_model_revision = config.get("base_model_revision", None) base_inpaint_model_variant = config.get("base_inpaint_model_variant", None) + base_inpaint_model_revision = config.get( + "base_inpaint_model_revision", None + ) m_config.base_model_path = model_path m_config.base_inpaint_model_path = inpaint_model_path @@ -39,7 +45,9 @@ def load_model_from_config(path): m_config.base_dimension = base_dimension m_config.low_gpu_mem = config.get("low_gpu_mem", False) m_config.base_model_variant = base_model_variant + m_config.base_model_revision = base_model_revision m_config.base_inpaint_model_variant = base_inpaint_model_variant + m_config.base_inpaint_model_revision = base_inpaint_model_revision # # if config.get("model_type") == "huggingface": diff --git a/internals/util/prompt.py b/internals/util/prompt.py index 688d5e89f7d917918ae1ebf42727d155e84ec4ce..0334564f7991c6ebcd95337d5aa207ddc38db427 100644 --- a/internals/util/prompt.py +++ b/internals/util/prompt.py @@ -21,6 +21,7 @@ def get_patched_prompt( for i in range(len(prompt)): prompt[i] = avatar.add_code_names(prompt[i]) prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) + prompt[i] = lora_style.append_style_to_prompt(prompt[i], task.get_style()) if additional: prompt[i] = additional + " " + prompt[i] @@ -51,6 +52,7 @@ def get_patched_prompt_text2img( def add_style_and_character(prompt: str, prepend: str = ""): prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) + prompt = lora_style.append_style_to_prompt(prompt, task.get_style()) prompt = prepend + prompt return prompt @@ -102,6 +104,7 @@ def get_patched_prompt_tile_upscale( lora_style: LoraStyle, img_classifier: ImageClassifier, img2text: Image2Text, + is_sdxl=False, ): if task.get_prompt(): prompt = task.get_prompt() @@ -114,10 +117,12 @@ def get_patched_prompt_tile_upscale( prompt = task.PROMPT.merge_blip(blip) # remove anomalies in prompt - prompt = remove_colors(prompt) + if not is_sdxl: + prompt = remove_colors(prompt) prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) + prompt = lora_style.append_style_to_prompt(prompt, task.get_style()) if not task.get_style(): class_name = img_classifier.classify( diff --git a/internals/util/sdxl_lightning.py b/internals/util/sdxl_lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..282c86acce43536e6858002a9a6ff87c6a38ed73 --- /dev/null +++ b/internals/util/sdxl_lightning.py @@ -0,0 +1,74 @@ +from pathlib import Path +from re import S +from typing import List, Union + +from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline +from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin +from torchvision.datasets.utils import download_url + + +class LightningMixin: + LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors" + + __scheduler_old = None + __pipe: StableDiffusionXLPipeline = None + __scheduler = None + + def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline): + lora_path = Path.home() / ".cache" / "lora_8_step.safetensors" + + download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name) + + pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora") + pipe.set_adapters([]) + + self.__scheduler = EulerDiscreteScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) + self.__scheduler_old = pipe.scheduler + self.__pipe = pipe + + def enable_sdxl_lightning(self): + pipe = self.__pipe + pipe.scheduler = self.__scheduler + + current = pipe.get_active_adapters() + current.extend(["8step_lora"]) + + weights = self.__find_adapter_weights(current) + pipe.set_adapters(current, adapter_weights=weights) + + return {"guidance_scale": 0, "num_inference_steps": 8} + + def disable_sdxl_lightning(self): + pipe = self.__pipe + pipe.scheduler = self.__scheduler_old + + current = pipe.get_active_adapters() + current = [adapter for adapter in current if adapter != "8step_lora"] + + weights = self.__find_adapter_weights(current) + pipe.set_adapters(current, adapter_weights=weights) + + def __find_adapter_weights(self, names: List[str]): + pipe = self.__pipe + + model = pipe.unet + + from peft.tuners.tuners_utils import BaseTunerLayer + + weights = [] + for adapter_name in names: + weight = 1.0 + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + if adapter_name in module.scaling: + weight = ( + module.scaling[adapter_name] + * module.r[adapter_name] + / module.lora_alpha[adapter_name] + ) + + weights.append(weight) + + return weights diff --git a/internals/util/slack.py b/internals/util/slack.py index b50fbe8d1743f3135e99d127020eed46aabdf390..2e22ce95b824a035f5d503ef3998ed6a9d7d924f 100644 --- a/internals/util/slack.py +++ b/internals/util/slack.py @@ -14,6 +14,8 @@ class Slack: self.error_webhook = "https://hooks.slack.com/services/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" def send_alert(self, task: Task, args: Optional[dict]): + if task.get_slack_url(): + self.webhook_url = task.get_slack_url() raw = task.get_raw().copy() raw["environment"] = get_environment() @@ -23,6 +25,7 @@ class Slack: raw.pop("task_id", None) raw.pop("maskImageUrl", None) raw.pop("aux_imageUrl", None) + raw.pop("slack_url", None) if args is not None: raw.update(args.items()) diff --git a/models/ultrasharp/model.py b/models/ultrasharp/model.py index 26c0924adcf5c7d88344e619a0d19e6f1cc4b863..fe10ae8933881a1b1979407b8f482e72f163aa57 100644 --- a/models/ultrasharp/model.py +++ b/models/ultrasharp/model.py @@ -1,5 +1,7 @@ from typing import List +import cv2 +import numpy as np import torch import models.ultrasharp.arch as arch @@ -25,5 +27,22 @@ class Ultrasharp: model.to("cuda") + if img.shape[2] == 4: # RGBA image with alpha channel + img_mode = "RGBA" + alpha = img[:, :, 3] + img = img[:, :, 0:3] + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = "RGB" img = upscale(model, img, self.tile_pad, self.tile) + + # process alpha channel if necessary + if img_mode == "RGBA": + output_alpha = upscale(model, alpha, self.tile_pad, self.tile) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + # output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + + # merge the alpha channel + img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) + img[:, :, 3] = output_alpha return img, None diff --git a/requirements.txt b/requirements.txt index cd1d82f321366d5fa425bc74f1c7e98435bce1df..67cf0bf40a708f547f01b14259112d28c6831d56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ boto3==1.24.61 triton==2.0.0 -diffusers==0.25.0 +diffusers==0.26.0 +huggingface-hub==0.25.1 fastapi==0.87.0 Pillow==9.3.0 redis==4.3.4 @@ -21,10 +22,7 @@ easydict==1.9.0 albumentations kornia==0.5.0 pytorch-lightning==1.2.9 -mmpose==0.29.0 -mmdet==2.28.2 https://comic-assets.s3.ap-south-1.amazonaws.com/packages/v1/lora-diffusion-0.1.7.zip -mmengine==0.8.4 pydash==7.0.6 scikit-learn==1.3.0 accelerate==0.22.0 @@ -34,7 +32,6 @@ scikit-image==0.19.3 omegaconf==2.3.0 webdataset==0.2.48 invisible-watermark -https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl python-dateutil==2.8.2 PyYAML invisible-watermark @@ -44,3 +41,4 @@ onnxruntime-gpu imgaug==0.4.0 tqdm==4.64.1 toml +peft