import argparse from pathlib import Path import torch from safetensors.torch import save_file def convert(path: Path): state_dict = torch.load(path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] to_remove = [] for k, v in state_dict.items(): if not isinstance(v, torch.Tensor): to_remove.append(k) for k in to_remove: del state_dict[k] output_path = path.with_suffix(".safetensors").as_posix() save_file(state_dict, output_path) def main(path: str): path_ = Path(path).resolve() if not path_.exists(): raise ValueError(f"Invalid path: {path}") if path_.is_file(): to_convert = [path_] else: to_convert = list(path_.glob("*.ckpt")) for file in to_convert: if file.with_suffix(".safetensors").exists(): continue print(f"Converting... {file}") convert(file) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("path", type=str, help="Path to checkpoint file or directory.") return parser.parse_args() if __name__ == "__main__": args = parse_args() main(args.path)