import argparse
import glob
import os
import cv2
from diffusers import AutoencoderKL
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
super(ResidualBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu2(out)
return out
class Upscaler(nn.Module):
def __init__(self):
super(Upscaler, self).__init__()
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU(inplace=True)
self.resblock1 = ResidualBlock(128)
self.resblock2 = ResidualBlock(128)
self.resblock3 = ResidualBlock(128)
self.resblock4 = ResidualBlock(128)
self.resblock5 = ResidualBlock(128)
self.resblock6 = ResidualBlock(128)
self.resblock7 = ResidualBlock(128)
self.resblock8 = ResidualBlock(128)
self.resblock9 = ResidualBlock(128)
self.resblock10 = ResidualBlock(128)
self.resblock11 = ResidualBlock(128)
self.resblock12 = ResidualBlock(128)
self.resblock13 = ResidualBlock(128)
self.resblock14 = ResidualBlock(128)
self.resblock15 = ResidualBlock(128)
self.resblock16 = ResidualBlock(128)
self.resblock17 = ResidualBlock(128)
self.resblock18 = ResidualBlock(128)
self.resblock19 = ResidualBlock(128)
self.resblock20 = ResidualBlock(128)
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
nn.init.constant_(self.conv_final.weight, 0)
def forward(self, x):
inp = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
residual = x
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = x + residual
residual = x
x = self.resblock5(x)
x = self.resblock6(x)
x = self.resblock7(x)
x = self.resblock8(x)
x = x + residual
residual = x
x = self.resblock9(x)
x = self.resblock10(x)
x = self.resblock11(x)
x = self.resblock12(x)
x = x + residual
residual = x
x = self.resblock13(x)
x = self.resblock14(x)
x = self.resblock15(x)
x = self.resblock16(x)
x = x + residual
residual = x
x = self.resblock17(x)
x = self.resblock18(x)
x = self.resblock19(x)
x = self.resblock20(x)
x = x + residual
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.conv_final(x)
x = x + inp
return x
def support_latents(self) -> bool:
return False
def upscale(
vae: AutoencoderKL,
lowreso_images: List[Image.Image],
lowreso_latents: torch.Tensor,
dtype: torch.dtype,
width: int,
height: int,
batch_size: int = 1,
vae_batch_size: int = 1,
assert lowreso_images is not None, "Upscaler requires lowreso image"
upsampled_images = []
for lowreso_image in lowreso_images:
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
upsampled_images = torch.stack(upsampled_images, dim=0)
upsampled_images = upsampled_images.to(dtype)
upsampled_images = upsampled_images / 127.5 - 1.0
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
with torch.no_grad():
batch = vae.encode(batch).latent_dist.sample()
upsampled_latents = torch.cat(upsampled_latents, dim=0)
print("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
upscaled_latents = torch.cat(upscaled_latents, dim=0)
return upscaled_latents * 0.18215
def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
sd = load_file(weights)
sd = torch.load(weights, map_location=torch.device("cpu"))
return model
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
us_dtype = torch.float16
os.makedirs(args.output_dir, exist_ok=True)
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
print("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
upscaler.to(DEVICE, dtype=us_dtype)
image_paths = glob.glob(args.image_pattern)
images = []
for image_path in image_paths:
image = Image.open(image_path)
image = image.convert("RGB")
width = image.width
height = image.height
if width % 8 != 0:
width = width - (width % 8)
if height % 8 != 0:
height = height - (height % 8)
if width != image.width or height != image.height:
image = image.crop((0, 0, width, height))
if args.debug:
for image, image_path in zip(images, image_paths):
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
basename = os.path.basename(image_path)
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
upscaled_latents /= 0.18215
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
batch = batch.to("cpu")
upscaled_images = torch.cat(upscaled_images, dim=0)
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
upscaled_images = (upscaled_images + 1.0) * 127.5
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
upscaled_images = upscaled_images[..., ::-1]
for i, image in enumerate(upscaled_images):
basename = os.path.basename(image_paths[i])
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
cv2.imwrite(dest_file_name, image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
parser.add_argument("--weights", type=str, default=None, help="Weights path")
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()