svjack's picture
Upload 46 files
10f36ff verified
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from animeinsseg import AnimeInsSeg, AnimeInstances
from animeinsseg.anime_instances import get_color
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
from datasets import load_dataset
import pathlib
# 安装必要的库
os.system("mim install mmengine")
os.system('mim install mmcv==2.1.0')
os.system("mim install mmdet==3.2.0")
# 加载模型
if not os.path.exists("models"):
os.mkdir("models")
os.system("huggingface-cli lfs-enable-largefiles .")
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
mask_thres = 0.3
instance_thres = 0.3
refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
# refine_kwargs = None
net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
# 加载数据集
Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
ds_size = len(Genshin_Impact_Illustration_ds)
name_image_dict = {}
for i in range(ds_size):
row_dict = Genshin_Impact_Illustration_ds[i]
name_image_dict[row_dict["name"]] = row_dict["image"]
# 从数据集中选择一些图片作为示例
example_images = list(map(str, list(pathlib.Path(".").rglob("*.png"))))
def fn(image, model_name):
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
instances: AnimeInstances = net.infer(
img,
output_type='numpy',
pred_score_thr=instance_thres
)
drawed = img.copy()
im_h, im_w = img.shape[:2]
# instances.bboxes, instances.masks will be None, None if no obj is detected
if instances.bboxes is None:
return Image.fromarray(drawed[..., ::-1]), "No instances detected"
# 用于存储每个 bbox 的 top1 结果
top1_results = []
for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
color = get_color(ii)
mask_alpha = 0.5
linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
# 提取 bbox 区域
x1, y1, w, h = map(int, xywh)
x2, y2 = x1 + w, y1 + h
bbox_image = img[y1:y2, x1:x2]
# 计算相似度
threshold = ccip_default_threshold(model_name)
results = []
for name, imagey in name_image_dict.items():
# 将数据集中的图片调整为与 bbox 区域相同的大小
imagey_resized = cv2.resize(imagey, (w, h))
diff = ccip_difference(bbox_image, imagey_resized)
result = (diff, 'Same' if diff <= threshold else 'Not Same', name)
results.append(result)
# 按照 diff 值进行排序
results.sort(key=lambda x: x[0])
top1_result = results[0]
top1_results.append(top1_result)
# 绘制 bbox
p1, p2 = (x1, y1), (x2, y2)
cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
# 绘制 mask
p = mask.astype(np.float32)
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
alpha_msk = (mask_alpha * p)[..., None]
alpha_ori = 1 - alpha_msk
drawed = drawed * alpha_ori + alpha_msk * blend_mask
drawed = drawed.astype(np.uint8)
# 在 bbox 旁边绘制 top1 结果
text = f"Diff: {top1_result[0]:.2f}, {top1_result[1]}, Name: {top1_result[2]}"
cv2.putText(drawed, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
return Image.fromarray(drawed[..., ::-1]), "\n".join([f"Bbox {i+1}: {res}" for i, res in enumerate(top1_results)])
# 创建 Gradio 界面
iface = gr.Interface(
# design titles and text descriptions
title="Anime Subject Instance Segmentation with Similarity Comparison",
description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
fn=fn,
inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')],
outputs=[gr.Image(type="pil"), gr.Textbox(label="Top1 Results for Each Bbox")],
examples=example_images
)
iface.launch(share=True)