Spaces:
Running
Running
import argparse | |
from transformers import AutoModel, CLIPImageProcessor | |
import torch | |
import json | |
import torch.nn as nn | |
from PIL import Image | |
import gradio as gr | |
TITLE = "Danbooru Tagger" | |
DESCRIPTION = """ | |
## Dataset | |
- Source: Cleaned Danbooru | |
- Last Update: December 28, 2024 | |
## Metrics | |
- Validation Split: 10% of images | |
- Validation Results (Macro F1 Score): | |
- General & Character: 0.4916 | |
- Artist: 0.6677 | |
""" | |
kaomojis = [ | |
"0_0", | |
"(o)_(o)", | |
"+_+", | |
"+_-", | |
"._.", | |
"<o>_<o>", | |
"<|>_<|>", | |
"=_=", | |
">_<", | |
"3_3", | |
"6_9", | |
">_o", | |
"@_@", | |
"^_^", | |
"o_o", | |
"u_u", | |
"x_x", | |
"|_|", | |
"||_||", | |
] | |
shortest_edge = 512 | |
patch_size = 16 | |
device = torch.device('cpu') | |
image_processor = CLIPImageProcessor(crop_size={'height': 512, 'width': 512}, do_center_crop=False, do_convert_rgb=True, do_normalize=False, do_rescale=True, do_resize=False, image_processor_type="CLIPImageProcessor", processor_class="CLIPProcessor", resample=3, size={"shortest_edge": 512}) | |
model = AutoModel.from_pretrained('nvidia/RADIO-H', trust_remote_code=True).to(device) | |
model.eval() | |
class MLP(nn.Module): | |
def __init__(self, input_size, class_num): | |
super().__init__() | |
self.layers0 = nn.Sequential( | |
nn.Linear(input_size, 1280), | |
nn.LayerNorm(1280), | |
nn.Mish() | |
) | |
self.layers1 = nn.Sequential( | |
nn.Linear(640, class_num), | |
nn.Sigmoid() | |
) | |
self.layers2 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2) | |
) | |
self.layers3 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2) | |
) | |
self.layers4 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2) | |
) | |
def forward(self, x): | |
out = self.layers0(x) | |
out = self.layers2(out) + self.layers3(out) + self.layers4(out) | |
out = self.layers1(out) | |
return out | |
with open('general_tag_dict.json', 'r', encoding='utf-8') as f: | |
general_dict = json.load(f) | |
with open('character_tag_dict.json', 'r', encoding='utf-8') as f: | |
character_dict = json.load(f) | |
with open('artist_tag_dict.json', 'r', encoding='utf-8') as f: | |
artist_dict = json.load(f) | |
with open('implications_list.json', 'r', encoding='utf-8') as f: | |
implications_list = json.load(f) | |
general_class = 9775 | |
mlp_general = MLP(3840, general_class) | |
general_s = torch.load("cls_predictor.pth", map_location=device) | |
mlp_general.load_state_dict(general_s) | |
mlp_general.to(device) | |
mlp_general.eval() | |
character_class = 7568 | |
mlp_character = MLP(3840, character_class) | |
character_s = torch.load("character_predictor.pth", map_location=device) | |
mlp_character.load_state_dict(character_s) | |
mlp_character.to(device) | |
mlp_character.eval() | |
artist_class = 13957 | |
mlp_artist = MLP(3840, artist_class) | |
artist_s = torch.load("artist_predictor.pth", map_location=device) | |
mlp_artist.load_state_dict(artist_s) | |
mlp_artist.to(device) | |
mlp_artist.eval() | |
class AES(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.layers0 = nn.Sequential( | |
nn.Linear(input_size, 1280), | |
nn.LayerNorm(1280), | |
nn.Mish() | |
) | |
self.layers1 = nn.Sequential( | |
nn.Sigmoid() | |
) | |
self.layers2 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2), | |
nn.Linear(640, 1) | |
) | |
self.layers3 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2), | |
nn.Linear(640, 1) | |
) | |
self.layers4 = nn.Sequential( | |
nn.Linear(1280, 640), | |
nn.LayerNorm(640), | |
nn.Mish(), | |
nn.Dropout(0.2), | |
nn.Linear(640, 1) | |
) | |
def forward(self, x): | |
out = self.layers0(x) | |
out = self.layers2(out) + self.layers3(out) + self.layers4(out) | |
out = self.layers1(out) | |
return out * 10 | |
mlp_ava = AES(3840) | |
ava_s = torch.load("aesthetic_predictor_ava.pth", map_location=device) | |
mlp_ava.load_state_dict(ava_s) | |
mlp_ava.to(device) | |
mlp_ava.eval() | |
def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold): | |
prediction = prediction.view(class_num) | |
predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1 | |
general = {} | |
character = {} | |
artist = {} | |
date = {} | |
rating = {} | |
for tag, value in tag_dict.items(): | |
if value[2] in predicted_ids: | |
tag_value = round(prediction[value[2] - 1].item(), 6) | |
if value[1] == "general" and tag_value >= general_threshold: | |
general[tag] = tag_value | |
elif value[1] == "character" and tag_value >= character_threshold: | |
character[tag] = tag_value | |
elif value[1] == "artist" and tag_value >= artist_threshold: | |
artist[tag] = tag_value | |
elif value[1] == "rating": | |
rating[tag] = tag_value | |
elif value[1] == "date": | |
date[tag] = tag_value | |
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True)) | |
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True)) | |
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True)) | |
if date: | |
date = {max(date, key=date.get): date[max(date, key=date.get)]} | |
if rating: | |
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]} | |
return general, character, artist, date, rating | |
def process_image(image, general_threshold, character_threshold, artist_threshold): | |
try: | |
image = image.convert('RGBA') | |
background = Image.new('RGBA', image.size, (255, 255, 255, 255)) | |
image = Image.alpha_composite(background, image).convert('RGB') | |
width, height = image.size | |
if width < height: | |
height = int((shortest_edge / width) * height) | |
width = shortest_edge | |
else: | |
width = int((shortest_edge / height) * width) | |
height = shortest_edge | |
height = int(round(height / patch_size) * patch_size) | |
width = int(round(width / patch_size) * patch_size) | |
height = max(height, patch_size) | |
width = max(width, patch_size) | |
image = image.resize((width, height), Image.LANCZOS) | |
pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values | |
pixel_values = pixel_values.to(device).to(torch.bfloat16) | |
except (OSError, IOError) as e: | |
print(f"Error opening image: {e}") | |
return | |
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): | |
summary, features = model(pixel_values) | |
outputs = summary.to(torch.float32) | |
general_prediction = mlp_general(outputs) | |
general_ = prediction_to_tag(general_prediction, general_dict, general_class, general_threshold, character_threshold, artist_threshold) | |
general_tags = general_[0] | |
rating = general_[4] | |
character_prediction = mlp_character(outputs) | |
character_ = prediction_to_tag(character_prediction, character_dict, character_class, general_threshold, character_threshold, artist_threshold) | |
character_tags = character_[1] | |
artist_prediction = mlp_artist(outputs) | |
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class, general_threshold, character_threshold, artist_threshold) | |
artist_tags = artist_[2] | |
date = artist_[3] | |
ava_score = round(mlp_ava(outputs).item(), 3) | |
combined_tags = {**character_tags, **general_tags} | |
tags_list = [tag for tag in combined_tags] | |
remove_list = [] | |
for tag in tags_list: | |
if tag in implications_list: | |
for implication in implications_list[tag]: | |
remove_list.append(implication) | |
tags_list = [tag for tag in tags_list if tag not in remove_list] | |
tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list] | |
tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)") | |
return tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--slider-step", type=float, default=0.01) | |
parser.add_argument("--general-threshold", type=float, default=0.61) | |
parser.add_argument("--character-threshold", type=float, default=0.8) | |
parser.add_argument("--artist-threshold", type=float, default=0.68) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
with gr.Blocks(title=TITLE) as demo: | |
with gr.Column(): | |
gr.Markdown( | |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" | |
) | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
image = gr.Image(type="pil", image_mode="RGBA", label="Input") | |
with gr.Row(): | |
general_threshold = gr.Slider( | |
0, | |
1, | |
step=args.slider_step, | |
value=args.general_threshold, | |
label="General Threshold", | |
scale=3, | |
) | |
character_threshold = gr.Slider( | |
0, | |
1, | |
step=args.slider_step, | |
value=args.character_threshold, | |
label="Character Threshold", | |
scale=3, | |
) | |
artist_threshold = gr.Slider( | |
0, | |
1, | |
step=args.slider_step, | |
value=args.artist_threshold, | |
label="Artist Threshold", | |
scale=3, | |
) | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[ | |
image, | |
], | |
variant="secondary", | |
size="lg", | |
) | |
gr.Markdown(value=DESCRIPTION) | |
with gr.Column(variant="panel"): | |
tags_str = gr.Textbox(label="Output") | |
with gr.Row(): | |
ava_score = gr.Textbox(label="Aesthetic Score (AVA)") | |
with gr.Row(): | |
rating = gr.Label(label="Rating") | |
date = gr.Label(label="Year") | |
artist_tags = gr.Label(label="Artist") | |
character_tags = gr.Label(label="Character") | |
general_tags = gr.Label(label="General") | |
clear.add( | |
[ | |
tags_str, | |
artist_tags, | |
general_tags, | |
character_tags, | |
rating, | |
date, | |
ava_score | |
] | |
) | |
submit.click( | |
process_image, | |
inputs=[ | |
image, | |
general_threshold, | |
character_threshold, | |
artist_threshold | |
], | |
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score], | |
) | |
demo.queue(max_size=10) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |