danbooru_tagger / app.py
Johnny-Z's picture
Upload 3 files
16cf4a9 verified
raw
history blame
12.1 kB
import argparse
from transformers import AutoModel, CLIPImageProcessor
import torch
import json
import torch.nn as nn
from PIL import Image
import gradio as gr
TITLE = "Danbooru Tagger"
DESCRIPTION = """
## Dataset
- Source: Cleaned Danbooru
- Last Update: December 28, 2024
## Metrics
- Validation Split: 10% of images
- Validation Results (Macro F1 Score):
- General & Character: 0.4916
- Artist: 0.6677
"""
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
shortest_edge = 512
patch_size = 16
device = torch.device('cpu')
image_processor = CLIPImageProcessor(crop_size={'height': 512, 'width': 512}, do_center_crop=False, do_convert_rgb=True, do_normalize=False, do_rescale=True, do_resize=False, image_processor_type="CLIPImageProcessor", processor_class="CLIPProcessor", resample=3, size={"shortest_edge": 512})
model = AutoModel.from_pretrained('nvidia/RADIO-H', trust_remote_code=True).to(device)
model.eval()
class MLP(nn.Module):
def __init__(self, input_size, class_num):
super().__init__()
self.layers0 = nn.Sequential(
nn.Linear(input_size, 1280),
nn.LayerNorm(1280),
nn.Mish()
)
self.layers1 = nn.Sequential(
nn.Linear(640, class_num),
nn.Sigmoid()
)
self.layers2 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2)
)
self.layers3 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2)
)
self.layers4 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2)
)
def forward(self, x):
out = self.layers0(x)
out = self.layers2(out) + self.layers3(out) + self.layers4(out)
out = self.layers1(out)
return out
with open('general_tag_dict.json', 'r', encoding='utf-8') as f:
general_dict = json.load(f)
with open('character_tag_dict.json', 'r', encoding='utf-8') as f:
character_dict = json.load(f)
with open('artist_tag_dict.json', 'r', encoding='utf-8') as f:
artist_dict = json.load(f)
with open('implications_list.json', 'r', encoding='utf-8') as f:
implications_list = json.load(f)
general_class = 9775
mlp_general = MLP(3840, general_class)
general_s = torch.load("cls_predictor.pth", map_location=device)
mlp_general.load_state_dict(general_s)
mlp_general.to(device)
mlp_general.eval()
character_class = 7568
mlp_character = MLP(3840, character_class)
character_s = torch.load("character_predictor.pth", map_location=device)
mlp_character.load_state_dict(character_s)
mlp_character.to(device)
mlp_character.eval()
artist_class = 13957
mlp_artist = MLP(3840, artist_class)
artist_s = torch.load("artist_predictor.pth", map_location=device)
mlp_artist.load_state_dict(artist_s)
mlp_artist.to(device)
mlp_artist.eval()
class AES(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers0 = nn.Sequential(
nn.Linear(input_size, 1280),
nn.LayerNorm(1280),
nn.Mish()
)
self.layers1 = nn.Sequential(
nn.Sigmoid()
)
self.layers2 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2),
nn.Linear(640, 1)
)
self.layers3 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2),
nn.Linear(640, 1)
)
self.layers4 = nn.Sequential(
nn.Linear(1280, 640),
nn.LayerNorm(640),
nn.Mish(),
nn.Dropout(0.2),
nn.Linear(640, 1)
)
def forward(self, x):
out = self.layers0(x)
out = self.layers2(out) + self.layers3(out) + self.layers4(out)
out = self.layers1(out)
return out * 10
mlp_ava = AES(3840)
ava_s = torch.load("aesthetic_predictor_ava.pth", map_location=device)
mlp_ava.load_state_dict(ava_s)
mlp_ava.to(device)
mlp_ava.eval()
def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
prediction = prediction.view(class_num)
predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
general = {}
character = {}
artist = {}
date = {}
rating = {}
for tag, value in tag_dict.items():
if value[2] in predicted_ids:
tag_value = round(prediction[value[2] - 1].item(), 6)
if value[1] == "general" and tag_value >= general_threshold:
general[tag] = tag_value
elif value[1] == "character" and tag_value >= character_threshold:
character[tag] = tag_value
elif value[1] == "artist" and tag_value >= artist_threshold:
artist[tag] = tag_value
elif value[1] == "rating":
rating[tag] = tag_value
elif value[1] == "date":
date[tag] = tag_value
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
if date:
date = {max(date, key=date.get): date[max(date, key=date.get)]}
if rating:
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
return general, character, artist, date, rating
def process_image(image, general_threshold, character_threshold, artist_threshold):
try:
image = image.convert('RGBA')
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background, image).convert('RGB')
width, height = image.size
if width < height:
height = int((shortest_edge / width) * height)
width = shortest_edge
else:
width = int((shortest_edge / height) * width)
height = shortest_edge
height = int(round(height / patch_size) * patch_size)
width = int(round(width / patch_size) * patch_size)
height = max(height, patch_size)
width = max(width, patch_size)
image = image.resize((width, height), Image.LANCZOS)
pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values
pixel_values = pixel_values.to(device).to(torch.bfloat16)
except (OSError, IOError) as e:
print(f"Error opening image: {e}")
return
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
summary, features = model(pixel_values)
outputs = summary.to(torch.float32)
general_prediction = mlp_general(outputs)
general_ = prediction_to_tag(general_prediction, general_dict, general_class, general_threshold, character_threshold, artist_threshold)
general_tags = general_[0]
rating = general_[4]
character_prediction = mlp_character(outputs)
character_ = prediction_to_tag(character_prediction, character_dict, character_class, general_threshold, character_threshold, artist_threshold)
character_tags = character_[1]
artist_prediction = mlp_artist(outputs)
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class, general_threshold, character_threshold, artist_threshold)
artist_tags = artist_[2]
date = artist_[3]
ava_score = round(mlp_ava(outputs).item(), 3)
combined_tags = {**character_tags, **general_tags}
tags_list = [tag for tag in combined_tags]
remove_list = []
for tag in tags_list:
if tag in implications_list:
for implication in implications_list[tag]:
remove_list.append(implication)
tags_list = [tag for tag in tags_list if tag not in remove_list]
tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)")
return tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--slider-step", type=float, default=0.01)
parser.add_argument("--general-threshold", type=float, default=0.61)
parser.add_argument("--character-threshold", type=float, default=0.8)
parser.add_argument("--artist-threshold", type=float, default=0.68)
return parser.parse_args()
def main():
args = parse_args()
with gr.Blocks(title=TITLE) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
with gr.Row():
with gr.Column(variant="panel"):
submit = gr.Button(value="Submit", variant="primary", size="lg")
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
with gr.Row():
general_threshold = gr.Slider(
0,
1,
step=args.slider_step,
value=args.general_threshold,
label="General Threshold",
scale=3,
)
character_threshold = gr.Slider(
0,
1,
step=args.slider_step,
value=args.character_threshold,
label="Character Threshold",
scale=3,
)
artist_threshold = gr.Slider(
0,
1,
step=args.slider_step,
value=args.artist_threshold,
label="Artist Threshold",
scale=3,
)
with gr.Row():
clear = gr.ClearButton(
components=[
image,
],
variant="secondary",
size="lg",
)
gr.Markdown(value=DESCRIPTION)
with gr.Column(variant="panel"):
tags_str = gr.Textbox(label="Output")
with gr.Row():
ava_score = gr.Textbox(label="Aesthetic Score (AVA)")
with gr.Row():
rating = gr.Label(label="Rating")
date = gr.Label(label="Year")
artist_tags = gr.Label(label="Artist")
character_tags = gr.Label(label="Character")
general_tags = gr.Label(label="General")
clear.add(
[
tags_str,
artist_tags,
general_tags,
character_tags,
rating,
date,
ava_score
]
)
submit.click(
process_image,
inputs=[
image,
general_threshold,
character_threshold,
artist_threshold
],
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score],
)
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main()