import gradio as gr import os import cv2 from PIL import Image import numpy as np from animeinsseg import AnimeInsSeg, AnimeInstances from animeinsseg.anime_instances import get_color from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold from datasets import load_dataset import pathlib # 安装必要的库 os.system("mim install mmengine") os.system('mim install mmcv==2.1.0') os.system("mim install mmdet==3.2.0") # 加载模型 if not os.path.exists("models"): os.mkdir("models") os.system("huggingface-cli lfs-enable-largefiles .") os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation") ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt' mask_thres = 0.3 instance_thres = 0.3 refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet # refine_kwargs = None net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs) # 加载数据集 Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"] ds_size = len(Genshin_Impact_Illustration_ds) name_image_dict = {} for i in range(ds_size): row_dict = Genshin_Impact_Illustration_ds[i] name_image_dict[row_dict["name"]] = row_dict["image"] # 从数据集中选择一些图片作为示例 example_images = list(map(str, list(pathlib.Path(".").rglob("*.png")))) def fn(image, model_name): img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) instances: AnimeInstances = net.infer( img, output_type='numpy', pred_score_thr=instance_thres ) drawed = img.copy() im_h, im_w = img.shape[:2] # instances.bboxes, instances.masks will be None, None if no obj is detected if instances.bboxes is None: return Image.fromarray(drawed[..., ::-1]), "No instances detected" # 用于存储每个 bbox 的 top1 结果 top1_results = [] for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)): color = get_color(ii) mask_alpha = 0.5 linewidth = max(round(sum(img.shape) / 2 * 0.003), 2) # 提取 bbox 区域 x1, y1, w, h = map(int, xywh) x2, y2 = x1 + w, y1 + h bbox_image = img[y1:y2, x1:x2] # 计算相似度 threshold = ccip_default_threshold(model_name) results = [] for name, imagey in name_image_dict.items(): # 将数据集中的图片调整为与 bbox 区域相同的大小 imagey_resized = cv2.resize(imagey, (w, h)) diff = ccip_difference(bbox_image, imagey_resized) result = (diff, 'Same' if diff <= threshold else 'Not Same', name) results.append(result) # 按照 diff 值进行排序 results.sort(key=lambda x: x[0]) top1_result = results[0] top1_results.append(top1_result) # 绘制 bbox p1, p2 = (x1, y1), (x2, y2) cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA) # 绘制 mask p = mask.astype(np.float32) blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32) alpha_msk = (mask_alpha * p)[..., None] alpha_ori = 1 - alpha_msk drawed = drawed * alpha_ori + alpha_msk * blend_mask drawed = drawed.astype(np.uint8) # 在 bbox 旁边绘制 top1 结果 text = f"Diff: {top1_result[0]:.2f}, {top1_result[1]}, Name: {top1_result[2]}" cv2.putText(drawed, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) return Image.fromarray(drawed[..., ::-1]), "\n".join([f"Bbox {i+1}: {res}" for i, res in enumerate(top1_results)]) # 创建 Gradio 界面 iface = gr.Interface( # design titles and text descriptions title="Anime Subject Instance Segmentation with Similarity Comparison", description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).", fn=fn, inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')], outputs=[gr.Image(type="pil"), gr.Textbox(label="Top1 Results for Each Bbox")], examples=example_images ) iface.launch(share=True)