import torch import json import struct from typing import Dict, Any, Union, Optional from safetensors.torch import load_file def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file """ _TYPES = { torch.float64: "F64", torch.float32: "F32", torch.float16: "F16", torch.bfloat16: "BF16", torch.int64: "I64", torch.int32: "I32", torch.int16: "I16", torch.int8: "I8", torch.uint8: "U8", torch.bool: "BOOL", getattr(torch, "float8_e5m2", None): "F8_E5M2", getattr(torch, "float8_e4m3fn", None): "F8_E4M3", } _ALIGN = 256 def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: validated = {} for key, value in metadata.items(): if not isinstance(key, str): raise ValueError(f"Metadata key must be a string, got {type(key)}") if not isinstance(value, str): print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") validated[key] = str(value) else: validated[key] = value return validated # print(f"Using memory efficient save file: {filename}") header = {} offset = 0 if metadata: header["__metadata__"] = validate_metadata(metadata) for k, v in tensors.items(): if v.numel() == 0: # empty tensor header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} else: size = v.numel() * v.element_size() header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} offset += size hjson = json.dumps(header).encode("utf-8") hjson += b" " * (-(len(hjson) + 8) % _ALIGN) with open(filename, "wb") as f: f.write(struct.pack(" Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] if offset_start == offset_end: tensor_bytes = None else: # adjust offset by header size self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack(" dict[str, torch.Tensor]: if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader # logger.info(f"Loading without mmap (experimental)") state_dict = {} with MemoryEfficientSafeOpen(path) as f: for key in f.keys(): state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) return state_dict else: try: state_dict = load_file(path, device=device) except: state_dict = load_file(path) # prevent device invalid Error if dtype is not None: for key in state_dict.keys(): state_dict[key] = state_dict[key].to(dtype=dtype) return state_dict