File size: 2,444 Bytes
2b1c7b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import base64
import pickle
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import safetensors.torch
import torch
from .aliases import PathOrStr
__all__ = [
"state_dict_to_safetensors_file",
"safetensors_file_to_state_dict",
]
@dataclass(eq=True, frozen=True)
class STKey:
keys: Tuple
value_is_pickled: bool
def encode_key(key: STKey) -> str:
b = pickle.dumps((key.keys, key.value_is_pickled))
b = base64.urlsafe_b64encode(b)
return str(b, "ASCII")
def decode_key(key: str) -> STKey:
b = base64.urlsafe_b64decode(key)
keys, value_is_pickled = pickle.loads(b)
return STKey(keys, value_is_pickled)
def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
result = {}
for key, value in d.items():
if isinstance(value, torch.Tensor):
result[STKey((key,), False)] = value
elif isinstance(value, dict):
value = flatten_dict(value)
for inner_key, inner_value in value.items():
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
else:
pickled = bytearray(pickle.dumps(value))
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
result[STKey((key,), True)] = pickled_tensor
return result
def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
result: Dict = {}
for key, value in d.items():
if key.value_is_pickled:
value = pickle.loads(value.numpy().data)
target_dict = result
for k in key.keys[:-1]:
new_target_dict = target_dict.get(k)
if new_target_dict is None:
new_target_dict = {}
target_dict[k] = new_target_dict
target_dict = new_target_dict
target_dict[key.keys[-1]] = value
return result
def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
state_dict = flatten_dict(state_dict)
state_dict = {encode_key(k): v for k, v in state_dict.items()}
safetensors.torch.save_file(state_dict, filename)
def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
if map_location is None:
map_location = "cpu"
state_dict = safetensors.torch.load_file(filename, device=map_location)
state_dict = {decode_key(k): v for k, v in state_dict.items()}
return unflatten_dict(state_dict)
|