SmilingWolf
commited on
Commit
•
b079c7b
1
Parent(s):
4148c50
Update app.py
Browse filesAdd newly released ConvNextV2 model
app.py
CHANGED
@@ -21,6 +21,7 @@ DESCRIPTION = """
|
|
21 |
Demo for:
|
22 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
23 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
|
|
24 |
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
25 |
|
26 |
Includes "ready to copy" prompt and a prompt analyzer.
|
@@ -36,6 +37,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
|
36 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
37 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
38 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
|
|
39 |
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
40 |
MODEL_FILENAME = "model.onnx"
|
41 |
LABEL_FILENAME = "selected_tags.csv"
|
@@ -65,6 +67,8 @@ def change_model(model_name):
|
|
65 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
66 |
elif model_name == "ConvNext":
|
67 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
|
|
|
|
68 |
elif model_name == "ViT":
|
69 |
model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
|
70 |
|
@@ -74,7 +78,7 @@ def change_model(model_name):
|
|
74 |
|
75 |
def load_labels() -> list[str]:
|
76 |
path = huggingface_hub.hf_hub_download(
|
77 |
-
|
78 |
)
|
79 |
df = pd.read_csv(path)
|
80 |
|
@@ -209,11 +213,11 @@ def predict(
|
|
209 |
|
210 |
def main():
|
211 |
global loaded_models
|
212 |
-
loaded_models = {"SwinV2": None, "ConvNext": None, "ViT": None}
|
213 |
|
214 |
args = parse_args()
|
215 |
|
216 |
-
change_model("
|
217 |
|
218 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
219 |
|
@@ -229,7 +233,7 @@ def main():
|
|
229 |
fn=func,
|
230 |
inputs=[
|
231 |
gr.Image(type="pil", label="Input"),
|
232 |
-
gr.Radio(["SwinV2", "ConvNext", "ViT"], value="
|
233 |
gr.Slider(
|
234 |
0,
|
235 |
1,
|
@@ -253,7 +257,7 @@ def main():
|
|
253 |
gr.Label(label="Output (tags)"),
|
254 |
gr.HTML(),
|
255 |
],
|
256 |
-
examples=[["power.jpg", "
|
257 |
title=TITLE,
|
258 |
description=DESCRIPTION,
|
259 |
allow_flagging="never",
|
|
|
21 |
Demo for:
|
22 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
23 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
24 |
+
- [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
|
25 |
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
26 |
|
27 |
Includes "ready to copy" prompt and a prompt analyzer.
|
|
|
37 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
38 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
39 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
40 |
+
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
41 |
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
42 |
MODEL_FILENAME = "model.onnx"
|
43 |
LABEL_FILENAME = "selected_tags.csv"
|
|
|
67 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
68 |
elif model_name == "ConvNext":
|
69 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
70 |
+
elif model_name == "ConvNextV2":
|
71 |
+
model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
|
72 |
elif model_name == "ViT":
|
73 |
model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
|
74 |
|
|
|
78 |
|
79 |
def load_labels() -> list[str]:
|
80 |
path = huggingface_hub.hf_hub_download(
|
81 |
+
CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
|
82 |
)
|
83 |
df = pd.read_csv(path)
|
84 |
|
|
|
213 |
|
214 |
def main():
|
215 |
global loaded_models
|
216 |
+
loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}
|
217 |
|
218 |
args = parse_args()
|
219 |
|
220 |
+
change_model("ConvNextV2")
|
221 |
|
222 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
223 |
|
|
|
233 |
fn=func,
|
234 |
inputs=[
|
235 |
gr.Image(type="pil", label="Input"),
|
236 |
+
gr.Radio(["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
|
237 |
gr.Slider(
|
238 |
0,
|
239 |
1,
|
|
|
257 |
gr.Label(label="Output (tags)"),
|
258 |
gr.HTML(),
|
259 |
],
|
260 |
+
examples=[["power.jpg", "ConvNextV2", 0.35, 0.85]],
|
261 |
title=TITLE,
|
262 |
description=DESCRIPTION,
|
263 |
allow_flagging="never",
|