Spaces:
Running
Running
Upload 3 files
Browse files- aesthetic_predictor_ava.pth +3 -0
- app.py +73 -17
- cls_predictor.pth +1 -1
aesthetic_predictor_ava.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4363c5bfea88c84ae55a55be5ba4c11de4853a87cedb1253373e81b592e2598
|
3 |
+
size 29545526
|
app.py
CHANGED
@@ -8,7 +8,15 @@ import gradio as gr
|
|
8 |
|
9 |
TITLE = "Danbooru Tagger"
|
10 |
DESCRIPTION = """
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
|
14 |
kaomojis = [
|
@@ -112,6 +120,52 @@ mlp_artist.load_state_dict(artist_s)
|
|
112 |
mlp_artist.to(device)
|
113 |
mlp_artist.eval()
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
|
116 |
prediction = prediction.view(class_num)
|
117 |
predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
@@ -137,17 +191,15 @@ def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, charac
|
|
137 |
|
138 |
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
|
139 |
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
|
|
|
140 |
|
141 |
if date:
|
142 |
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
143 |
if rating:
|
144 |
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
|
145 |
-
if artist:
|
146 |
-
artist = {max(artist, key=artist.get): artist[max(artist, key=artist.get)]}
|
147 |
|
148 |
return general, character, artist, date, rating
|
149 |
|
150 |
-
|
151 |
def process_image(image, general_threshold, character_threshold, artist_threshold):
|
152 |
try:
|
153 |
image = image.convert('RGBA')
|
@@ -177,7 +229,7 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
|
|
177 |
print(f"Error opening image: {e}")
|
178 |
return
|
179 |
|
180 |
-
with torch.no_grad():
|
181 |
summary, features = model(pixel_values)
|
182 |
outputs = summary.to(torch.float32)
|
183 |
|
@@ -195,7 +247,9 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
|
|
195 |
artist_tags = artist_[2]
|
196 |
date = artist_[3]
|
197 |
|
198 |
-
|
|
|
|
|
199 |
|
200 |
tags_list = [tag for tag in combined_tags]
|
201 |
remove_list = []
|
@@ -208,12 +262,12 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
|
|
208 |
|
209 |
tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)")
|
210 |
|
211 |
-
return tags_str, artist_tags, character_tags, general_tags, rating, date
|
212 |
|
213 |
def parse_args() -> argparse.Namespace:
|
214 |
parser = argparse.ArgumentParser()
|
215 |
parser.add_argument("--slider-step", type=float, default=0.01)
|
216 |
-
parser.add_argument("--general-threshold", type=float, default=0.
|
217 |
parser.add_argument("--character-threshold", type=float, default=0.8)
|
218 |
parser.add_argument("--artist-threshold", type=float, default=0.68)
|
219 |
return parser.parse_args()
|
@@ -226,9 +280,9 @@ def main():
|
|
226 |
gr.Markdown(
|
227 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
228 |
)
|
229 |
-
gr.Markdown(value=DESCRIPTION)
|
230 |
with gr.Row():
|
231 |
with gr.Column(variant="panel"):
|
|
|
232 |
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
|
233 |
with gr.Row():
|
234 |
general_threshold = gr.Slider(
|
@@ -239,7 +293,6 @@ def main():
|
|
239 |
label="General Threshold",
|
240 |
scale=3,
|
241 |
)
|
242 |
-
with gr.Row():
|
243 |
character_threshold = gr.Slider(
|
244 |
0,
|
245 |
1,
|
@@ -248,7 +301,6 @@ def main():
|
|
248 |
label="Character Threshold",
|
249 |
scale=3,
|
250 |
)
|
251 |
-
with gr.Row():
|
252 |
artist_threshold = gr.Slider(
|
253 |
0,
|
254 |
1,
|
@@ -265,13 +317,16 @@ def main():
|
|
265 |
variant="secondary",
|
266 |
size="lg",
|
267 |
)
|
268 |
-
|
269 |
with gr.Column(variant="panel"):
|
270 |
tags_str = gr.Textbox(label="Output")
|
|
|
|
|
|
|
|
|
|
|
271 |
artist_tags = gr.Label(label="Artist")
|
272 |
-
character_tags = gr.Label(label="
|
273 |
-
rating = gr.Label(label="Rating")
|
274 |
-
date = gr.Label(label="Year")
|
275 |
general_tags = gr.Label(label="General")
|
276 |
clear.add(
|
277 |
[
|
@@ -280,7 +335,8 @@ def main():
|
|
280 |
general_tags,
|
281 |
character_tags,
|
282 |
rating,
|
283 |
-
date
|
|
|
284 |
]
|
285 |
)
|
286 |
|
@@ -292,7 +348,7 @@ def main():
|
|
292 |
character_threshold,
|
293 |
artist_threshold
|
294 |
],
|
295 |
-
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
|
296 |
)
|
297 |
|
298 |
demo.queue(max_size=10)
|
|
|
8 |
|
9 |
TITLE = "Danbooru Tagger"
|
10 |
DESCRIPTION = """
|
11 |
+
## Dataset
|
12 |
+
- Source: Cleaned Danbooru
|
13 |
+
- Last Update: December 28, 2024
|
14 |
+
|
15 |
+
## Metrics
|
16 |
+
- Validation Split: 10% of images
|
17 |
+
- Validation Results (Macro F1 Score):
|
18 |
+
- General & Character: 0.4916
|
19 |
+
- Artist: 0.6677
|
20 |
"""
|
21 |
|
22 |
kaomojis = [
|
|
|
120 |
mlp_artist.to(device)
|
121 |
mlp_artist.eval()
|
122 |
|
123 |
+
|
124 |
+
class AES(nn.Module):
|
125 |
+
def __init__(self, input_size):
|
126 |
+
super().__init__()
|
127 |
+
self.layers0 = nn.Sequential(
|
128 |
+
nn.Linear(input_size, 1280),
|
129 |
+
nn.LayerNorm(1280),
|
130 |
+
nn.Mish()
|
131 |
+
)
|
132 |
+
self.layers1 = nn.Sequential(
|
133 |
+
nn.Sigmoid()
|
134 |
+
)
|
135 |
+
self.layers2 = nn.Sequential(
|
136 |
+
nn.Linear(1280, 640),
|
137 |
+
nn.LayerNorm(640),
|
138 |
+
nn.Mish(),
|
139 |
+
nn.Dropout(0.2),
|
140 |
+
nn.Linear(640, 1)
|
141 |
+
)
|
142 |
+
self.layers3 = nn.Sequential(
|
143 |
+
nn.Linear(1280, 640),
|
144 |
+
nn.LayerNorm(640),
|
145 |
+
nn.Mish(),
|
146 |
+
nn.Dropout(0.2),
|
147 |
+
nn.Linear(640, 1)
|
148 |
+
)
|
149 |
+
self.layers4 = nn.Sequential(
|
150 |
+
nn.Linear(1280, 640),
|
151 |
+
nn.LayerNorm(640),
|
152 |
+
nn.Mish(),
|
153 |
+
nn.Dropout(0.2),
|
154 |
+
nn.Linear(640, 1)
|
155 |
+
)
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
out = self.layers0(x)
|
159 |
+
out = self.layers2(out) + self.layers3(out) + self.layers4(out)
|
160 |
+
out = self.layers1(out)
|
161 |
+
return out * 10
|
162 |
+
|
163 |
+
mlp_ava = AES(3840)
|
164 |
+
ava_s = torch.load("aesthetic_predictor_ava.pth", map_location=device)
|
165 |
+
mlp_ava.load_state_dict(ava_s)
|
166 |
+
mlp_ava.to(device)
|
167 |
+
mlp_ava.eval()
|
168 |
+
|
169 |
def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
|
170 |
prediction = prediction.view(class_num)
|
171 |
predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
|
|
191 |
|
192 |
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
|
193 |
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
|
194 |
+
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
|
195 |
|
196 |
if date:
|
197 |
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
198 |
if rating:
|
199 |
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
|
|
|
|
|
200 |
|
201 |
return general, character, artist, date, rating
|
202 |
|
|
|
203 |
def process_image(image, general_threshold, character_threshold, artist_threshold):
|
204 |
try:
|
205 |
image = image.convert('RGBA')
|
|
|
229 |
print(f"Error opening image: {e}")
|
230 |
return
|
231 |
|
232 |
+
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
233 |
summary, features = model(pixel_values)
|
234 |
outputs = summary.to(torch.float32)
|
235 |
|
|
|
247 |
artist_tags = artist_[2]
|
248 |
date = artist_[3]
|
249 |
|
250 |
+
ava_score = round(mlp_ava(outputs).item(), 3)
|
251 |
+
|
252 |
+
combined_tags = {**character_tags, **general_tags}
|
253 |
|
254 |
tags_list = [tag for tag in combined_tags]
|
255 |
remove_list = []
|
|
|
262 |
|
263 |
tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)")
|
264 |
|
265 |
+
return tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score
|
266 |
|
267 |
def parse_args() -> argparse.Namespace:
|
268 |
parser = argparse.ArgumentParser()
|
269 |
parser.add_argument("--slider-step", type=float, default=0.01)
|
270 |
+
parser.add_argument("--general-threshold", type=float, default=0.61)
|
271 |
parser.add_argument("--character-threshold", type=float, default=0.8)
|
272 |
parser.add_argument("--artist-threshold", type=float, default=0.68)
|
273 |
return parser.parse_args()
|
|
|
280 |
gr.Markdown(
|
281 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
282 |
)
|
|
|
283 |
with gr.Row():
|
284 |
with gr.Column(variant="panel"):
|
285 |
+
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
286 |
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
|
287 |
with gr.Row():
|
288 |
general_threshold = gr.Slider(
|
|
|
293 |
label="General Threshold",
|
294 |
scale=3,
|
295 |
)
|
|
|
296 |
character_threshold = gr.Slider(
|
297 |
0,
|
298 |
1,
|
|
|
301 |
label="Character Threshold",
|
302 |
scale=3,
|
303 |
)
|
|
|
304 |
artist_threshold = gr.Slider(
|
305 |
0,
|
306 |
1,
|
|
|
317 |
variant="secondary",
|
318 |
size="lg",
|
319 |
)
|
320 |
+
gr.Markdown(value=DESCRIPTION)
|
321 |
with gr.Column(variant="panel"):
|
322 |
tags_str = gr.Textbox(label="Output")
|
323 |
+
with gr.Row():
|
324 |
+
ava_score = gr.Textbox(label="Aesthetic Score (AVA)")
|
325 |
+
with gr.Row():
|
326 |
+
rating = gr.Label(label="Rating")
|
327 |
+
date = gr.Label(label="Year")
|
328 |
artist_tags = gr.Label(label="Artist")
|
329 |
+
character_tags = gr.Label(label="Character")
|
|
|
|
|
330 |
general_tags = gr.Label(label="General")
|
331 |
clear.add(
|
332 |
[
|
|
|
335 |
general_tags,
|
336 |
character_tags,
|
337 |
rating,
|
338 |
+
date,
|
339 |
+
ava_score
|
340 |
]
|
341 |
)
|
342 |
|
|
|
348 |
character_threshold,
|
349 |
artist_threshold
|
350 |
],
|
351 |
+
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score],
|
352 |
)
|
353 |
|
354 |
demo.queue(max_size=10)
|
cls_predictor.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 54599508
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5a6373053dad15af8b8cc2a6830bd04f67d35ff04acc5f071c34cb5d8c05305
|
3 |
size 54599508
|