''' python seg_script.py Genshin_Impact_Images Genshin_Impact_Images_Seg ''' import os import cv2 import argparse from PIL import Image import numpy as np from tqdm import tqdm from pathlib import Path from animeinsseg import AnimeInsSeg, AnimeInstances from animeinsseg.anime_instances import get_color # 设置模型路径 ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt' mask_thres = 0.3 instance_thres = 0.3 refine_kwargs = {'refine_method': 'refinenet_isnet'} # 如果不使用 refinenet,设置为 None # refine_kwargs = None # 初始化模型 net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs) def process_image(image_path, output_dir): # 读取图像 img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 推理 instances: AnimeInstances = net.infer( img, output_type='numpy', pred_score_thr=instance_thres ) # 初始化输出图像 drawed = img.copy() im_h, im_w = img.shape[:2] # 创建黑色背景的 mask 图像 mask_image = np.zeros((im_h, im_w), dtype=np.uint8) # 如果没有检测到对象,直接返回原图 if instances.bboxes is None: return # 处理每个实例 for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)): color = get_color(ii) mask_alpha = 0.5 linewidth = max(round(sum(img.shape) / 2 * 0.003), 2) # 绘制边界框 p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1])) cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA) # 绘制掩码 p = mask.astype(np.float32) blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32) alpha_msk = (mask_alpha * p)[..., None] alpha_ori = 1 - alpha_msk drawed = drawed * alpha_ori + alpha_msk * blend_mask drawed = drawed.astype(np.uint8) # 将当前实例的 mask 绘制到黑色背景上(白色表示 mask) mask_image[mask > 0] = 255 # 裁剪图像 x1, y1, x2, y2 = int(xywh[0]), int(xywh[1]), int(xywh[0] + xywh[2]), int(xywh[1] + xywh[3]) cropped_img = img[y1:y2, x1:x2] cropped_mask = mask[y1:y2, x1:x2] # 创建透明通道的边缘图 alpha_channel = (cropped_mask * 255).astype(np.uint8) rgba_image = np.dstack((cropped_img, alpha_channel)) # 保存裁剪后的图像和分割后的图像 base_name = Path(image_path).stem output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 保存裁剪后的图像 Image.fromarray(cropped_img).save(output_path / f"{base_name}_cropped_{ii}.png") # 保存分割后的图像(RGBA) Image.fromarray(rgba_image, 'RGBA').save(output_path / f"{base_name}_segmented_{ii}.png") # 保存绘制后的图像 Image.fromarray(drawed).save(output_path / f"{base_name}_drawed.png") # 保存 mask 图像 Image.fromarray(mask_image).save(output_path / f"{base_name}_mask.png") def main(): parser = argparse.ArgumentParser(description="Anime Instance Segmentation") parser.add_argument("input_path", type=str, help="Path to the input image or folder") parser.add_argument("output_dir", type=str, help="Path to the output directory") args = parser.parse_args() input_path = Path(args.input_path) output_dir = Path(args.output_dir) if input_path.is_file(): process_image(input_path, output_dir) elif input_path.is_dir(): image_paths = list(input_path.rglob("*.png")) + list(input_path.rglob("*.jpg")) for image_path in tqdm(image_paths, desc="Processing images"): process_image(image_path, output_dir) else: print("Invalid input path. Please provide a valid image or folder path.") if __name__ == "__main__": main()