OLMo-Bitnet-1B / safetensors_util.py
emozilla's picture
update inference code
2b1c7b3
import base64
import pickle
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import safetensors.torch
import torch
from .aliases import PathOrStr
__all__ = [
"state_dict_to_safetensors_file",
"safetensors_file_to_state_dict",
]
@dataclass(eq=True, frozen=True)
class STKey:
keys: Tuple
value_is_pickled: bool
def encode_key(key: STKey) -> str:
b = pickle.dumps((key.keys, key.value_is_pickled))
b = base64.urlsafe_b64encode(b)
return str(b, "ASCII")
def decode_key(key: str) -> STKey:
b = base64.urlsafe_b64decode(key)
keys, value_is_pickled = pickle.loads(b)
return STKey(keys, value_is_pickled)
def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
result = {}
for key, value in d.items():
if isinstance(value, torch.Tensor):
result[STKey((key,), False)] = value
elif isinstance(value, dict):
value = flatten_dict(value)
for inner_key, inner_value in value.items():
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
else:
pickled = bytearray(pickle.dumps(value))
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
result[STKey((key,), True)] = pickled_tensor
return result
def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
result: Dict = {}
for key, value in d.items():
if key.value_is_pickled:
value = pickle.loads(value.numpy().data)
target_dict = result
for k in key.keys[:-1]:
new_target_dict = target_dict.get(k)
if new_target_dict is None:
new_target_dict = {}
target_dict[k] = new_target_dict
target_dict = new_target_dict
target_dict[key.keys[-1]] = value
return result
def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
state_dict = flatten_dict(state_dict)
state_dict = {encode_key(k): v for k, v in state_dict.items()}
safetensors.torch.save_file(state_dict, filename)
def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
if map_location is None:
map_location = "cpu"
state_dict = safetensors.torch.load_file(filename, device=map_location)
state_dict = {decode_key(k): v for k, v in state_dict.items()}
return unflatten_dict(state_dict)