File size: 1,280 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from model_oml import EmbeddingModelOML
from huggingface_hub import HfApi
import argparse

parser = argparse.ArgumentParser("Packing checkpoint to JIT and serving to HF repo",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--upload-to-hf", action="store_true", help="Whether to upload model to hf hub, REQUIRES LOGGING")
parser.add_argument("--path-to-save", help="Where to save the model file", default="../models/")
parser.add_argument("--model-name", help="Which model name to save in folder", default="vits8stamp-torchscript.pth")
parser.add_argument("--repo-id", help="repository id on huggingface", default="stamps-labs/vits8-stamp")

args = parser.parse_args()
config = vars(args)

if __name__ == "__main__":
    model = EmbeddingModelOML().extractor.cuda()

    model.eval()

    with torch.no_grad():
        model_ts = torch.jit.script(model)

    model_ts.save(config["path_to_save"]+config["model_name"])
    if config["upload_to_hf"]: 
        api = HfApi()
        api.upload_file(
            path_or_fileobj=config["path_to_save"]+config["model_name"],
            path_in_repo=config["model_name"],
            repo_id=config["repo_id"],
            repo_type="model"
            )