import hashlib from io import BytesIO from typing import Optional import safetensors.torch import torch def model_hash(filename): """Old model hash used by stable-diffusion-webui""" try: with open(filename, "rb") as file: m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) return m.hexdigest()[0:8] except FileNotFoundError: return "NOFILE" except IsADirectoryError: # Linux? return "IsADirectory" except PermissionError: # Windows return "IsADirectory" def calculate_sha256(filename): """New model hash used by stable-diffusion-webui""" try: hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 with open(filename, "rb") as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() except FileNotFoundError: return "NOFILE" except IsADirectoryError: # Linux? return "IsADirectory" except PermissionError: # Windows return "IsADirectory" def addnet_hash_legacy(b): """Old model hash used by sd-webui-additional-networks for .safetensors format files""" m = hashlib.sha256() b.seek(0x100000) m.update(b.read(0x10000)) return m.hexdigest()[0:8] def addnet_hash_safetensors(b): """New model hash used by sd-webui-additional-networks for .safetensors format files""" hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 b.seek(0) header = b.read(8) n = int.from_bytes(header, "little") offset = n + 8 b.seek(offset) for chunk in iter(lambda: b.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() def precalculate_safetensors_hashes(tensors, metadata): """Precalculate the model hashes needed by sd-webui-additional-networks to save time on indexing the model later.""" # Because writing user metadata to the file can change the result of # sd_models.model_hash(), only retain the training metadata for purposes of # calculating the hash, as they are meant to be immutable metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} bytes = safetensors.torch.save(tensors, metadata) b = BytesIO(bytes) model_hash = addnet_hash_safetensors(b) legacy_hash = addnet_hash_legacy(b) return model_hash, legacy_hash def dtype_to_str(dtype: torch.dtype) -> str: # get name of the dtype dtype_name = str(dtype).split(".")[-1] return dtype_name def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ Convert a string to a torch.dtype Args: s: string representation of the dtype default_dtype: default dtype to return if s is None Returns: torch.dtype: the corresponding torch.dtype Raises: ValueError: if the dtype is not supported Examples: >>> str_to_dtype("float32") torch.float32 >>> str_to_dtype("fp32") torch.float32 >>> str_to_dtype("float16") torch.float16 >>> str_to_dtype("fp16") torch.float16 >>> str_to_dtype("bfloat16") torch.bfloat16 >>> str_to_dtype("bf16") torch.bfloat16 >>> str_to_dtype("fp8") torch.float8_e4m3fn >>> str_to_dtype("fp8_e4m3fn") torch.float8_e4m3fn >>> str_to_dtype("fp8_e4m3fnuz") torch.float8_e4m3fnuz >>> str_to_dtype("fp8_e5m2") torch.float8_e5m2 >>> str_to_dtype("fp8_e5m2fnuz") torch.float8_e5m2fnuz """ if s is None: return default_dtype if s in ["bf16", "bfloat16"]: return torch.bfloat16 elif s in ["fp16", "float16"]: return torch.float16 elif s in ["fp32", "float32", "float"]: return torch.float32 elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: return torch.float8_e4m3fn elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: return torch.float8_e4m3fnuz elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: return torch.float8_e5m2 elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: return torch.float8_e5m2fnuz elif s in ["fp8", "float8"]: return torch.float8_e4m3fn # default fp8 else: raise ValueError(f"Unsupported dtype: {s}")