|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import torch |
|
from PIL import Image |
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( |
|
MultiUpscaler, |
|
UpscalerCheckpoints, |
|
) |
|
|
|
from esrgan_model import UpscalerESRGAN |
|
|
|
|
|
@dataclass(kw_only=True) |
|
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): |
|
esrgan: Path |
|
|
|
|
|
class ESRGANUpscaler(MultiUpscaler): |
|
def __init__( |
|
self, |
|
checkpoints: ESRGANUpscalerCheckpoints, |
|
device: torch.device, |
|
dtype: torch.dtype, |
|
) -> None: |
|
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) |
|
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) |
|
self.esrgan.to(device=device, dtype=dtype) |
|
|
|
def to(self, device: torch.device, dtype: torch.dtype): |
|
self.esrgan.to(device=device, dtype=dtype) |
|
self.sd = self.sd.to(device=device, dtype=dtype) |
|
self.device = device |
|
self.dtype = dtype |
|
|
|
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: |
|
image = self.esrgan.upscale_with_tiling(image) |
|
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) |
|
|