ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
raw
history blame
7.95 kB
import os
import tqdm
import torch
import safetensors.torch
from torch import Tensor
from modules import shared
from modules import sd_models, sd_vae
# position_ids in clip is int64. model_ema.num_updates is int32
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
class MockModelInfo:
def __init__(self, model_path: str) -> None:
self.filepath = model_path
self.filename: str = os.path.basename(model_path)
self.model_name: str = self.filename.split(".")[0]
def conv_fp16(t: Tensor):
return t.half() if t.dtype in dtypes_to_fp16 else t
def conv_bf16(t: Tensor):
return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
def conv_full(t):
return t
_g_precision_func = {
"full": conv_full,
"fp32": conv_full,
"fp16": conv_fp16,
"bf16": conv_bf16,
}
def check_weight_type(k: str) -> str:
if k.startswith("model.diffusion_model"):
return "unet"
elif k.startswith("first_stage_model"):
return "vae"
elif k.startswith("cond_stage_model"):
return "clip"
return "other"
def load_model(path):
if path.endswith(".safetensors"):
m = safetensors.torch.load_file(path, device="cpu")
else:
m = torch.load(path, map_location="cpu")
state_dict = m["state_dict"] if "state_dict" in m else m
return state_dict
def fix_model(model, fix_clip=False, force_position_id=False):
# code from model-toolkit
nai_keys = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
}
position_id_key = "cond_stage_model.transformer.text_model.embeddings.position_ids"
for k in list(model.keys()):
for r in nai_keys:
if type(k) == str and k.startswith(r):
new_key = k.replace(r, nai_keys[r])
model[new_key] = model[k]
del model[k]
print(f"[Converter] Fixed novelai error key {k}")
break
if force_position_id and position_id_key in model:
model[position_id_key] = model[position_id_key].to(torch.int64)
if fix_clip:
if position_id_key in model:
correct = torch.Tensor([list(range(77))]).to(torch.int64)
now = model[position_id_key].to(torch.int64)
broken = correct.ne(now)
broken = [i for i in range(77) if broken[0][i]]
if len(broken) != 0:
model[position_id_key] = correct
print(f"[Converter] Fixed broken clip\n{broken}")
else:
print("[Converter] Clip in this model is fine, skip fixing...")
else:
print("[Converter] Missing position id in model, try fixing...")
model[position_id_key] = torch.Tensor([list(range(77))]).to(torch.int64)
return model
def convert_warp(
model_name, model_path, directory,
*args
):
if sum(map(bool, [model_name, model_path, directory])) != 1:
print("[Converter] Check your inputs. Multiple input was set or missing input")
return
if directory != "":
if not os.path.exists(directory) or not os.path.isdir(directory):
return "Error: path not exists or not dir"
files = [f for f in os.listdir(directory) if f.endswith(".ckpt") or f.endswith(".safetensors")]
if len(files) == 0:
return "Error: cant found model in directory"
# remove custom filename in batch processing
_args = list(args)
_args[3] = ""
for m in files:
do_convert(MockModelInfo(os.path.join(directory, m)), *_args)
elif model_path != "":
if os.path.exists(model_path):
return do_convert(MockModelInfo(model_path), *args)
elif model_name != "":
model_info = sd_models.checkpoints_list[model_name]
return do_convert(MockModelInfo(model_info.filename), *args)
else:
return "Error: must choose a model"
def do_convert(model_info: MockModelInfo,
checkpoint_formats,
precision, conv_type, custom_name,
bake_in_vae,
unet_conv, text_encoder_conv, vae_conv, others_conv,
fix_clip, force_position_id, delete_known_junk_data):
if len(checkpoint_formats) == 0:
return "Error: at least choose one model save format"
extra_opt = {
"unet": unet_conv,
"clip": text_encoder_conv,
"vae": vae_conv,
"other": others_conv
}
shared.state.begin()
shared.state.job = 'model-convert'
shared.state.textinfo = f"Loading {model_info.filename}..."
print(f"[Converter] Loading {model_info.filename}...")
ok = {}
state_dict = load_model(model_info.filepath)
fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id)
conv_func = _g_precision_func[precision]
def _hf(wk: str, t: Tensor):
if not isinstance(t, Tensor):
return
weight_type = check_weight_type(wk)
conv_t = extra_opt[weight_type]
if conv_t == "convert":
ok[wk] = conv_func(t)
elif conv_t == "copy":
ok[wk] = t
elif conv_t == "delete":
return
print("[Converter] Converting model...")
if conv_type == "ema-only":
for k in tqdm.tqdm(state_dict):
ema_k = "___"
try:
ema_k = "model_ema." + k[6:].replace(".", "")
except:
pass
if ema_k in state_dict:
_hf(k, state_dict[ema_k])
# print("ema: " + ema_k + " > " + k)
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
_hf(k, state_dict[k])
# print(k)
# else:
# print("skipped: " + k)
elif conv_type == "no-ema":
for k, v in tqdm.tqdm(state_dict.items()):
if "model_ema." not in k:
_hf(k, v)
else:
for k, v in tqdm.tqdm(state_dict.items()):
_hf(k, v)
if delete_known_junk_data:
known_junk_data_prefix = [
"embedding_manager.embedder.",
"lora_te_text_model",
"control_model."
]
need_delete = []
for key in ok.keys():
for jk in known_junk_data_prefix:
if key.startswith(jk):
need_delete.append(key)
for k in need_delete:
del ok[k]
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
if bake_in_vae_filename is not None:
print(f"[Converter] Baking in VAE from {bake_in_vae_filename}")
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
for k, v in vae_dict.items():
_hf(k, vae_dict[k])
del vae_dict
output = ""
ckpt_dir = os.path.dirname(model_info.filepath)
save_name = f"{model_info.model_name}-{precision}"
if conv_type != "disabled":
save_name += f"-{conv_type}"
if fix_clip:
save_name += f"-clip-fix"
if custom_name != "":
save_name = custom_name
for fmt in checkpoint_formats:
ext = ".safetensors" if fmt == "safetensors" else ".ckpt"
_save_name = save_name + ext
save_path = os.path.join(ckpt_dir, _save_name)
print(f"[Converter] Saving to {save_path}...")
if fmt == "safetensors":
safetensors.torch.save_file(ok, save_path)
else:
torch.save({"state_dict": ok}, save_path)
output += f"Checkpoint saved to {save_path}\n"
shared.state.end()
return output[:-1]