File size: 4,844 Bytes
10f36ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from animeinsseg import AnimeInsSeg, AnimeInstances
from animeinsseg.anime_instances import get_color
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
from datasets import load_dataset
import pathlib

# 安装必要的库
os.system("mim install mmengine")
os.system('mim install mmcv==2.1.0')
os.system("mim install mmdet==3.2.0")

# 加载模型
if not os.path.exists("models"):
    os.mkdir("models")

os.system("huggingface-cli lfs-enable-largefiles .")
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")

ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'

mask_thres = 0.3
instance_thres = 0.3
refine_kwargs = {'refine_method': 'refinenet_isnet'}  # set to None if not using refinenet
# refine_kwargs = None

net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)

# 加载数据集
Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
ds_size = len(Genshin_Impact_Illustration_ds)
name_image_dict = {}
for i in range(ds_size):
    row_dict = Genshin_Impact_Illustration_ds[i]
    name_image_dict[row_dict["name"]] = row_dict["image"]

# 从数据集中选择一些图片作为示例
example_images = list(map(str, list(pathlib.Path(".").rglob("*.png"))))

def fn(image, model_name):
    img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    instances: AnimeInstances = net.infer(
        img,
        output_type='numpy',
        pred_score_thr=instance_thres
    )

    drawed = img.copy()
    im_h, im_w = img.shape[:2]

    # instances.bboxes, instances.masks will be None, None if no obj is detected
    if instances.bboxes is None:
        return Image.fromarray(drawed[..., ::-1]), "No instances detected"

    # 用于存储每个 bbox 的 top5 结果
    top5_results = []

    for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
        color = get_color(ii)

        mask_alpha = 0.5
        linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)

        # 提取 bbox 区域
        x1, y1, w, h = map(int, xywh)
        x2, y2 = x1 + w, y1 + h
        bbox_image = img[y1:y2, x1:x2]

        # 计算相似度
        threshold = ccip_default_threshold(model_name)
        results = []

        for name, imagey in name_image_dict.items():
            # 将数据集中的图片调整为与 bbox 区域相同的大小
            imagey_resized = cv2.resize(imagey, (w, h))
            diff = ccip_difference(bbox_image, imagey_resized)
            result = (diff, 'Same' if diff <= threshold else 'Not Same', name)
            results.append(result)

        # 按照 diff 值进行排序
        results.sort(key=lambda x: x[0])
        top5_results.append(results[:5])  # 取 top5 结果

        # 绘制 bbox
        p1, p2 = (x1, y1), (x2, y2)
        cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)

        # 绘制 mask
        p = mask.astype(np.float32)
        blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
        alpha_msk = (mask_alpha * p)[..., None]
        alpha_ori = 1 - alpha_msk
        drawed = drawed * alpha_ori + alpha_msk * blend_mask

        drawed = drawed.astype(np.uint8)

    # 创建调色盘图像
    palette_height = 100
    palette_width = im_w
    palette = np.zeros((palette_height, palette_width, 3), dtype=np.uint8)

    # 绘制每个 bbox 的 top5 结果
    for idx, results in enumerate(top5_results):
        color = get_color(idx)
        x_start = idx * (palette_width // len(top5_results))
        x_end = (idx + 1) * (palette_width // len(top5_results))

        # 填充颜色
        palette[:, x_start:x_end] = color

        # 在调色盘上绘制 top5 结果
        for i, (diff, pred, name) in enumerate(results):
            text = f"{name}: {diff:.2f} ({pred})"
            y_pos = 20 + i * 15
            cv2.putText(palette, text, (x_start + 10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)

    return Image.fromarray(drawed[..., ::-1]), Image.fromarray(palette)

# 创建 Gradio 界面
iface = gr.Interface(
    # design titles and text descriptions
    title="Anime Subject Instance Segmentation with Similarity Comparison",
    description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
    fn=fn,
    inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')],
    outputs=[gr.Image(type="pil", label="Segmentation Result"), gr.Image(type="pil", label="Top5 Results Palette")],
    examples=example_images
)

iface.launch(share=True)