|
import os,sys |
|
import folder_paths |
|
|
|
from PIL import Image |
|
import importlib.util |
|
|
|
import comfy.utils |
|
import numpy as np |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_download |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import normalize |
|
|
|
class REBNCONV(nn.Module): |
|
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): |
|
super(REBNCONV,self).__init__() |
|
|
|
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) |
|
self.bn_s1 = nn.BatchNorm2d(out_ch) |
|
self.relu_s1 = nn.ReLU(inplace=True) |
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) |
|
|
|
return xout |
|
|
|
|
|
def _upsample_like(src,tar): |
|
|
|
src = F.interpolate(src,size=tar.shape[2:],mode='bilinear') |
|
|
|
return src |
|
|
|
|
|
|
|
class RSU7(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): |
|
super(RSU7,self).__init__() |
|
|
|
self.in_ch = in_ch |
|
self.mid_ch = mid_ch |
|
self.out_ch = out_ch |
|
|
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
|
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
|
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
|
def forward(self,x): |
|
b, c, h, w = x.shape |
|
|
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
|
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
|
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
|
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
|
|
hx4 = self.rebnconv4(hx) |
|
hx = self.pool4(hx4) |
|
|
|
hx5 = self.rebnconv5(hx) |
|
hx = self.pool5(hx5) |
|
|
|
hx6 = self.rebnconv6(hx) |
|
|
|
hx7 = self.rebnconv7(hx6) |
|
|
|
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) |
|
hx6dup = _upsample_like(hx6d,hx5) |
|
|
|
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) |
|
hx5dup = _upsample_like(hx5d,hx4) |
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) |
|
hx4dup = _upsample_like(hx4d,hx3) |
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
|
hx3dup = _upsample_like(hx3d,hx2) |
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
|
hx2dup = _upsample_like(hx2d,hx1) |
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
return hx1d + hxin |
|
|
|
|
|
|
|
class RSU6(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU6,self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
|
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
|
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
|
|
hxin = self.rebnconvin(hx) |
|
|
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
|
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
|
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
|
|
hx4 = self.rebnconv4(hx) |
|
hx = self.pool4(hx4) |
|
|
|
hx5 = self.rebnconv5(hx) |
|
|
|
hx6 = self.rebnconv6(hx5) |
|
|
|
|
|
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) |
|
hx5dup = _upsample_like(hx5d,hx4) |
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) |
|
hx4dup = _upsample_like(hx4d,hx3) |
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
|
hx3dup = _upsample_like(hx3d,hx2) |
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
|
hx2dup = _upsample_like(hx2d,hx1) |
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
return hx1d + hxin |
|
|
|
|
|
class RSU5(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU5,self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
|
|
hxin = self.rebnconvin(hx) |
|
|
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
|
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
|
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
|
|
hx4 = self.rebnconv4(hx) |
|
|
|
hx5 = self.rebnconv5(hx4) |
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) |
|
hx4dup = _upsample_like(hx4d,hx3) |
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
|
hx3dup = _upsample_like(hx3d,hx2) |
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
|
hx2dup = _upsample_like(hx2d,hx1) |
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
return hx1d + hxin |
|
|
|
|
|
class RSU4(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU4,self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
|
|
hxin = self.rebnconvin(hx) |
|
|
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
|
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
|
|
hx3 = self.rebnconv3(hx) |
|
|
|
hx4 = self.rebnconv4(hx3) |
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) |
|
hx3dup = _upsample_like(hx3d,hx2) |
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
|
hx2dup = _upsample_like(hx2d,hx1) |
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
return hx1d + hxin |
|
|
|
|
|
class RSU4F(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU4F,self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) |
|
|
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) |
|
|
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) |
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) |
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
|
|
hxin = self.rebnconvin(hx) |
|
|
|
hx1 = self.rebnconv1(hxin) |
|
hx2 = self.rebnconv2(hx1) |
|
hx3 = self.rebnconv3(hx2) |
|
|
|
hx4 = self.rebnconv4(hx3) |
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) |
|
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) |
|
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) |
|
|
|
return hx1d + hxin |
|
|
|
|
|
class myrebnconv(nn.Module): |
|
def __init__(self, in_ch=3, |
|
out_ch=1, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
groups=1): |
|
super(myrebnconv,self).__init__() |
|
|
|
self.conv = nn.Conv2d(in_ch, |
|
out_ch, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups) |
|
self.bn = nn.BatchNorm2d(out_ch) |
|
self.rl = nn.ReLU(inplace=True) |
|
|
|
def forward(self,x): |
|
return self.rl(self.bn(self.conv(x))) |
|
|
|
|
|
class BriaRMBG(nn.Module): |
|
|
|
def __init__(self,in_ch=3,out_ch=1): |
|
super(BriaRMBG,self).__init__() |
|
|
|
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) |
|
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage1 = RSU7(64,32,64) |
|
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage2 = RSU6(64,32,128) |
|
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage3 = RSU5(128,64,256) |
|
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage4 = RSU4(256,128,512) |
|
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage5 = RSU4F(512,256,512) |
|
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
|
self.stage6 = RSU4F(512,256,512) |
|
|
|
|
|
self.stage5d = RSU4F(1024,256,512) |
|
self.stage4d = RSU4(1024,128,256) |
|
self.stage3d = RSU5(512,64,128) |
|
self.stage2d = RSU6(256,32,64) |
|
self.stage1d = RSU7(128,16,64) |
|
|
|
self.side1 = nn.Conv2d(64,out_ch,3,padding=1) |
|
self.side2 = nn.Conv2d(64,out_ch,3,padding=1) |
|
self.side3 = nn.Conv2d(128,out_ch,3,padding=1) |
|
self.side4 = nn.Conv2d(256,out_ch,3,padding=1) |
|
self.side5 = nn.Conv2d(512,out_ch,3,padding=1) |
|
self.side6 = nn.Conv2d(512,out_ch,3,padding=1) |
|
|
|
|
|
|
|
def forward(self,x): |
|
|
|
hx = x |
|
|
|
hxin = self.conv_in(hx) |
|
|
|
|
|
|
|
hx1 = self.stage1(hxin) |
|
hx = self.pool12(hx1) |
|
|
|
|
|
hx2 = self.stage2(hx) |
|
hx = self.pool23(hx2) |
|
|
|
|
|
hx3 = self.stage3(hx) |
|
hx = self.pool34(hx3) |
|
|
|
|
|
hx4 = self.stage4(hx) |
|
hx = self.pool45(hx4) |
|
|
|
|
|
hx5 = self.stage5(hx) |
|
hx = self.pool56(hx5) |
|
|
|
|
|
hx6 = self.stage6(hx) |
|
hx6up = _upsample_like(hx6,hx5) |
|
|
|
|
|
hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) |
|
hx5dup = _upsample_like(hx5d,hx4) |
|
|
|
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) |
|
hx4dup = _upsample_like(hx4d,hx3) |
|
|
|
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) |
|
hx3dup = _upsample_like(hx3d,hx2) |
|
|
|
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) |
|
hx2dup = _upsample_like(hx2d,hx1) |
|
|
|
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
|
|
|
|
d1 = self.side1(hx1d) |
|
d1 = _upsample_like(d1,x) |
|
|
|
d2 = self.side2(hx2d) |
|
d2 = _upsample_like(d2,x) |
|
|
|
d3 = self.side3(hx3d) |
|
d3 = _upsample_like(d3,x) |
|
|
|
d4 = self.side4(hx4d) |
|
d4 = _upsample_like(d4,x) |
|
|
|
d5 = self.side5(hx5d) |
|
d5 = _upsample_like(d5,x) |
|
|
|
d6 = self.side6(hx6) |
|
d6 = _upsample_like(d6,x) |
|
|
|
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] |
|
|
|
|
|
|
|
|
|
def get_U2NET_model_path(): |
|
try: |
|
return folder_paths.get_folder_paths('rembg')[0] |
|
except: |
|
return os.path.join(folder_paths.models_dir, "rembg") |
|
|
|
|
|
U2NET_HOME=get_U2NET_model_path() |
|
os.environ["U2NET_HOME"] = U2NET_HOME |
|
|
|
global _available |
|
_available=False |
|
|
|
|
|
def get_rembg_models(path): |
|
"""从目录中获取文件并提取文件名 |
|
Args: |
|
path: 目录路径 |
|
Returns: |
|
文件名列表 |
|
""" |
|
filenames = [] |
|
for root, _, files in os.walk(path): |
|
for filename in files: |
|
|
|
if not filename.startswith('.'): |
|
name, ext = os.path.splitext(os.path.basename(filename)) |
|
filenames.append(name) |
|
return filenames |
|
|
|
|
|
def is_installed(package): |
|
try: |
|
spec = importlib.util.find_spec(package) |
|
except ModuleNotFoundError: |
|
return False |
|
return spec is not None |
|
|
|
|
|
try: |
|
if is_installed('rembg')==False: |
|
import subprocess |
|
|
|
|
|
print('#pip install rembg[gpu]') |
|
|
|
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'rembg[gpu]'], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
print("#install success") |
|
from rembg import new_session, remove |
|
_available=True |
|
else: |
|
print("#install error") |
|
|
|
else: |
|
from rembg import new_session, remove |
|
_available=True |
|
|
|
except: |
|
_available=False |
|
|
|
|
|
def run_briarmbg(images=[]): |
|
mroot=U2NET_HOME |
|
m=os.path.join(mroot,'briarmbg.pth') |
|
if os.path.exists(m)==False: |
|
|
|
m1=hf_hub_download("briaai/RMBG-1.4", |
|
local_dir=mroot, |
|
filename='model.pth', |
|
local_dir_use_symlinks=False, |
|
endpoint='https://hf-mirror.com') |
|
os.rename(m1, m) |
|
|
|
net=BriaRMBG() |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(m)) |
|
net=net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(m,map_location="cpu")) |
|
net.eval() |
|
|
|
masks=[] |
|
rgba_images=[] |
|
rgb_images=[] |
|
for orig_image in images: |
|
|
|
w,h = orig_im_size = orig_image.size |
|
|
|
image = orig_image.convert('RGB') |
|
model_input_size = (1024, 1024) |
|
image = image.resize(model_input_size, Image.BILINEAR) |
|
|
|
im_np = np.array(image) |
|
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) |
|
im_tensor = torch.unsqueeze(im_tensor,0) |
|
im_tensor = torch.divide(im_tensor,255.0) |
|
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) |
|
if torch.cuda.is_available(): |
|
im_tensor=im_tensor.cuda() |
|
|
|
result=net(im_tensor) |
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) |
|
ma = torch.max(result) |
|
mi = torch.min(result) |
|
result = (result-mi)/(ma-mi) |
|
im_array = (result*255).cpu().data.numpy().astype(np.uint8) |
|
mask = Image.fromarray(np.squeeze(im_array)) |
|
|
|
|
|
mask=mask.convert('L') |
|
|
|
masks.append(mask) |
|
|
|
|
|
image_rgba =orig_image.convert("RGBA") |
|
image_rgba.putalpha(mask) |
|
rgba_images.append(image_rgba) |
|
|
|
|
|
rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
|
rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
|
rgb_images.append(rgb_image) |
|
return (masks,rgba_images,rgb_images) |
|
|
|
|
|
def run_rembg(model_name= "unet",images=[],callback=None): |
|
|
|
|
|
rembg_session = new_session(model_name) |
|
masks=[] |
|
rgba_images=[] |
|
rgb_images=[] |
|
|
|
pbar=callback |
|
for img in images: |
|
|
|
mask = remove(img, session=rembg_session,only_mask=True,post_process_mask=True) |
|
|
|
|
|
if model_name=="u2net_cloth_seg": |
|
width, original_height = mask.size |
|
num_slices = original_height // img.height |
|
for i in range(num_slices): |
|
top = i * img.height |
|
bottom = (i + 1) * img.height |
|
slice_image = mask.crop((0, top, width, bottom)) |
|
slice_mask=slice_image.convert('L') |
|
masks.append(slice_mask) |
|
|
|
|
|
image_rgba = img.convert("RGBA") |
|
image_rgba.putalpha(slice_mask) |
|
rgba_images.append(image_rgba) |
|
|
|
|
|
rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
|
rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
|
rgb_images.append(rgb_image) |
|
|
|
else: |
|
mask=mask.convert('L') |
|
|
|
masks.append(mask) |
|
|
|
|
|
image_rgba = img.convert("RGBA") |
|
image_rgba.putalpha(mask) |
|
rgba_images.append(image_rgba) |
|
|
|
|
|
rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
|
rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
|
rgb_images.append(rgb_image) |
|
|
|
if pbar: |
|
pbar.update(1) |
|
return (masks,rgba_images,rgb_images) |
|
|
|
|
|
|
|
def tensor2pil(image): |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
|
|
|
def pil2tensor(image): |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
|
|
class RembgNode_: |
|
|
|
global _available |
|
available=_available |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"image": ("IMAGE",), |
|
"model_name": (get_rembg_models(U2NET_HOME),), |
|
|
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("MASK","IMAGE","RGBA",) |
|
RETURN_NAMES = ("masks","images","RGBAs") |
|
|
|
FUNCTION = "run" |
|
|
|
CATEGORY = "♾️Mixlab/Mask" |
|
OUTPUT_NODE = True |
|
INPUT_IS_LIST = True |
|
OUTPUT_IS_LIST = (True,True,True,) |
|
|
|
def run(self,image,model_name): |
|
|
|
|
|
model_name=model_name[0] |
|
|
|
images=[] |
|
|
|
for ims in image: |
|
for im in ims: |
|
im=tensor2pil(im) |
|
images.append(im) |
|
|
|
if model_name=='briarmbg': |
|
masks,rgba_images,rgb_images=run_briarmbg(images) |
|
else: |
|
masks,rgba_images,rgb_images=run_rembg(model_name,images, comfy.utils.ProgressBar(len(images) )) |
|
|
|
masks=[pil2tensor(m) for m in masks] |
|
|
|
rgba_images=[pil2tensor(m) for m in rgba_images] |
|
|
|
rgb_images=[pil2tensor(m) for m in rgb_images] |
|
|
|
return (masks,rgb_images,rgba_images,) |