File size: 2,760 Bytes
a3d6c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from pathlib import Path
from typing import Union, List, Optional
from PIL import Image
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.utils.image_utils import load_image
from carvekit.utils.mask_utils import apply_mask
from carvekit.utils.pool_utils import thread_pool_processing
class Interface:
def __init__(
self,
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
pre_pipe: Optional[Union[PreprocessingStub]] = None,
post_pipe: Optional[Union[MattingMethod]] = None,
device="cpu",
):
"""
Initializes an object for interacting with pipelines and other components of the CarveKit framework.
Args:
pre_pipe: Initialized pre-processing pipeline object
seg_pipe: Initialized segmentation network object
post_pipe: Initialized postprocessing pipeline object
device: The processing device that will be used to apply the masks to the images.
"""
self.device = device
self.preprocessing_pipeline = pre_pipe
self.segmentation_pipeline = seg_pipe
self.postprocessing_pipeline = post_pipe
def __call__(
self, images: List[Union[str, Path, Image.Image]]
) -> List[Image.Image]:
"""
Removes the background from the specified images.
Args:
images: list of input images
Returns:
List of images without background as PIL.Image.Image instances
"""
images = thread_pool_processing(load_image, images)
if self.preprocessing_pipeline is not None:
masks: List[Image.Image] = self.preprocessing_pipeline(
interface=self, images=images
)
else:
masks: List[Image.Image] = self.segmentation_pipeline(images=images)
if self.postprocessing_pipeline is not None:
images: List[Image.Image] = self.postprocessing_pipeline(
images=images, masks=masks
)
else:
images = list(
map(
lambda x: apply_mask(
image=images[x], mask=masks[x], device=self.device
),
range(len(images)),
)
)
return images
|