import argparse from transformers import AutoModel, CLIPImageProcessor import torch import json import torch.nn as nn from PIL import Image import gradio as gr TITLE = "Danbooru Tagger" DESCRIPTION = """ ## Dataset - Source: Cleaned Danbooru - Last Update: December 28, 2024 ## Metrics - Validation Split: 10% of images - Validation Results (Macro F1 Score): - General & Character: 0.4916 - Artist: 0.6677 """ kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] shortest_edge = 512 patch_size = 16 device = torch.device('cpu') image_processor = CLIPImageProcessor(crop_size={'height': 512, 'width': 512}, do_center_crop=False, do_convert_rgb=True, do_normalize=False, do_rescale=True, do_resize=False, image_processor_type="CLIPImageProcessor", processor_class="CLIPProcessor", resample=3, size={"shortest_edge": 512}) model = AutoModel.from_pretrained('nvidia/RADIO-H', trust_remote_code=True).to(device) model.eval() class MLP(nn.Module): def __init__(self, input_size, class_num): super().__init__() self.layers0 = nn.Sequential( nn.Linear(input_size, 1280), nn.LayerNorm(1280), nn.Mish() ) self.layers1 = nn.Sequential( nn.Linear(640, class_num), nn.Sigmoid() ) self.layers2 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2) ) self.layers3 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2) ) self.layers4 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2) ) def forward(self, x): out = self.layers0(x) out = self.layers2(out) + self.layers3(out) + self.layers4(out) out = self.layers1(out) return out with open('general_tag_dict.json', 'r', encoding='utf-8') as f: general_dict = json.load(f) with open('character_tag_dict.json', 'r', encoding='utf-8') as f: character_dict = json.load(f) with open('artist_tag_dict.json', 'r', encoding='utf-8') as f: artist_dict = json.load(f) with open('implications_list.json', 'r', encoding='utf-8') as f: implications_list = json.load(f) general_class = 9775 mlp_general = MLP(3840, general_class) general_s = torch.load("cls_predictor.pth", map_location=device) mlp_general.load_state_dict(general_s) mlp_general.to(device) mlp_general.eval() character_class = 7568 mlp_character = MLP(3840, character_class) character_s = torch.load("character_predictor.pth", map_location=device) mlp_character.load_state_dict(character_s) mlp_character.to(device) mlp_character.eval() artist_class = 13957 mlp_artist = MLP(3840, artist_class) artist_s = torch.load("artist_predictor.pth", map_location=device) mlp_artist.load_state_dict(artist_s) mlp_artist.to(device) mlp_artist.eval() class AES(nn.Module): def __init__(self, input_size): super().__init__() self.layers0 = nn.Sequential( nn.Linear(input_size, 1280), nn.LayerNorm(1280), nn.Mish() ) self.layers1 = nn.Sequential( nn.Sigmoid() ) self.layers2 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2), nn.Linear(640, 1) ) self.layers3 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2), nn.Linear(640, 1) ) self.layers4 = nn.Sequential( nn.Linear(1280, 640), nn.LayerNorm(640), nn.Mish(), nn.Dropout(0.2), nn.Linear(640, 1) ) def forward(self, x): out = self.layers0(x) out = self.layers2(out) + self.layers3(out) + self.layers4(out) out = self.layers1(out) return out * 10 mlp_ava = AES(3840) ava_s = torch.load("aesthetic_predictor_ava.pth", map_location=device) mlp_ava.load_state_dict(ava_s) mlp_ava.to(device) mlp_ava.eval() def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold): prediction = prediction.view(class_num) predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1 general = {} character = {} artist = {} date = {} rating = {} for tag, value in tag_dict.items(): if value[2] in predicted_ids: tag_value = round(prediction[value[2] - 1].item(), 6) if value[1] == "general" and tag_value >= general_threshold: general[tag] = tag_value elif value[1] == "character" and tag_value >= character_threshold: character[tag] = tag_value elif value[1] == "artist" and tag_value >= artist_threshold: artist[tag] = tag_value elif value[1] == "rating": rating[tag] = tag_value elif value[1] == "date": date[tag] = tag_value general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True)) character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True)) artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True)) if date: date = {max(date, key=date.get): date[max(date, key=date.get)]} if rating: rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]} return general, character, artist, date, rating def process_image(image, general_threshold, character_threshold, artist_threshold): try: image = image.convert('RGBA') background = Image.new('RGBA', image.size, (255, 255, 255, 255)) image = Image.alpha_composite(background, image).convert('RGB') width, height = image.size if width < height: height = int((shortest_edge / width) * height) width = shortest_edge else: width = int((shortest_edge / height) * width) height = shortest_edge height = int(round(height / patch_size) * patch_size) width = int(round(width / patch_size) * patch_size) height = max(height, patch_size) width = max(width, patch_size) image = image.resize((width, height), Image.LANCZOS) pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values pixel_values = pixel_values.to(device).to(torch.bfloat16) except (OSError, IOError) as e: print(f"Error opening image: {e}") return with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): summary, features = model(pixel_values) outputs = summary.to(torch.float32) general_prediction = mlp_general(outputs) general_ = prediction_to_tag(general_prediction, general_dict, general_class, general_threshold, character_threshold, artist_threshold) general_tags = general_[0] rating = general_[4] character_prediction = mlp_character(outputs) character_ = prediction_to_tag(character_prediction, character_dict, character_class, general_threshold, character_threshold, artist_threshold) character_tags = character_[1] artist_prediction = mlp_artist(outputs) artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class, general_threshold, character_threshold, artist_threshold) artist_tags = artist_[2] date = artist_[3] ava_score = round(mlp_ava(outputs).item(), 3) combined_tags = {**character_tags, **general_tags} tags_list = [tag for tag in combined_tags] remove_list = [] for tag in tags_list: if tag in implications_list: for implication in implications_list[tag]: remove_list.append(implication) tags_list = [tag for tag in tags_list if tag not in remove_list] tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list] tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)") return tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--slider-step", type=float, default=0.01) parser.add_argument("--general-threshold", type=float, default=0.61) parser.add_argument("--character-threshold", type=float, default=0.8) parser.add_argument("--artist-threshold", type=float, default=0.68) return parser.parse_args() def main(): args = parse_args() with gr.Blocks(title=TITLE) as demo: with gr.Column(): gr.Markdown( value=f"

{TITLE}

" ) with gr.Row(): with gr.Column(variant="panel"): submit = gr.Button(value="Submit", variant="primary", size="lg") image = gr.Image(type="pil", image_mode="RGBA", label="Input") with gr.Row(): general_threshold = gr.Slider( 0, 1, step=args.slider_step, value=args.general_threshold, label="General Threshold", scale=3, ) character_threshold = gr.Slider( 0, 1, step=args.slider_step, value=args.character_threshold, label="Character Threshold", scale=3, ) artist_threshold = gr.Slider( 0, 1, step=args.slider_step, value=args.artist_threshold, label="Artist Threshold", scale=3, ) with gr.Row(): clear = gr.ClearButton( components=[ image, ], variant="secondary", size="lg", ) gr.Markdown(value=DESCRIPTION) with gr.Column(variant="panel"): tags_str = gr.Textbox(label="Output") with gr.Row(): ava_score = gr.Textbox(label="Aesthetic Score (AVA)") with gr.Row(): rating = gr.Label(label="Rating") date = gr.Label(label="Year") artist_tags = gr.Label(label="Artist") character_tags = gr.Label(label="Character") general_tags = gr.Label(label="General") clear.add( [ tags_str, artist_tags, general_tags, character_tags, rating, date, ava_score ] ) submit.click( process_image, inputs=[ image, general_threshold, character_threshold, artist_threshold ], outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score], ) demo.queue(max_size=10) demo.launch() if __name__ == "__main__": main()