CM2000112 / internals /pipelines /object_remove.py
jayparmr's picture
Upload 118 files
19b3da3
raw
history blame
2.65 kB
import os
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
import tqdm
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate
from internals.util.commons import download_file, download_image
from internals.util.config import get_root_dir
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.training.data.datasets import make_default_val_dataset
from saicinpainting.training.trainers import load_checkpoint
class ObjectRemoval:
def load(self, model_dir):
print("Downloading LAMA model...")
self.lama_path = Path.home() / ".cache" / "lama"
out_file = self.lama_path / "models" / "best.ckpt"
os.makedirs(os.path.dirname(out_file), exist_ok=True)
download_file(
"https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
)
config = OmegaConf.load(get_root_dir() + "/config.yml")
config.training_model.predict_only = True
self.model = load_checkpoint(
config, str(out_file), strict=False, map_location="cuda"
)
self.model.freeze()
self.model.to("cuda")
@torch.no_grad()
def process(
self,
image_url: str,
mask_image_url: str,
seed: int,
width: int,
height: int,
) -> List:
torch.manual_seed(seed)
img_folder = self.lama_path / "images"
indir = img_folder / "input"
img_folder.mkdir(parents=True, exist_ok=True)
indir.mkdir(parents=True, exist_ok=True)
download_image(image_url).resize((width, height)).save(indir / "data.png")
download_image(mask_image_url).resize((width, height)).save(
indir / "data_mask.png"
)
dataset = make_default_val_dataset(
img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
)
out_images = []
for img_i in tqdm.trange(len(dataset)):
batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
batch["mask"] = (batch["mask"] > 0) * 1
batch = self.model(batch)
out_path = str(img_folder / "out.png")
cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(out_path, cur_res)
image = Image.open(out_path).convert("RGB")
out_images.append(image)
os.remove(out_path)
return out_images