Johnny-Z commited on
Commit
6afc5c1
1 Parent(s): cd78c66

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +335 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+ import gradio as gr
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ import pandas as pd
9
+ from PIL import Image
10
+
11
+ TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ """
14
+
15
+ #HF_TOKEN = os.environ["HF_TOKEN"]
16
+
17
+ # Dataset v3 series of models:
18
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
19
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
20
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
21
+
22
+ # Dataset v2 series of models:
23
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
24
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
25
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
26
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
27
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
28
+
29
+ # Files to download from the repos
30
+ MODEL_FILENAME = "model.onnx"
31
+ LABEL_FILENAME = "selected_tags.csv"
32
+
33
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
34
+ kaomojis = [
35
+ "0_0",
36
+ "(o)_(o)",
37
+ "+_+",
38
+ "+_-",
39
+ "._.",
40
+ "<o>_<o>",
41
+ "<|>_<|>",
42
+ "=_=",
43
+ ">_<",
44
+ "3_3",
45
+ "6_9",
46
+ ">_o",
47
+ "@_@",
48
+ "^_^",
49
+ "o_o",
50
+ "u_u",
51
+ "x_x",
52
+ "|_|",
53
+ "||_||",
54
+ ]
55
+
56
+
57
+ def parse_args() -> argparse.Namespace:
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
60
+ parser.add_argument("--score-general-threshold", type=float, default=0.4)
61
+ parser.add_argument("--score-character-threshold", type=float, default=0.9)
62
+ parser.add_argument("--share", action="store_true")
63
+ return parser.parse_args()
64
+
65
+
66
+ def load_labels(dataframe) -> list[str]:
67
+ name_series = dataframe["name"]
68
+ #name_series = name_series.map(
69
+ # lambda x: x.replace("_", " ") if x not in kaomojis else x
70
+ #)
71
+ tag_names = name_series.tolist()
72
+
73
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
74
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
75
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
76
+ return tag_names, rating_indexes, general_indexes, character_indexes
77
+
78
+
79
+ def mcut_threshold(probs):
80
+ """
81
+ Maximum Cut Thresholding (MCut)
82
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
83
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
84
+ (pp. 172-183).
85
+ """
86
+ sorted_probs = probs[probs.argsort()[::-1]]
87
+ difs = sorted_probs[:-1] - sorted_probs[1:]
88
+ t = difs.argmax()
89
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
90
+ return thresh
91
+
92
+
93
+ class Predictor:
94
+ def __init__(self):
95
+ self.model_target_size = None
96
+ self.last_loaded_repo = None
97
+
98
+ def download_model(self, model_repo):
99
+ csv_path = huggingface_hub.hf_hub_download(
100
+ model_repo,
101
+ LABEL_FILENAME,
102
+ #use_auth_token=HF_TOKEN,
103
+ )
104
+ model_path = huggingface_hub.hf_hub_download(
105
+ model_repo,
106
+ MODEL_FILENAME,
107
+ #use_auth_token=HF_TOKEN,
108
+ )
109
+ return csv_path, model_path
110
+
111
+ def load_model(self, model_repo):
112
+ if model_repo == self.last_loaded_repo:
113
+ return
114
+
115
+ csv_path, model_path = self.download_model(model_repo)
116
+
117
+ tags_df = pd.read_csv(csv_path)
118
+ sep_tags = load_labels(tags_df)
119
+
120
+ self.tag_names = sep_tags[0]
121
+ self.rating_indexes = sep_tags[1]
122
+ self.general_indexes = sep_tags[2]
123
+ self.character_indexes = sep_tags[3]
124
+
125
+ model = rt.InferenceSession(model_path)
126
+ _, height, width, _ = model.get_inputs()[0].shape
127
+ self.model_target_size = height
128
+
129
+ self.last_loaded_repo = model_repo
130
+ self.model = model
131
+
132
+ def prepare_image(self, image):
133
+ target_size = self.model_target_size
134
+
135
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
136
+ canvas.alpha_composite(image)
137
+ image = canvas.convert("RGB")
138
+
139
+ # Pad image to square
140
+ image_shape = image.size
141
+ max_dim = max(image_shape)
142
+ pad_left = (max_dim - image_shape[0]) // 2
143
+ pad_top = (max_dim - image_shape[1]) // 2
144
+
145
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
146
+ padded_image.paste(image, (pad_left, pad_top))
147
+
148
+ # Resize
149
+ if max_dim != target_size:
150
+ padded_image = padded_image.resize(
151
+ (target_size, target_size),
152
+ Image.BICUBIC,
153
+ )
154
+
155
+ # Convert to numpy array
156
+ image_array = np.asarray(padded_image, dtype=np.float32)
157
+
158
+ # Convert PIL-native RGB to BGR
159
+ image_array = image_array[:, :, ::-1]
160
+
161
+ return np.expand_dims(image_array, axis=0)
162
+
163
+ def predict(
164
+ self,
165
+ image,
166
+ model_repo,
167
+ general_thresh,
168
+ general_mcut_enabled,
169
+ character_thresh,
170
+ character_mcut_enabled,
171
+ ):
172
+ self.load_model(model_repo)
173
+
174
+ image = self.prepare_image(image)
175
+
176
+ input_name = self.model.get_inputs()[0].name
177
+ label_name = self.model.get_outputs()[0].name
178
+ preds = self.model.run([label_name], {input_name: image})[0]
179
+
180
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
181
+
182
+ # First 4 labels are actually ratings: pick one with argmax
183
+ ratings_names = [labels[i] for i in self.rating_indexes]
184
+ rating = dict(ratings_names)
185
+
186
+ # Then we have general tags: pick any where prediction confidence > threshold
187
+ general_names = [labels[i] for i in self.general_indexes]
188
+
189
+ if general_mcut_enabled:
190
+ general_probs = np.array([x[1] for x in general_names])
191
+ general_thresh = mcut_threshold(general_probs)
192
+
193
+ general_res = [x for x in general_names if x[1] > general_thresh]
194
+ general_res = dict(general_res)
195
+
196
+ with open('./implications_list.json', 'r') as f:
197
+ implications_list = json.load(f)
198
+
199
+ to_delete = set()
200
+ for key in general_res.keys():
201
+ if key in implications_list:
202
+ to_delete.update(implications_list[key])
203
+
204
+ for key in to_delete:
205
+ general_res.pop(key, None)
206
+
207
+ # Everything else is characters: pick any where prediction confidence > threshold
208
+ character_names = [labels[i] for i in self.character_indexes]
209
+
210
+ if character_mcut_enabled:
211
+ character_probs = np.array([x[1] for x in character_names])
212
+ character_thresh = mcut_threshold(character_probs)
213
+ character_thresh = max(0.15, character_thresh)
214
+
215
+ character_res = [x for x in character_names if x[1] > character_thresh]
216
+ character_res = dict(character_res)
217
+
218
+ sorted_general_strings = sorted(
219
+ general_res.items(),
220
+ key=lambda x: x[1],
221
+ reverse=True,
222
+ )
223
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
224
+ sorted_general_strings = (
225
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
226
+ )
227
+
228
+ return sorted_general_strings, rating, character_res, general_res
229
+
230
+
231
+ def main():
232
+ args = parse_args()
233
+
234
+ predictor = Predictor()
235
+
236
+ dropdown_list = [
237
+ SWINV2_MODEL_DSV3_REPO,
238
+ CONV_MODEL_DSV3_REPO,
239
+ VIT_MODEL_DSV3_REPO,
240
+ MOAT_MODEL_DSV2_REPO,
241
+ SWIN_MODEL_DSV2_REPO,
242
+ CONV_MODEL_DSV2_REPO,
243
+ CONV2_MODEL_DSV2_REPO,
244
+ VIT_MODEL_DSV2_REPO,
245
+ ]
246
+
247
+ with gr.Blocks(title=TITLE) as demo:
248
+ with gr.Column():
249
+ gr.Markdown(
250
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
251
+ )
252
+ gr.Markdown(value=DESCRIPTION)
253
+ with gr.Row():
254
+ with gr.Column(variant="panel"):
255
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
256
+ model_repo = gr.Dropdown(
257
+ dropdown_list,
258
+ value=SWINV2_MODEL_DSV3_REPO,
259
+ label="Model",
260
+ )
261
+ with gr.Row():
262
+ general_thresh = gr.Slider(
263
+ 0,
264
+ 1,
265
+ step=args.score_slider_step,
266
+ value=args.score_general_threshold,
267
+ label="General Tags Threshold",
268
+ scale=3,
269
+ )
270
+ general_mcut_enabled = gr.Checkbox(
271
+ value=False,
272
+ label="Use MCut threshold",
273
+ scale=1,
274
+ )
275
+ with gr.Row():
276
+ character_thresh = gr.Slider(
277
+ 0,
278
+ 1,
279
+ step=args.score_slider_step,
280
+ value=args.score_character_threshold,
281
+ label="Character Tags Threshold",
282
+ scale=3,
283
+ )
284
+ character_mcut_enabled = gr.Checkbox(
285
+ value=False,
286
+ label="Use MCut threshold",
287
+ scale=1,
288
+ )
289
+ with gr.Row():
290
+ clear = gr.ClearButton(
291
+ components=[
292
+ image,
293
+ model_repo,
294
+ general_thresh,
295
+ general_mcut_enabled,
296
+ character_thresh,
297
+ character_mcut_enabled,
298
+ ],
299
+ variant="secondary",
300
+ size="lg",
301
+ )
302
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
303
+ with gr.Column(variant="panel"):
304
+ sorted_general_strings = gr.Textbox(label="Output (string)")
305
+ rating = gr.Label(label="Rating")
306
+ character_res = gr.Label(label="Output (characters)")
307
+ general_res = gr.Label(label="Output (tags)")
308
+ clear.add(
309
+ [
310
+ sorted_general_strings,
311
+ rating,
312
+ character_res,
313
+ general_res,
314
+ ]
315
+ )
316
+
317
+ submit.click(
318
+ predictor.predict,
319
+ inputs=[
320
+ image,
321
+ model_repo,
322
+ general_thresh,
323
+ general_mcut_enabled,
324
+ character_thresh,
325
+ character_mcut_enabled,
326
+ ],
327
+ outputs=[sorted_general_strings, rating, character_res, general_res],
328
+ )
329
+
330
+ demo.queue(max_size=10)
331
+ demo.launch(server_port=8000, server_name="0.0.0.0")
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pillow>=9.0.0
2
+ onnxruntime>=1.12.0
3
+ huggingface-hub