Spaces:
Runtime error
Runtime error
import torch | |
from model_oml import EmbeddingModelOML | |
from huggingface_hub import HfApi | |
import argparse | |
parser = argparse.ArgumentParser("Packing checkpoint to JIT and serving to HF repo", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("--upload-to-hf", action="store_true", help="Whether to upload model to hf hub, REQUIRES LOGGING") | |
parser.add_argument("--path-to-save", help="Where to save the model file", default="../models/") | |
parser.add_argument("--model-name", help="Which model name to save in folder", default="vits8stamp-torchscript.pth") | |
parser.add_argument("--repo-id", help="repository id on huggingface", default="stamps-labs/vits8-stamp") | |
args = parser.parse_args() | |
config = vars(args) | |
if __name__ == "__main__": | |
model = EmbeddingModelOML().extractor.cuda() | |
model.eval() | |
with torch.no_grad(): | |
model_ts = torch.jit.script(model) | |
model_ts.save(config["path_to_save"]+config["model_name"]) | |
if config["upload_to_hf"]: | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=config["path_to_save"]+config["model_name"], | |
path_in_repo=config["model_name"], | |
repo_id=config["repo_id"], | |
repo_type="model" | |
) | |