Kaeya commited on
Commit
9153408
·
1 Parent(s): 668eb66

Update safetensors_converter.py

Browse files
Files changed (1) hide show
  1. safetensors_converter.py +21 -6
safetensors_converter.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from safetensors.torch import save_file
6
 
7
 
8
- def convert(path: Path):
9
  state_dict = torch.load(path, map_location="cpu")
10
  if "state_dict" in state_dict:
11
  state_dict = state_dict["state_dict"]
@@ -14,14 +14,25 @@ def convert(path: Path):
14
  for k, v in state_dict.items():
15
  if not isinstance(v, torch.Tensor):
16
  to_remove.append(k)
 
 
 
17
  for k in to_remove:
18
  del state_dict[k]
19
 
20
- output_path = path.with_suffix(".safetensors").as_posix()
21
- save_file(state_dict, output_path)
 
 
 
 
 
 
 
 
22
 
23
 
24
- def main(path: str):
25
  path_ = Path(path).resolve()
26
 
27
  if not path_.exists():
@@ -36,15 +47,19 @@ def main(path: str):
36
  if file.with_suffix(".safetensors").exists():
37
  continue
38
  print(f"Converting... {file}")
39
- convert(file)
40
 
41
 
42
  def parse_args():
43
  parser = argparse.ArgumentParser()
44
  parser.add_argument("path", type=str, help="Path to checkpoint file or directory.")
 
 
 
 
45
  return parser.parse_args()
46
 
47
 
48
  if __name__ == "__main__":
49
  args = parse_args()
50
- main(args.path)
 
5
  from safetensors.torch import save_file
6
 
7
 
8
+ def convert(path: Path, half: bool = False, no_ema: bool = False):
9
  state_dict = torch.load(path, map_location="cpu")
10
  if "state_dict" in state_dict:
11
  state_dict = state_dict["state_dict"]
 
14
  for k, v in state_dict.items():
15
  if not isinstance(v, torch.Tensor):
16
  to_remove.append(k)
17
+ if no_ema and "ema" in k:
18
+ to_remove.append(k)
19
+
20
  for k in to_remove:
21
  del state_dict[k]
22
 
23
+ if half:
24
+ state_dict = {k: v.half() for k, v in state_dict.items()}
25
+
26
+ output_name = path.stem
27
+ if no_ema:
28
+ output_name += "-pruned"
29
+ if half:
30
+ output_name += "-fp16"
31
+ output_path = path.parent / f"{output_name}.safetensors"
32
+ save_file(state_dict, output_path.as_posix())
33
 
34
 
35
+ def main(path: str, half: bool = False, no_ema: bool = False):
36
  path_ = Path(path).resolve()
37
 
38
  if not path_.exists():
 
47
  if file.with_suffix(".safetensors").exists():
48
  continue
49
  print(f"Converting... {file}")
50
+ convert(file, half, no_ema)
51
 
52
 
53
  def parse_args():
54
  parser = argparse.ArgumentParser()
55
  parser.add_argument("path", type=str, help="Path to checkpoint file or directory.")
56
+ parser.add_argument(
57
+ "--half", action="store_true", help="Convert to half precision."
58
+ )
59
+ parser.add_argument("--no-ema", action="store_true", help="Ignore EMA weights.")
60
  return parser.parse_args()
61
 
62
 
63
  if __name__ == "__main__":
64
  args = parse_args()
65
+ main(args.path, args.half, args.no_ema)