import os,sys import folder_paths from PIL import Image import importlib.util import comfy.utils import numpy as np import torch from huggingface_hub import hf_hub_download import torch.nn as nn import torch.nn.functional as F from torchvision.transforms.functional import normalize # BRIA-RMBG-1.4 / briarmbg.py class REBNCONV(nn.Module): def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): super(REBNCONV,self).__init__() self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) self.bn_s1 = nn.BatchNorm2d(out_ch) self.relu_s1 = nn.ReLU(inplace=True) def forward(self,x): hx = x xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) return xout ## upsample tensor 'src' to have the same spatial size with tensor 'tar' def _upsample_like(src,tar): src = F.interpolate(src,size=tar.shape[2:],mode='bilinear') return src ### RSU-7 ### class RSU7(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): super(RSU7,self).__init__() self.in_ch = in_ch self.mid_ch = mid_ch self.out_ch = out_ch self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): b, c, h, w = x.shape hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx = self.pool5(hx5) hx6 = self.rebnconv6(hx) hx7 = self.rebnconv7(hx6) hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) hx6dup = _upsample_like(hx6d,hx5) hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-6 ### class RSU6(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU6,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx6 = self.rebnconv6(hx5) hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-5 ### class RSU5(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU5,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx5 = self.rebnconv5(hx4) hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-4 ### class RSU4(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-4F ### class RSU4F(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4F,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx2 = self.rebnconv2(hx1) hx3 = self.rebnconv3(hx2) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) return hx1d + hxin class myrebnconv(nn.Module): def __init__(self, in_ch=3, out_ch=1, kernel_size=3, stride=1, padding=1, dilation=1, groups=1): super(myrebnconv,self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) self.bn = nn.BatchNorm2d(out_ch) self.rl = nn.ReLU(inplace=True) def forward(self,x): return self.rl(self.bn(self.conv(x))) class BriaRMBG(nn.Module): def __init__(self,in_ch=3,out_ch=1): super(BriaRMBG,self).__init__() self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage1 = RSU7(64,32,64) self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage2 = RSU6(64,32,128) self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage3 = RSU5(128,64,256) self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage4 = RSU4(256,128,512) self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage5 = RSU4F(512,256,512) self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage6 = RSU4F(512,256,512) # decoder self.stage5d = RSU4F(1024,256,512) self.stage4d = RSU4(1024,128,256) self.stage3d = RSU5(512,64,128) self.stage2d = RSU6(256,32,64) self.stage1d = RSU7(128,16,64) self.side1 = nn.Conv2d(64,out_ch,3,padding=1) self.side2 = nn.Conv2d(64,out_ch,3,padding=1) self.side3 = nn.Conv2d(128,out_ch,3,padding=1) self.side4 = nn.Conv2d(256,out_ch,3,padding=1) self.side5 = nn.Conv2d(512,out_ch,3,padding=1) self.side6 = nn.Conv2d(512,out_ch,3,padding=1) # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) def forward(self,x): hx = x hxin = self.conv_in(hx) #hx = self.pool_in(hxin) #stage 1 hx1 = self.stage1(hxin) hx = self.pool12(hx1) #stage 2 hx2 = self.stage2(hx) hx = self.pool23(hx2) #stage 3 hx3 = self.stage3(hx) hx = self.pool34(hx3) #stage 4 hx4 = self.stage4(hx) hx = self.pool45(hx4) #stage 5 hx5 = self.stage5(hx) hx = self.pool56(hx5) #stage 6 hx6 = self.stage6(hx) hx6up = _upsample_like(hx6,hx5) #-------------------- decoder -------------------- hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) #side output d1 = self.side1(hx1d) d1 = _upsample_like(d1,x) d2 = self.side2(hx2d) d2 = _upsample_like(d2,x) d3 = self.side3(hx3d) d3 = _upsample_like(d3,x) d4 = self.side4(hx4d) d4 = _upsample_like(d4,x) d5 = self.side5(hx5d) d5 = _upsample_like(d5,x) d6 = self.side6(hx6) d6 = _upsample_like(d6,x) return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] def get_U2NET_model_path(): try: return folder_paths.get_folder_paths('rembg')[0] except: return os.path.join(folder_paths.models_dir, "rembg") U2NET_HOME=get_U2NET_model_path() os.environ["U2NET_HOME"] = U2NET_HOME global _available _available=False def get_rembg_models(path): """从目录中获取文件并提取文件名 Args: path: 目录路径 Returns: 文件名列表 """ filenames = [] for root, _, files in os.walk(path): for filename in files: # 过滤隐藏文件 if not filename.startswith('.'): name, ext = os.path.splitext(os.path.basename(filename)) filenames.append(name) return filenames def is_installed(package): try: spec = importlib.util.find_spec(package) except ModuleNotFoundError: return False return spec is not None try: if is_installed('rembg')==False: import subprocess # 安装 print('#pip install rembg[gpu]') result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'rembg[gpu]'], capture_output=True, text=True) #检查命令执行结果 if result.returncode == 0: print("#install success") from rembg import new_session, remove _available=True else: print("#install error") else: from rembg import new_session, remove _available=True except: _available=False def run_briarmbg(images=[]): mroot=U2NET_HOME m=os.path.join(mroot,'briarmbg.pth') if os.path.exists(m)==False: # 下载 m1=hf_hub_download("briaai/RMBG-1.4", local_dir=mroot, filename='model.pth', local_dir_use_symlinks=False, endpoint='https://hf-mirror.com') os.rename(m1, m) net=BriaRMBG() if torch.cuda.is_available(): net.load_state_dict(torch.load(m)) net=net.cuda() else: net.load_state_dict(torch.load(m,map_location="cpu")) net.eval() masks=[] rgba_images=[] rgb_images=[] for orig_image in images: w,h = orig_im_size = orig_image.size image = orig_image.convert('RGB') model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if torch.cuda.is_available(): im_tensor=im_tensor.cuda() result=net(im_tensor) result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) ma = torch.max(result) mi = torch.min(result) result = (result-mi)/(ma-mi) im_array = (result*255).cpu().data.numpy().astype(np.uint8) mask = Image.fromarray(np.squeeze(im_array)) # mask.save('test.png') # mask=tensor2pil(result) mask=mask.convert('L') masks.append(mask) # rgba图 image_rgba =orig_image.convert("RGBA") image_rgba.putalpha(mask) rgba_images.append(image_rgba) #rgb rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) rgb_images.append(rgb_image) return (masks,rgba_images,rgb_images) def run_rembg(model_name= "unet",images=[],callback=None): # model_name = "unet" # "isnet-general-use" # print('#run_rembg',model_name) rembg_session = new_session(model_name) masks=[] rgba_images=[] rgb_images=[] # 进度条 pbar=callback for img in images: # use the post_process_mask argument to post process the mask to get better results. mask = remove(img, session=rembg_session,only_mask=True,post_process_mask=True) # mask=mask.convert('L') # masks.append(mask) if model_name=="u2net_cloth_seg": width, original_height = mask.size num_slices = original_height // img.height for i in range(num_slices): top = i * img.height bottom = (i + 1) * img.height slice_image = mask.crop((0, top, width, bottom)) slice_mask=slice_image.convert('L') masks.append(slice_mask) # rgba图 image_rgba = img.convert("RGBA") image_rgba.putalpha(slice_mask) rgba_images.append(image_rgba) #rgb rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) rgb_images.append(rgb_image) else: mask=mask.convert('L') # mask.save(output_path) masks.append(mask) # rgba图 image_rgba = img.convert("RGBA") image_rgba.putalpha(mask) rgba_images.append(image_rgba) #rgb rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) rgb_images.append(rgb_image) if pbar: pbar.update(1) return (masks,rgba_images,rgb_images) # Tensor to PIL def tensor2pil(image): return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) # Convert PIL to Tensor def pil2tensor(image): return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) class RembgNode_: global _available available=_available @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), "model_name": (get_rembg_models(U2NET_HOME),), }, } RETURN_TYPES = ("MASK","IMAGE","RGBA",) RETURN_NAMES = ("masks","images","RGBAs") FUNCTION = "run" CATEGORY = "♾️Mixlab/Mask" OUTPUT_NODE = True INPUT_IS_LIST = True OUTPUT_IS_LIST = (True,True,True,) def run(self,image,model_name): # 兼容list输入和batch输入 model_name=model_name[0] images=[] for ims in image: for im in ims: im=tensor2pil(im) images.append(im) if model_name=='briarmbg': masks,rgba_images,rgb_images=run_briarmbg(images) else: masks,rgba_images,rgb_images=run_rembg(model_name,images, comfy.utils.ProgressBar(len(images) )) masks=[pil2tensor(m) for m in masks] rgba_images=[pil2tensor(m) for m in rgba_images] rgb_images=[pil2tensor(m) for m in rgb_images] return (masks,rgb_images,rgba_images,)