narugo's picture
Update app.py
9f181f5 verified
from os import getenv
from pathlib import Path
import gradio as gr
from PIL import Image
from rich.traceback import install as traceback_install
from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
from tagger.model import load_model_and_transform, process_heatmap
TITLE = "WD Tagger Heatmap For More Models"
DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
# get HF token
HF_TOKEN = getenv("HF_TOKEN", None)
# model repo and cache
AVAILABLE_MODEL_REPOS = [
'SmilingWolf/wd-convnext-tagger-v3',
'SmilingWolf/wd-swinv2-tagger-v3',
'SmilingWolf/wd-vit-tagger-v3',
'SmilingWolf/wd-vit-large-tagger-v3',
"SmilingWolf/wd-eva02-large-tagger-v3",
]
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
# get the repo root (or the current working directory if running in ipython)
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
# allowed extensions
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
_ = traceback_install(show_locals=True, locals_max_length=0)
# get the example images
example_images = sorted(
[
str(x.relative_to(WORK_DIR))
for x in WORK_DIR.joinpath("examples").iterdir()
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
]
)
def predict(
image: Image.Image,
model_repo: str,
threshold: float = 0.5,
):
# join variant for cache key
model, transform = load_model_and_transform(model_repo)
# load labels
labels: LabelData = load_labels_hf(model_repo)
# preprocess image
image = preprocess_image(image, (448, 448))
image = transform(image).unsqueeze(0)
# get the model output
heatmaps: list[Heatmap]
image_labels: ImageLabels
heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)
heatmap_images = [(x.image, x.label) for x in heatmaps]
return (
heatmap_images,
heatmap_grid,
image_labels.caption,
image_labels.booru,
image_labels.rating,
image_labels.character,
image_labels.general,
)
css = """
#use_mcut, #char_mcut {
padding-top: var(--scale-3);
}
#threshold.dimmed {
filter: brightness(75%);
}
"""
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
with gr.Row(equal_height=False):
with gr.Column(min_width=720):
with gr.Group():
img_input = gr.Image(
label="Input",
type="pil",
image_mode="RGB",
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.35,
step=0.01,
label="Tag Threshold",
scale=5,
elem_id="threshold",
)
model_to_use = gr.Dropdown(
choices=AVAILABLE_MODEL_REPOS,
value=MODEL_REPO,
)
with gr.Row():
clear = gr.ClearButton(
components=[],
variant="secondary",
size="lg",
)
submit = gr.Button(value="Submit", variant="primary", size="lg")
with gr.Column(min_width=720):
with gr.Tab(label="Heatmaps"):
heatmap_gallery = gr.Gallery(columns=3, show_label=False)
with gr.Tab(label="Grid"):
heatmap_grid = gr.Image(show_label=False)
with gr.Tab(label="Tags"):
with gr.Group():
caption = gr.Textbox(label="Caption", show_copy_button=True)
tags = gr.Textbox(label="Tags", show_copy_button=True)
with gr.Group():
rating = gr.Label(label="Rating")
with gr.Group():
character = gr.Label(label="Character")
with gr.Group():
general = gr.Label(label="General")
with gr.Row():
examples = [[imgpath, MODEL_REPO, 0.35] for imgpath in example_images]
examples = gr.Examples(
examples=examples,
inputs=[img_input, model_to_use, threshold],
)
# tell clear button which components to clear
clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general])
submit.click(
predict,
inputs=[img_input, model_to_use, threshold],
outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=10)
if getenv("SPACE_ID", None) is not None:
demo.launch()
else:
demo.launch(
server_name="0.0.0.0",
server_port=7871,
debug=True,
)