|
import os |
|
import tqdm |
|
import torch |
|
import safetensors.torch |
|
from torch import Tensor |
|
from modules import shared |
|
from modules import sd_models, sd_vae |
|
|
|
|
|
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} |
|
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} |
|
|
|
|
|
class MockModelInfo: |
|
def __init__(self, model_path: str) -> None: |
|
self.filepath = model_path |
|
self.filename: str = os.path.basename(model_path) |
|
self.model_name: str = self.filename.split(".")[0] |
|
|
|
|
|
def conv_fp16(t: Tensor): |
|
return t.half() if t.dtype in dtypes_to_fp16 else t |
|
|
|
|
|
def conv_bf16(t: Tensor): |
|
return t.bfloat16() if t.dtype in dtypes_to_bf16 else t |
|
|
|
|
|
def conv_full(t): |
|
return t |
|
|
|
|
|
_g_precision_func = { |
|
"full": conv_full, |
|
"fp32": conv_full, |
|
"fp16": conv_fp16, |
|
"bf16": conv_bf16, |
|
} |
|
|
|
|
|
def check_weight_type(k: str) -> str: |
|
if k.startswith("model.diffusion_model"): |
|
return "unet" |
|
elif k.startswith("first_stage_model"): |
|
return "vae" |
|
elif k.startswith("cond_stage_model"): |
|
return "clip" |
|
return "other" |
|
|
|
|
|
def load_model(path): |
|
if path.endswith(".safetensors"): |
|
m = safetensors.torch.load_file(path, device="cpu") |
|
else: |
|
m = torch.load(path, map_location="cpu") |
|
state_dict = m["state_dict"] if "state_dict" in m else m |
|
return state_dict |
|
|
|
|
|
def fix_model(model, fix_clip=False, force_position_id=False): |
|
|
|
nai_keys = { |
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', |
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', |
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.' |
|
} |
|
position_id_key = "cond_stage_model.transformer.text_model.embeddings.position_ids" |
|
for k in list(model.keys()): |
|
for r in nai_keys: |
|
if type(k) == str and k.startswith(r): |
|
new_key = k.replace(r, nai_keys[r]) |
|
model[new_key] = model[k] |
|
del model[k] |
|
print(f"[Converter] Fixed novelai error key {k}") |
|
break |
|
|
|
if force_position_id and position_id_key in model: |
|
model[position_id_key] = model[position_id_key].to(torch.int64) |
|
|
|
if fix_clip: |
|
if position_id_key in model: |
|
correct = torch.Tensor([list(range(77))]).to(torch.int64) |
|
now = model[position_id_key].to(torch.int64) |
|
|
|
broken = correct.ne(now) |
|
broken = [i for i in range(77) if broken[0][i]] |
|
if len(broken) != 0: |
|
model[position_id_key] = correct |
|
print(f"[Converter] Fixed broken clip\n{broken}") |
|
else: |
|
print("[Converter] Clip in this model is fine, skip fixing...") |
|
else: |
|
print("[Converter] Missing position id in model, try fixing...") |
|
model[position_id_key] = torch.Tensor([list(range(77))]).to(torch.int64) |
|
|
|
return model |
|
|
|
|
|
def convert_warp( |
|
model_name, model_path, directory, |
|
*args |
|
): |
|
if sum(map(bool, [model_name, model_path, directory])) != 1: |
|
print("[Converter] Check your inputs. Multiple input was set or missing input") |
|
return |
|
|
|
if directory != "": |
|
if not os.path.exists(directory) or not os.path.isdir(directory): |
|
return "Error: path not exists or not dir" |
|
|
|
files = [f for f in os.listdir(directory) if f.endswith(".ckpt") or f.endswith(".safetensors")] |
|
|
|
if len(files) == 0: |
|
return "Error: cant found model in directory" |
|
|
|
|
|
_args = list(args) |
|
_args[3] = "" |
|
|
|
for m in files: |
|
do_convert(MockModelInfo(os.path.join(directory, m)), *_args) |
|
|
|
elif model_path != "": |
|
if os.path.exists(model_path): |
|
return do_convert(MockModelInfo(model_path), *args) |
|
|
|
elif model_name != "": |
|
model_info = sd_models.checkpoints_list[model_name] |
|
return do_convert(MockModelInfo(model_info.filename), *args) |
|
|
|
else: |
|
return "Error: must choose a model" |
|
|
|
|
|
def do_convert(model_info: MockModelInfo, |
|
checkpoint_formats, |
|
precision, conv_type, custom_name, |
|
bake_in_vae, |
|
unet_conv, text_encoder_conv, vae_conv, others_conv, |
|
fix_clip, force_position_id, delete_known_junk_data): |
|
if len(checkpoint_formats) == 0: |
|
return "Error: at least choose one model save format" |
|
|
|
extra_opt = { |
|
"unet": unet_conv, |
|
"clip": text_encoder_conv, |
|
"vae": vae_conv, |
|
"other": others_conv |
|
} |
|
shared.state.begin() |
|
shared.state.job = 'model-convert' |
|
shared.state.textinfo = f"Loading {model_info.filename}..." |
|
print(f"[Converter] Loading {model_info.filename}...") |
|
|
|
ok = {} |
|
state_dict = load_model(model_info.filepath) |
|
fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id) |
|
|
|
conv_func = _g_precision_func[precision] |
|
|
|
def _hf(wk: str, t: Tensor): |
|
if not isinstance(t, Tensor): |
|
return |
|
weight_type = check_weight_type(wk) |
|
conv_t = extra_opt[weight_type] |
|
if conv_t == "convert": |
|
ok[wk] = conv_func(t) |
|
elif conv_t == "copy": |
|
ok[wk] = t |
|
elif conv_t == "delete": |
|
return |
|
|
|
print("[Converter] Converting model...") |
|
|
|
if conv_type == "ema-only": |
|
for k in tqdm.tqdm(state_dict): |
|
ema_k = "___" |
|
try: |
|
ema_k = "model_ema." + k[6:].replace(".", "") |
|
except: |
|
pass |
|
if ema_k in state_dict: |
|
_hf(k, state_dict[ema_k]) |
|
|
|
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: |
|
_hf(k, state_dict[k]) |
|
|
|
|
|
|
|
elif conv_type == "no-ema": |
|
for k, v in tqdm.tqdm(state_dict.items()): |
|
if "model_ema." not in k: |
|
_hf(k, v) |
|
else: |
|
for k, v in tqdm.tqdm(state_dict.items()): |
|
_hf(k, v) |
|
|
|
if delete_known_junk_data: |
|
known_junk_data_prefix = [ |
|
"embedding_manager.embedder.", |
|
"lora_te_text_model", |
|
"control_model." |
|
] |
|
need_delete = [] |
|
for key in ok.keys(): |
|
for jk in known_junk_data_prefix: |
|
if key.startswith(jk): |
|
need_delete.append(key) |
|
|
|
for k in need_delete: |
|
del ok[k] |
|
|
|
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) |
|
if bake_in_vae_filename is not None: |
|
print(f"[Converter] Baking in VAE from {bake_in_vae_filename}") |
|
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') |
|
|
|
for k, v in vae_dict.items(): |
|
_hf(k, vae_dict[k]) |
|
|
|
del vae_dict |
|
|
|
output = "" |
|
ckpt_dir = os.path.dirname(model_info.filepath) |
|
save_name = f"{model_info.model_name}-{precision}" |
|
if conv_type != "disabled": |
|
save_name += f"-{conv_type}" |
|
|
|
if fix_clip: |
|
save_name += f"-clip-fix" |
|
|
|
if custom_name != "": |
|
save_name = custom_name |
|
|
|
for fmt in checkpoint_formats: |
|
ext = ".safetensors" if fmt == "safetensors" else ".ckpt" |
|
_save_name = save_name + ext |
|
|
|
save_path = os.path.join(ckpt_dir, _save_name) |
|
print(f"[Converter] Saving to {save_path}...") |
|
|
|
if fmt == "safetensors": |
|
safetensors.torch.save_file(ok, save_path) |
|
else: |
|
torch.save({"state_dict": ok}, save_path) |
|
output += f"Checkpoint saved to {save_path}\n" |
|
|
|
shared.state.end() |
|
return output[:-1] |
|
|