import os from pathlib import Path from typing import List import cv2 import numpy as np import torch import tqdm from omegaconf import OmegaConf from PIL import Image from torch.utils.data._utils.collate import default_collate from internals.util.commons import download_file, download_image from internals.util.config import get_root_dir from saicinpainting.evaluation.utils import move_to_device from saicinpainting.training.data.datasets import make_default_val_dataset from saicinpainting.training.trainers import load_checkpoint class ObjectRemoval: __loaded = False def load(self, model_dir): if self.__loaded: return print("Downloading LAMA model...") self.lama_path = Path.home() / ".cache" / "lama" out_file = self.lama_path / "models" / "best.ckpt" os.makedirs(os.path.dirname(out_file), exist_ok=True) download_file( "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file ) config = OmegaConf.load(get_root_dir() + "/config.yml") config.training_model.predict_only = True self.model = load_checkpoint( config, str(out_file), strict=False, map_location="cuda" ) self.model.freeze() self.model.to("cuda") self.__loaded = True @torch.no_grad() def process( self, image_url: str, mask_image_url: str, seed: int, width: int, height: int, ) -> List: torch.manual_seed(seed) img_folder = self.lama_path / "images" indir = img_folder / "input" img_folder.mkdir(parents=True, exist_ok=True) indir.mkdir(parents=True, exist_ok=True) download_image(image_url).resize((width, height)).save(indir / "data.png") download_image(mask_image_url).resize((width, height)).save( indir / "data_mask.png" ) dataset = make_default_val_dataset( img_folder / "input", img_suffix=".png", pad_out_to_modulo=8 ) out_images = [] for img_i in tqdm.trange(len(dataset)): batch = move_to_device(default_collate([dataset[img_i]]), "cuda") batch["mask"] = (batch["mask"] > 0) * 1 batch = self.model(batch) out_path = str(img_folder / "out.png") cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy() cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) cv2.imwrite(out_path, cur_res) image = Image.open(out_path).convert("RGB") out_images.append(image) os.remove(out_path) return out_images