|
|
|
|
|
|
|
import argparse
|
|
import glob
|
|
import os
|
|
import cv2
|
|
from diffusers import AutoencoderKL
|
|
|
|
from typing import Dict, List
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
if out_channels is None:
|
|
out_channels = in_channels
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
|
|
|
|
self._initialize_weights()
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu1(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
out += residual
|
|
|
|
out = self.relu2(out)
|
|
|
|
return out
|
|
|
|
|
|
class Upscaler(nn.Module):
|
|
def __init__(self):
|
|
super(Upscaler, self).__init__()
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
|
self.bn1 = nn.BatchNorm2d(128)
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
self.resblock1 = ResidualBlock(128)
|
|
self.resblock2 = ResidualBlock(128)
|
|
self.resblock3 = ResidualBlock(128)
|
|
self.resblock4 = ResidualBlock(128)
|
|
self.resblock5 = ResidualBlock(128)
|
|
self.resblock6 = ResidualBlock(128)
|
|
self.resblock7 = ResidualBlock(128)
|
|
self.resblock8 = ResidualBlock(128)
|
|
self.resblock9 = ResidualBlock(128)
|
|
self.resblock10 = ResidualBlock(128)
|
|
self.resblock11 = ResidualBlock(128)
|
|
self.resblock12 = ResidualBlock(128)
|
|
self.resblock13 = ResidualBlock(128)
|
|
self.resblock14 = ResidualBlock(128)
|
|
self.resblock15 = ResidualBlock(128)
|
|
self.resblock16 = ResidualBlock(128)
|
|
self.resblock17 = ResidualBlock(128)
|
|
self.resblock18 = ResidualBlock(128)
|
|
self.resblock19 = ResidualBlock(128)
|
|
self.resblock20 = ResidualBlock(128)
|
|
|
|
|
|
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
|
self.bn2 = nn.BatchNorm2d(64)
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
|
|
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
|
self.bn3 = nn.BatchNorm2d(64)
|
|
self.relu3 = nn.ReLU(inplace=True)
|
|
|
|
|
|
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
|
|
|
|
self._initialize_weights()
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
nn.init.constant_(self.conv_final.weight, 0)
|
|
|
|
def forward(self, x):
|
|
inp = x
|
|
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu1(x)
|
|
|
|
|
|
residual = x
|
|
x = self.resblock1(x)
|
|
x = self.resblock2(x)
|
|
x = self.resblock3(x)
|
|
x = self.resblock4(x)
|
|
x = x + residual
|
|
residual = x
|
|
x = self.resblock5(x)
|
|
x = self.resblock6(x)
|
|
x = self.resblock7(x)
|
|
x = self.resblock8(x)
|
|
x = x + residual
|
|
residual = x
|
|
x = self.resblock9(x)
|
|
x = self.resblock10(x)
|
|
x = self.resblock11(x)
|
|
x = self.resblock12(x)
|
|
x = x + residual
|
|
residual = x
|
|
x = self.resblock13(x)
|
|
x = self.resblock14(x)
|
|
x = self.resblock15(x)
|
|
x = self.resblock16(x)
|
|
x = x + residual
|
|
residual = x
|
|
x = self.resblock17(x)
|
|
x = self.resblock18(x)
|
|
x = self.resblock19(x)
|
|
x = self.resblock20(x)
|
|
x = x + residual
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu2(x)
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
|
|
|
|
|
|
x = self.conv_final(x)
|
|
|
|
|
|
x = x + inp
|
|
|
|
return x
|
|
|
|
def support_latents(self) -> bool:
|
|
return False
|
|
|
|
def upscale(
|
|
self,
|
|
vae: AutoencoderKL,
|
|
lowreso_images: List[Image.Image],
|
|
lowreso_latents: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
width: int,
|
|
height: int,
|
|
batch_size: int = 1,
|
|
vae_batch_size: int = 1,
|
|
):
|
|
|
|
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
|
|
|
|
|
upsampled_images = []
|
|
for lowreso_image in lowreso_images:
|
|
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
|
upsampled_images.append(upsampled_image)
|
|
|
|
|
|
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
|
upsampled_images = torch.stack(upsampled_images, dim=0)
|
|
upsampled_images = upsampled_images.to(dtype)
|
|
|
|
|
|
upsampled_images = upsampled_images / 127.5 - 1.0
|
|
|
|
|
|
|
|
upsampled_latents = []
|
|
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
|
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
|
with torch.no_grad():
|
|
batch = vae.encode(batch).latent_dist.sample()
|
|
upsampled_latents.append(batch)
|
|
|
|
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
|
|
|
|
|
print("Upscaling latents...")
|
|
upscaled_latents = []
|
|
for i in range(0, upsampled_latents.shape[0], batch_size):
|
|
with torch.no_grad():
|
|
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
|
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
|
|
|
return upscaled_latents * 0.18215
|
|
|
|
|
|
|
|
def create_upscaler(**kwargs):
|
|
weights = kwargs["weights"]
|
|
model = Upscaler()
|
|
|
|
print(f"Loading weights from {weights}...")
|
|
if os.path.splitext(weights)[1] == ".safetensors":
|
|
from safetensors.torch import load_file
|
|
|
|
sd = load_file(weights)
|
|
else:
|
|
sd = torch.load(weights, map_location=torch.device("cpu"))
|
|
model.load_state_dict(sd)
|
|
return model
|
|
|
|
|
|
|
|
def upscale_images(args: argparse.Namespace):
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
us_dtype = torch.float16
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
assert args.vae_path is not None, "VAE path is required"
|
|
print(f"Loading VAE from {args.vae_path}...")
|
|
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
|
vae.to(DEVICE, dtype=us_dtype)
|
|
|
|
|
|
print("Preparing model...")
|
|
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
|
|
|
|
|
upscaler.eval()
|
|
upscaler.to(DEVICE, dtype=us_dtype)
|
|
|
|
|
|
image_paths = glob.glob(args.image_pattern)
|
|
images = []
|
|
for image_path in image_paths:
|
|
image = Image.open(image_path)
|
|
image = image.convert("RGB")
|
|
|
|
|
|
width = image.width
|
|
height = image.height
|
|
if width % 8 != 0:
|
|
width = width - (width % 8)
|
|
if height % 8 != 0:
|
|
height = height - (height % 8)
|
|
if width != image.width or height != image.height:
|
|
image = image.crop((0, 0, width, height))
|
|
|
|
images.append(image)
|
|
|
|
|
|
if args.debug:
|
|
for image, image_path in zip(images, image_paths):
|
|
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
|
|
|
basename = os.path.basename(image_path)
|
|
basename_wo_ext, ext = os.path.splitext(basename)
|
|
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
|
image_debug.save(dest_file_name)
|
|
|
|
|
|
print("Upscaling...")
|
|
upscaled_latents = upscaler.upscale(
|
|
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
|
)
|
|
upscaled_latents /= 0.18215
|
|
|
|
|
|
print("Decoding...")
|
|
upscaled_images = []
|
|
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
|
with torch.no_grad():
|
|
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
|
batch = batch.to("cpu")
|
|
upscaled_images.append(batch)
|
|
upscaled_images = torch.cat(upscaled_images, dim=0)
|
|
|
|
|
|
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
|
upscaled_images = (upscaled_images + 1.0) * 127.5
|
|
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
|
|
|
upscaled_images = upscaled_images[..., ::-1]
|
|
|
|
|
|
for i, image in enumerate(upscaled_images):
|
|
basename = os.path.basename(image_paths[i])
|
|
basename_wo_ext, ext = os.path.splitext(basename)
|
|
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
|
cv2.imwrite(dest_file_name, image)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
|
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
|
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
|
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
|
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
|
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
|
|
|
args = parser.parse_args()
|
|
upscale_images(args)
|
|
|