Spaces:
Runtime error
Runtime error
File size: 4,844 Bytes
10f36ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from animeinsseg import AnimeInsSeg, AnimeInstances
from animeinsseg.anime_instances import get_color
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
from datasets import load_dataset
import pathlib
# 安装必要的库
os.system("mim install mmengine")
os.system('mim install mmcv==2.1.0')
os.system("mim install mmdet==3.2.0")
# 加载模型
if not os.path.exists("models"):
os.mkdir("models")
os.system("huggingface-cli lfs-enable-largefiles .")
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
mask_thres = 0.3
instance_thres = 0.3
refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
# refine_kwargs = None
net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
# 加载数据集
Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
ds_size = len(Genshin_Impact_Illustration_ds)
name_image_dict = {}
for i in range(ds_size):
row_dict = Genshin_Impact_Illustration_ds[i]
name_image_dict[row_dict["name"]] = row_dict["image"]
# 从数据集中选择一些图片作为示例
example_images = list(map(str, list(pathlib.Path(".").rglob("*.png"))))
def fn(image, model_name):
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
instances: AnimeInstances = net.infer(
img,
output_type='numpy',
pred_score_thr=instance_thres
)
drawed = img.copy()
im_h, im_w = img.shape[:2]
# instances.bboxes, instances.masks will be None, None if no obj is detected
if instances.bboxes is None:
return Image.fromarray(drawed[..., ::-1]), "No instances detected"
# 用于存储每个 bbox 的 top5 结果
top5_results = []
for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
color = get_color(ii)
mask_alpha = 0.5
linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
# 提取 bbox 区域
x1, y1, w, h = map(int, xywh)
x2, y2 = x1 + w, y1 + h
bbox_image = img[y1:y2, x1:x2]
# 计算相似度
threshold = ccip_default_threshold(model_name)
results = []
for name, imagey in name_image_dict.items():
# 将数据集中的图片调整为与 bbox 区域相同的大小
imagey_resized = cv2.resize(imagey, (w, h))
diff = ccip_difference(bbox_image, imagey_resized)
result = (diff, 'Same' if diff <= threshold else 'Not Same', name)
results.append(result)
# 按照 diff 值进行排序
results.sort(key=lambda x: x[0])
top5_results.append(results[:5]) # 取 top5 结果
# 绘制 bbox
p1, p2 = (x1, y1), (x2, y2)
cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
# 绘制 mask
p = mask.astype(np.float32)
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
alpha_msk = (mask_alpha * p)[..., None]
alpha_ori = 1 - alpha_msk
drawed = drawed * alpha_ori + alpha_msk * blend_mask
drawed = drawed.astype(np.uint8)
# 创建调色盘图像
palette_height = 100
palette_width = im_w
palette = np.zeros((palette_height, palette_width, 3), dtype=np.uint8)
# 绘制每个 bbox 的 top5 结果
for idx, results in enumerate(top5_results):
color = get_color(idx)
x_start = idx * (palette_width // len(top5_results))
x_end = (idx + 1) * (palette_width // len(top5_results))
# 填充颜色
palette[:, x_start:x_end] = color
# 在调色盘上绘制 top5 结果
for i, (diff, pred, name) in enumerate(results):
text = f"{name}: {diff:.2f} ({pred})"
y_pos = 20 + i * 15
cv2.putText(palette, text, (x_start + 10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
return Image.fromarray(drawed[..., ::-1]), Image.fromarray(palette)
# 创建 Gradio 界面
iface = gr.Interface(
# design titles and text descriptions
title="Anime Subject Instance Segmentation with Similarity Comparison",
description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
fn=fn,
inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')],
outputs=[gr.Image(type="pil", label="Segmentation Result"), gr.Image(type="pil", label="Top5 Results Palette")],
examples=example_images
)
iface.launch(share=True)
|