Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
from animeinsseg import AnimeInsSeg, AnimeInstances | |
from animeinsseg.anime_instances import get_color | |
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold | |
from datasets import load_dataset | |
import pathlib | |
# 安装必要的库 | |
os.system("mim install mmengine") | |
os.system('mim install mmcv==2.1.0') | |
os.system("mim install mmdet==3.2.0") | |
# 加载模型 | |
if not os.path.exists("models"): | |
os.mkdir("models") | |
os.system("huggingface-cli lfs-enable-largefiles .") | |
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation") | |
ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt' | |
mask_thres = 0.3 | |
instance_thres = 0.3 | |
refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet | |
# refine_kwargs = None | |
net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs) | |
# 加载数据集 | |
Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"] | |
ds_size = len(Genshin_Impact_Illustration_ds) | |
name_image_dict = {} | |
for i in range(ds_size): | |
row_dict = Genshin_Impact_Illustration_ds[i] | |
name_image_dict[row_dict["name"]] = row_dict["image"] | |
# 从数据集中选择一些图片作为示例 | |
example_images = list(map(str, list(pathlib.Path(".").rglob("*.png")))) | |
def fn(image, model_name): | |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
instances: AnimeInstances = net.infer( | |
img, | |
output_type='numpy', | |
pred_score_thr=instance_thres | |
) | |
drawed = img.copy() | |
im_h, im_w = img.shape[:2] | |
# instances.bboxes, instances.masks will be None, None if no obj is detected | |
if instances.bboxes is None: | |
return Image.fromarray(drawed[..., ::-1]), "No instances detected" | |
# 用于存储每个 bbox 的 top5 结果 | |
top5_results = [] | |
for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)): | |
color = get_color(ii) | |
mask_alpha = 0.5 | |
linewidth = max(round(sum(img.shape) / 2 * 0.003), 2) | |
# 提取 bbox 区域 | |
x1, y1, w, h = map(int, xywh) | |
x2, y2 = x1 + w, y1 + h | |
bbox_image = img[y1:y2, x1:x2] | |
# 计算相似度 | |
threshold = ccip_default_threshold(model_name) | |
results = [] | |
for name, imagey in name_image_dict.items(): | |
# 将数据集中的图片调整为与 bbox 区域相同的大小 | |
imagey_resized = cv2.resize(imagey, (w, h)) | |
diff = ccip_difference(bbox_image, imagey_resized) | |
result = (diff, 'Same' if diff <= threshold else 'Not Same', name) | |
results.append(result) | |
# 按照 diff 值进行排序 | |
results.sort(key=lambda x: x[0]) | |
top5_results.append(results[:5]) # 取 top5 结果 | |
# 绘制 bbox | |
p1, p2 = (x1, y1), (x2, y2) | |
cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA) | |
# 绘制 mask | |
p = mask.astype(np.float32) | |
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32) | |
alpha_msk = (mask_alpha * p)[..., None] | |
alpha_ori = 1 - alpha_msk | |
drawed = drawed * alpha_ori + alpha_msk * blend_mask | |
drawed = drawed.astype(np.uint8) | |
# 创建调色盘图像 | |
palette_height = 100 | |
palette_width = im_w | |
palette = np.zeros((palette_height, palette_width, 3), dtype=np.uint8) | |
# 绘制每个 bbox 的 top5 结果 | |
for idx, results in enumerate(top5_results): | |
color = get_color(idx) | |
x_start = idx * (palette_width // len(top5_results)) | |
x_end = (idx + 1) * (palette_width // len(top5_results)) | |
# 填充颜色 | |
palette[:, x_start:x_end] = color | |
# 在调色盘上绘制 top5 结果 | |
for i, (diff, pred, name) in enumerate(results): | |
text = f"{name}: {diff:.2f} ({pred})" | |
y_pos = 20 + i * 15 | |
cv2.putText(palette, text, (x_start + 10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA) | |
return Image.fromarray(drawed[..., ::-1]), Image.fromarray(palette) | |
# 创建 Gradio 界面 | |
iface = gr.Interface( | |
# design titles and text descriptions | |
title="Anime Subject Instance Segmentation with Similarity Comparison", | |
description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).", | |
fn=fn, | |
inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')], | |
outputs=[gr.Image(type="pil", label="Segmentation Result"), gr.Image(type="pil", label="Top5 Results Palette")], | |
examples=example_images | |
) | |
iface.launch(share=True) | |