File size: 2,760 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from pathlib import Path
from typing import Union, List, Optional

from PIL import Image

from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.utils.image_utils import load_image
from carvekit.utils.mask_utils import apply_mask
from carvekit.utils.pool_utils import thread_pool_processing


class Interface:
    def __init__(
        self,
        seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
        pre_pipe: Optional[Union[PreprocessingStub]] = None,
        post_pipe: Optional[Union[MattingMethod]] = None,
        device="cpu",
    ):
        """
        Initializes an object for interacting with pipelines and other components of the CarveKit framework.

        Args:
            pre_pipe: Initialized pre-processing pipeline object
            seg_pipe: Initialized segmentation network object
            post_pipe: Initialized postprocessing pipeline object
            device: The processing device that will be used to apply the masks to the images.
        """
        self.device = device
        self.preprocessing_pipeline = pre_pipe
        self.segmentation_pipeline = seg_pipe
        self.postprocessing_pipeline = post_pipe

    def __call__(
        self, images: List[Union[str, Path, Image.Image]]
    ) -> List[Image.Image]:
        """
        Removes the background from the specified images.

        Args:
            images: list of input images

        Returns:
            List of images without background as PIL.Image.Image instances
        """
        images = thread_pool_processing(load_image, images)
        if self.preprocessing_pipeline is not None:
            masks: List[Image.Image] = self.preprocessing_pipeline(
                interface=self, images=images
            )
        else:
            masks: List[Image.Image] = self.segmentation_pipeline(images=images)

        if self.postprocessing_pipeline is not None:
            images: List[Image.Image] = self.postprocessing_pipeline(
                images=images, masks=masks
            )
        else:
            images = list(
                map(
                    lambda x: apply_mask(
                        image=images[x], mask=masks[x], device=self.device
                    ),
                    range(len(images)),
                )
            )
        return images