drhead commited on
Commit
b1fbce9
1 Parent(s): 78dd58b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -14
app.py CHANGED
@@ -117,28 +117,16 @@ transform = transforms.Compose([
117
  transforms.CenterCrop((384, 384)),
118
  ])
119
 
120
- model_file = huggingface_hub.hf_hub_download(
121
- repo_id="RedRocket/JointTaggerProject",
122
- filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors",
123
- subfolder="JTP_PILOT"
124
- )
125
-
126
  model = timm.create_model(
127
  "vit_so400m_patch14_siglip_384.webli",
128
  pretrained=False,
129
  num_classes=9083,
130
  ) # type: VisionTransformer
131
 
132
- safetensors.torch.load_model(model, model_file)
133
  model.eval()
134
 
135
- tags_file = huggingface_hub.hf_hub_download(
136
- repo_id="RedRocket/JointTaggerProject",
137
- filename="tags.json",
138
- subfolder="JTP_PILOT"
139
- )
140
-
141
- with open(tags_file, "r") as file:
142
  tags = json.load(file) # type: dict
143
  allowed_tags = tags.keys()
144
 
 
117
  transforms.CenterCrop((384, 384)),
118
  ])
119
 
 
 
 
 
 
 
120
  model = timm.create_model(
121
  "vit_so400m_patch14_siglip_384.webli",
122
  pretrained=False,
123
  num_classes=9083,
124
  ) # type: VisionTransformer
125
 
126
+ safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
127
  model.eval()
128
 
129
+ with open("tags.json", "r") as file:
 
 
 
 
 
 
130
  tags = json.load(file) # type: dict
131
  allowed_tags = tags.keys()
132