Update app.py
Browse files
app.py
CHANGED
@@ -117,28 +117,16 @@ transform = transforms.Compose([
|
|
117 |
transforms.CenterCrop((384, 384)),
|
118 |
])
|
119 |
|
120 |
-
model_file = huggingface_hub.hf_hub_download(
|
121 |
-
repo_id="RedRocket/JointTaggerProject",
|
122 |
-
filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors",
|
123 |
-
subfolder="JTP_PILOT"
|
124 |
-
)
|
125 |
-
|
126 |
model = timm.create_model(
|
127 |
"vit_so400m_patch14_siglip_384.webli",
|
128 |
pretrained=False,
|
129 |
num_classes=9083,
|
130 |
) # type: VisionTransformer
|
131 |
|
132 |
-
safetensors.torch.load_model(model,
|
133 |
model.eval()
|
134 |
|
135 |
-
|
136 |
-
repo_id="RedRocket/JointTaggerProject",
|
137 |
-
filename="tags.json",
|
138 |
-
subfolder="JTP_PILOT"
|
139 |
-
)
|
140 |
-
|
141 |
-
with open(tags_file, "r") as file:
|
142 |
tags = json.load(file) # type: dict
|
143 |
allowed_tags = tags.keys()
|
144 |
|
|
|
117 |
transforms.CenterCrop((384, 384)),
|
118 |
])
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
model = timm.create_model(
|
121 |
"vit_so400m_patch14_siglip_384.webli",
|
122 |
pretrained=False,
|
123 |
num_classes=9083,
|
124 |
) # type: VisionTransformer
|
125 |
|
126 |
+
safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
|
127 |
model.eval()
|
128 |
|
129 |
+
with open("tags.json", "r") as file:
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
tags = json.load(file) # type: dict
|
131 |
allowed_tags = tags.keys()
|
132 |
|