|
from os import getenv |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
from rich.traceback import install as traceback_install |
|
|
|
from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image |
|
from tagger.model import load_model_and_transform, process_heatmap |
|
|
|
TITLE = "WD Tagger Heatmap For More Models" |
|
DESCRIPTION = """WD Tagger v3 Heatmap Generator.""" |
|
|
|
HF_TOKEN = getenv("HF_TOKEN", None) |
|
|
|
|
|
AVAILABLE_MODEL_REPOS = [ |
|
'SmilingWolf/wd-convnext-tagger-v3', |
|
'SmilingWolf/wd-swinv2-tagger-v3', |
|
'SmilingWolf/wd-vit-tagger-v3', |
|
'SmilingWolf/wd-vit-large-tagger-v3', |
|
"SmilingWolf/wd-eva02-large-tagger-v3", |
|
] |
|
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3" |
|
|
|
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() |
|
|
|
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] |
|
|
|
_ = traceback_install(show_locals=True, locals_max_length=0) |
|
|
|
|
|
example_images = sorted( |
|
[ |
|
str(x.relative_to(WORK_DIR)) |
|
for x in WORK_DIR.joinpath("examples").iterdir() |
|
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS |
|
] |
|
) |
|
|
|
|
|
def predict( |
|
image: Image.Image, |
|
model_repo: str, |
|
threshold: float = 0.5, |
|
): |
|
|
|
model, transform = load_model_and_transform(model_repo) |
|
|
|
labels: LabelData = load_labels_hf(model_repo) |
|
|
|
image = preprocess_image(image, (448, 448)) |
|
image = transform(image).unsqueeze(0) |
|
|
|
|
|
heatmaps: list[Heatmap] |
|
image_labels: ImageLabels |
|
heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold) |
|
|
|
heatmap_images = [(x.image, x.label) for x in heatmaps] |
|
|
|
return ( |
|
heatmap_images, |
|
heatmap_grid, |
|
image_labels.caption, |
|
image_labels.booru, |
|
image_labels.rating, |
|
image_labels.character, |
|
image_labels.general, |
|
) |
|
|
|
|
|
css = """ |
|
#use_mcut, #char_mcut { |
|
padding-top: var(--scale-3); |
|
} |
|
#threshold.dimmed { |
|
filter: brightness(75%); |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: |
|
with gr.Row(equal_height=False): |
|
with gr.Column(min_width=720): |
|
with gr.Group(): |
|
img_input = gr.Image( |
|
label="Input", |
|
type="pil", |
|
image_mode="RGB", |
|
sources=["upload", "clipboard"], |
|
) |
|
with gr.Group(): |
|
with gr.Row(): |
|
threshold = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.35, |
|
step=0.01, |
|
label="Tag Threshold", |
|
scale=5, |
|
elem_id="threshold", |
|
) |
|
model_to_use = gr.Dropdown( |
|
choices=AVAILABLE_MODEL_REPOS, |
|
value=MODEL_REPO, |
|
) |
|
with gr.Row(): |
|
clear = gr.ClearButton( |
|
components=[], |
|
variant="secondary", |
|
size="lg", |
|
) |
|
submit = gr.Button(value="Submit", variant="primary", size="lg") |
|
|
|
with gr.Column(min_width=720): |
|
with gr.Tab(label="Heatmaps"): |
|
heatmap_gallery = gr.Gallery(columns=3, show_label=False) |
|
with gr.Tab(label="Grid"): |
|
heatmap_grid = gr.Image(show_label=False) |
|
with gr.Tab(label="Tags"): |
|
with gr.Group(): |
|
caption = gr.Textbox(label="Caption", show_copy_button=True) |
|
tags = gr.Textbox(label="Tags", show_copy_button=True) |
|
with gr.Group(): |
|
rating = gr.Label(label="Rating") |
|
with gr.Group(): |
|
character = gr.Label(label="Character") |
|
with gr.Group(): |
|
general = gr.Label(label="General") |
|
|
|
with gr.Row(): |
|
examples = [[imgpath, MODEL_REPO, 0.35] for imgpath in example_images] |
|
examples = gr.Examples( |
|
examples=examples, |
|
inputs=[img_input, model_to_use, threshold], |
|
) |
|
|
|
|
|
clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general]) |
|
|
|
submit.click( |
|
predict, |
|
inputs=[img_input, model_to_use, threshold], |
|
outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general], |
|
api_name="predict", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=10) |
|
if getenv("SPACE_ID", None) is not None: |
|
demo.launch() |
|
else: |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7871, |
|
debug=True, |
|
) |
|
|