File size: 3,336 Bytes
19b3da3
cd51d32
19b3da3
10230ea
 
19b3da3
a3d6c18
 
19b3da3
 
10230ea
19b3da3
42ef134
a3d6c18
f1235a4
10230ea
 
19b3da3
 
 
 
 
 
 
 
 
a3d6c18
 
 
 
10230ea
 
 
 
 
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
10230ea
 
 
a3d6c18
f1235a4
a3d6c18
10230ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42ef134
10230ea
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import io
from pathlib import Path
from typing import Union
import numpy as np
import cv2

import torch
import torch.nn.functional as F
from PIL import Image
from rembg import remove
from internals.data.task import ModelType

import internals.util.image as ImageUtil
from carvekit.api.high import HiInterface
from internals.util.commons import download_image, read_url
import onnxruntime as rt
import huggingface_hub


class RemoveBackground:
    def remove(self, image: Union[str, Image.Image]) -> Image.Image:
        if type(image) is str:
            image = Image.open(io.BytesIO(read_url(image)))

        output = remove(image)
        return output


class RemoveBackgroundV2:
    def __init__(self):
        model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
        self.anime_rembg = rt.InferenceSession(
            model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
        )

        self.interface = HiInterface(
            object_type="object",  # Can be "object" or "hairs-like".
            batch_size_seg=5,
            batch_size_matting=1,
            device="cuda" if torch.cuda.is_available() else "cpu",
            seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
            matting_mask_size=2048,
            trimap_prob_threshold=231,
            trimap_dilation=30,
            trimap_erosion_iters=5,
            fp16=False,
        )

    def remove(
        self, image: Union[str, Image.Image], model_type: ModelType = ModelType.REAL
    ) -> Image.Image:
        if type(image) is str:
            image = download_image(image)

        if model_type == ModelType.ANIME or model_type == ModelType.COMIC:
            print("Using Anime Background remover")
            _, img = self.__rmbg_fn(np.array(image))

            return Image.fromarray(img)
        else:
            print("Using Real Background remover")
            img_path = Path.home() / ".cache" / "rm_bg.png"

            w, h = image.size
            if max(w, h) > 1536:
                image = ImageUtil.resize_image(image, dimension=1024)

            image.save(img_path)
            images_without_background = self.interface([img_path])
            out = images_without_background[0]
            return out

    def __get_mask(self, img, s=1024):
        img = (img / 255).astype(np.float32)
        h, w = h0, w0 = img.shape[:-1]
        h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
        ph, pw = s - h, s - w
        img_input = np.zeros([s, s, 3], dtype=np.float32)
        img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(
            img, (w, h)
        )
        img_input = np.transpose(img_input, (2, 0, 1))
        img_input = img_input[np.newaxis, :]
        mask = self.anime_rembg.run(None, {"img": img_input})[0][0]
        mask = np.transpose(mask, (1, 2, 0))
        mask = mask[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
        mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
        return mask

    def __rmbg_fn(self, img):
        mask = self.__get_mask(img)
        img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
        mask = (mask * 255).astype(np.uint8)
        img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
        mask = mask.repeat(3, axis=2)
        return mask, img