File size: 2,654 Bytes
19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
import tqdm
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate
from internals.util.commons import download_file, download_image
from internals.util.config import get_root_dir
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.training.data.datasets import make_default_val_dataset
from saicinpainting.training.trainers import load_checkpoint
class ObjectRemoval:
def load(self, model_dir):
print("Downloading LAMA model...")
self.lama_path = Path.home() / ".cache" / "lama"
out_file = self.lama_path / "models" / "best.ckpt"
os.makedirs(os.path.dirname(out_file), exist_ok=True)
download_file(
"https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
)
config = OmegaConf.load(get_root_dir() + "/config.yml")
config.training_model.predict_only = True
self.model = load_checkpoint(
config, str(out_file), strict=False, map_location="cuda"
)
self.model.freeze()
self.model.to("cuda")
@torch.no_grad()
def process(
self,
image_url: str,
mask_image_url: str,
seed: int,
width: int,
height: int,
) -> List:
torch.manual_seed(seed)
img_folder = self.lama_path / "images"
indir = img_folder / "input"
img_folder.mkdir(parents=True, exist_ok=True)
indir.mkdir(parents=True, exist_ok=True)
download_image(image_url).resize((width, height)).save(indir / "data.png")
download_image(mask_image_url).resize((width, height)).save(
indir / "data_mask.png"
)
dataset = make_default_val_dataset(
img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
)
out_images = []
for img_i in tqdm.trange(len(dataset)):
batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
batch["mask"] = (batch["mask"] > 0) * 1
batch = self.model(batch)
out_path = str(img_folder / "out.png")
cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(out_path, cur_res)
image = Image.open(out_path).convert("RGB")
out_images.append(image)
os.remove(out_path)
return out_images
|