import io | |
from typing import Union | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from rembg import remove | |
from carvekit.api.high import HiInterface | |
from internals.util.commons import read_url | |
class RemoveBackground: | |
def remove(self, image: Union[str, Image.Image]) -> Image.Image: | |
if type(image) is str: | |
image = Image.open(io.BytesIO(read_url(image))) | |
output = remove(image) | |
return output | |
class RemoveBackgroundV2: | |
def __init__(self): | |
self.interface = HiInterface( | |
object_type="object", # Can be "object" or "hairs-like". | |
batch_size_seg=5, | |
batch_size_matting=1, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=False, | |
) | |
def remove(self, image: Union[str, Image.Image]) -> Image.Image: | |
if type(image) is str: | |
image = Image.open(io.BytesIO(read_url(image))) | |
image.save("rm_bg.png") | |
images_without_background = self.interface(["./rm_bg.png"]) | |
out = images_without_background[0] | |
return out | |