Johnny-Z commited on
Commit
16cf4a9
·
verified ·
1 Parent(s): fe53e5c

Upload 3 files

Browse files
Files changed (3) hide show
  1. aesthetic_predictor_ava.pth +3 -0
  2. app.py +73 -17
  3. cls_predictor.pth +1 -1
aesthetic_predictor_ava.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4363c5bfea88c84ae55a55be5ba4c11de4853a87cedb1253373e81b592e2598
3
+ size 29545526
app.py CHANGED
@@ -8,7 +8,15 @@ import gradio as gr
8
 
9
  TITLE = "Danbooru Tagger"
10
  DESCRIPTION = """
11
- Macro F1 (General & Character): 0.4937
 
 
 
 
 
 
 
 
12
  """
13
 
14
  kaomojis = [
@@ -112,6 +120,52 @@ mlp_artist.load_state_dict(artist_s)
112
  mlp_artist.to(device)
113
  mlp_artist.eval()
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
116
  prediction = prediction.view(class_num)
117
  predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
@@ -137,17 +191,15 @@ def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, charac
137
 
138
  general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
139
  character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
 
140
 
141
  if date:
142
  date = {max(date, key=date.get): date[max(date, key=date.get)]}
143
  if rating:
144
  rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
145
- if artist:
146
- artist = {max(artist, key=artist.get): artist[max(artist, key=artist.get)]}
147
 
148
  return general, character, artist, date, rating
149
 
150
-
151
  def process_image(image, general_threshold, character_threshold, artist_threshold):
152
  try:
153
  image = image.convert('RGBA')
@@ -177,7 +229,7 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
177
  print(f"Error opening image: {e}")
178
  return
179
 
180
- with torch.no_grad():
181
  summary, features = model(pixel_values)
182
  outputs = summary.to(torch.float32)
183
 
@@ -195,7 +247,9 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
195
  artist_tags = artist_[2]
196
  date = artist_[3]
197
 
198
- combined_tags = {**artist_tags, **character_tags, **general_tags}
 
 
199
 
200
  tags_list = [tag for tag in combined_tags]
201
  remove_list = []
@@ -208,12 +262,12 @@ def process_image(image, general_threshold, character_threshold, artist_threshol
208
 
209
  tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)")
210
 
211
- return tags_str, artist_tags, character_tags, general_tags, rating, date
212
 
213
  def parse_args() -> argparse.Namespace:
214
  parser = argparse.ArgumentParser()
215
  parser.add_argument("--slider-step", type=float, default=0.01)
216
- parser.add_argument("--general-threshold", type=float, default=0.5)
217
  parser.add_argument("--character-threshold", type=float, default=0.8)
218
  parser.add_argument("--artist-threshold", type=float, default=0.68)
219
  return parser.parse_args()
@@ -226,9 +280,9 @@ def main():
226
  gr.Markdown(
227
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
228
  )
229
- gr.Markdown(value=DESCRIPTION)
230
  with gr.Row():
231
  with gr.Column(variant="panel"):
 
232
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
233
  with gr.Row():
234
  general_threshold = gr.Slider(
@@ -239,7 +293,6 @@ def main():
239
  label="General Threshold",
240
  scale=3,
241
  )
242
- with gr.Row():
243
  character_threshold = gr.Slider(
244
  0,
245
  1,
@@ -248,7 +301,6 @@ def main():
248
  label="Character Threshold",
249
  scale=3,
250
  )
251
- with gr.Row():
252
  artist_threshold = gr.Slider(
253
  0,
254
  1,
@@ -265,13 +317,16 @@ def main():
265
  variant="secondary",
266
  size="lg",
267
  )
268
- submit = gr.Button(value="Submit", variant="primary", size="lg")
269
  with gr.Column(variant="panel"):
270
  tags_str = gr.Textbox(label="Output")
 
 
 
 
 
271
  artist_tags = gr.Label(label="Artist")
272
- character_tags = gr.Label(label="Characters")
273
- rating = gr.Label(label="Rating")
274
- date = gr.Label(label="Year")
275
  general_tags = gr.Label(label="General")
276
  clear.add(
277
  [
@@ -280,7 +335,8 @@ def main():
280
  general_tags,
281
  character_tags,
282
  rating,
283
- date
 
284
  ]
285
  )
286
 
@@ -292,7 +348,7 @@ def main():
292
  character_threshold,
293
  artist_threshold
294
  ],
295
- outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
296
  )
297
 
298
  demo.queue(max_size=10)
 
8
 
9
  TITLE = "Danbooru Tagger"
10
  DESCRIPTION = """
11
+ ## Dataset
12
+ - Source: Cleaned Danbooru
13
+ - Last Update: December 28, 2024
14
+
15
+ ## Metrics
16
+ - Validation Split: 10% of images
17
+ - Validation Results (Macro F1 Score):
18
+ - General & Character: 0.4916
19
+ - Artist: 0.6677
20
  """
21
 
22
  kaomojis = [
 
120
  mlp_artist.to(device)
121
  mlp_artist.eval()
122
 
123
+
124
+ class AES(nn.Module):
125
+ def __init__(self, input_size):
126
+ super().__init__()
127
+ self.layers0 = nn.Sequential(
128
+ nn.Linear(input_size, 1280),
129
+ nn.LayerNorm(1280),
130
+ nn.Mish()
131
+ )
132
+ self.layers1 = nn.Sequential(
133
+ nn.Sigmoid()
134
+ )
135
+ self.layers2 = nn.Sequential(
136
+ nn.Linear(1280, 640),
137
+ nn.LayerNorm(640),
138
+ nn.Mish(),
139
+ nn.Dropout(0.2),
140
+ nn.Linear(640, 1)
141
+ )
142
+ self.layers3 = nn.Sequential(
143
+ nn.Linear(1280, 640),
144
+ nn.LayerNorm(640),
145
+ nn.Mish(),
146
+ nn.Dropout(0.2),
147
+ nn.Linear(640, 1)
148
+ )
149
+ self.layers4 = nn.Sequential(
150
+ nn.Linear(1280, 640),
151
+ nn.LayerNorm(640),
152
+ nn.Mish(),
153
+ nn.Dropout(0.2),
154
+ nn.Linear(640, 1)
155
+ )
156
+
157
+ def forward(self, x):
158
+ out = self.layers0(x)
159
+ out = self.layers2(out) + self.layers3(out) + self.layers4(out)
160
+ out = self.layers1(out)
161
+ return out * 10
162
+
163
+ mlp_ava = AES(3840)
164
+ ava_s = torch.load("aesthetic_predictor_ava.pth", map_location=device)
165
+ mlp_ava.load_state_dict(ava_s)
166
+ mlp_ava.to(device)
167
+ mlp_ava.eval()
168
+
169
  def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
170
  prediction = prediction.view(class_num)
171
  predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
 
191
 
192
  general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
193
  character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
194
+ artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
195
 
196
  if date:
197
  date = {max(date, key=date.get): date[max(date, key=date.get)]}
198
  if rating:
199
  rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
 
 
200
 
201
  return general, character, artist, date, rating
202
 
 
203
  def process_image(image, general_threshold, character_threshold, artist_threshold):
204
  try:
205
  image = image.convert('RGBA')
 
229
  print(f"Error opening image: {e}")
230
  return
231
 
232
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
233
  summary, features = model(pixel_values)
234
  outputs = summary.to(torch.float32)
235
 
 
247
  artist_tags = artist_[2]
248
  date = artist_[3]
249
 
250
+ ava_score = round(mlp_ava(outputs).item(), 3)
251
+
252
+ combined_tags = {**character_tags, **general_tags}
253
 
254
  tags_list = [tag for tag in combined_tags]
255
  remove_list = []
 
262
 
263
  tags_str = ", ".join(tags_list).replace("(", "\(").replace(")", "\)")
264
 
265
+ return tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score
266
 
267
  def parse_args() -> argparse.Namespace:
268
  parser = argparse.ArgumentParser()
269
  parser.add_argument("--slider-step", type=float, default=0.01)
270
+ parser.add_argument("--general-threshold", type=float, default=0.61)
271
  parser.add_argument("--character-threshold", type=float, default=0.8)
272
  parser.add_argument("--artist-threshold", type=float, default=0.68)
273
  return parser.parse_args()
 
280
  gr.Markdown(
281
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
282
  )
 
283
  with gr.Row():
284
  with gr.Column(variant="panel"):
285
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
286
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
287
  with gr.Row():
288
  general_threshold = gr.Slider(
 
293
  label="General Threshold",
294
  scale=3,
295
  )
 
296
  character_threshold = gr.Slider(
297
  0,
298
  1,
 
301
  label="Character Threshold",
302
  scale=3,
303
  )
 
304
  artist_threshold = gr.Slider(
305
  0,
306
  1,
 
317
  variant="secondary",
318
  size="lg",
319
  )
320
+ gr.Markdown(value=DESCRIPTION)
321
  with gr.Column(variant="panel"):
322
  tags_str = gr.Textbox(label="Output")
323
+ with gr.Row():
324
+ ava_score = gr.Textbox(label="Aesthetic Score (AVA)")
325
+ with gr.Row():
326
+ rating = gr.Label(label="Rating")
327
+ date = gr.Label(label="Year")
328
  artist_tags = gr.Label(label="Artist")
329
+ character_tags = gr.Label(label="Character")
 
 
330
  general_tags = gr.Label(label="General")
331
  clear.add(
332
  [
 
335
  general_tags,
336
  character_tags,
337
  rating,
338
+ date,
339
+ ava_score
340
  ]
341
  )
342
 
 
348
  character_threshold,
349
  artist_threshold
350
  ],
351
+ outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date, ava_score],
352
  )
353
 
354
  demo.queue(max_size=10)
cls_predictor.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b0bb58f320b941f20d9c3b9e3af4dc87780d9cf3f9d50be7a72b684028cd7763
3
  size 54599508
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5a6373053dad15af8b8cc2a6830bd04f67d35ff04acc5f071c34cb5d8c05305
3
  size 54599508