CM2000112 / internals /pipelines /object_remove.py
jayparmr's picture
Upload folder using huggingface_hub
99a0484
raw
history blame
2.75 kB
import os
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
import tqdm
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate
from internals.util.commons import download_file, download_image
from internals.util.config import get_root_dir
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.training.data.datasets import make_default_val_dataset
from saicinpainting.training.trainers import load_checkpoint
class ObjectRemoval:
__loaded = False
def load(self, model_dir):
if self.__loaded:
return
print("Downloading LAMA model...")
self.lama_path = Path.home() / ".cache" / "lama"
out_file = self.lama_path / "models" / "best.ckpt"
os.makedirs(os.path.dirname(out_file), exist_ok=True)
download_file(
"https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
)
config = OmegaConf.load(get_root_dir() + "/config.yml")
config.training_model.predict_only = True
self.model = load_checkpoint(
config, str(out_file), strict=False, map_location="cuda"
)
self.model.freeze()
self.model.to("cuda")
self.__loaded = True
@torch.no_grad()
def process(
self,
image_url: str,
mask_image_url: str,
seed: int,
width: int,
height: int,
) -> List:
torch.manual_seed(seed)
img_folder = self.lama_path / "images"
indir = img_folder / "input"
img_folder.mkdir(parents=True, exist_ok=True)
indir.mkdir(parents=True, exist_ok=True)
download_image(image_url).resize((width, height)).save(indir / "data.png")
download_image(mask_image_url).resize((width, height)).save(
indir / "data_mask.png"
)
dataset = make_default_val_dataset(
img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
)
out_images = []
for img_i in tqdm.trange(len(dataset)):
batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
batch["mask"] = (batch["mask"] > 0) * 1
batch = self.model(batch)
out_path = str(img_folder / "out.png")
cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(out_path, cur_res)
image = Image.open(out_path).convert("RGB")
out_images.append(image)
os.remove(out_path)
return out_images