File size: 1,321 Bytes
19b3da3
 
 
a3d6c18
 
19b3da3
 
 
a3d6c18
19b3da3
 
 
 
 
 
 
 
 
 
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import io
from typing import Union

import torch
import torch.nn.functional as F
from PIL import Image
from rembg import remove

from carvekit.api.high import HiInterface
from internals.util.commons import read_url


class RemoveBackground:
    def remove(self, image: Union[str, Image.Image]) -> Image.Image:
        if type(image) is str:
            image = Image.open(io.BytesIO(read_url(image)))

        output = remove(image)
        return output


class RemoveBackgroundV2:
    def __init__(self):
        self.interface = HiInterface(
            object_type="object",  # Can be "object" or "hairs-like".
            batch_size_seg=5,
            batch_size_matting=1,
            device="cuda" if torch.cuda.is_available() else "cpu",
            seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
            matting_mask_size=2048,
            trimap_prob_threshold=231,
            trimap_dilation=30,
            trimap_erosion_iters=5,
            fp16=False,
        )

    def remove(self, image: Union[str, Image.Image]) -> Image.Image:
        if type(image) is str:
            image = Image.open(io.BytesIO(read_url(image)))

        image.save("rm_bg.png")
        images_without_background = self.interface(["./rm_bg.png"])
        out = images_without_background[0]
        return out