|
import argparse |
|
import os |
|
|
|
import torch |
|
from safetensors import safe_open |
|
from safetensors.torch import load_file, save_file |
|
from tqdm import tqdm |
|
|
|
|
|
def is_unet_key(key): |
|
|
|
return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) |
|
|
|
|
|
TEXT_ENCODER_KEY_REPLACEMENTS = [ |
|
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), |
|
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), |
|
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), |
|
] |
|
|
|
|
|
|
|
def replace_text_encoder_key(key): |
|
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: |
|
if key.startswith(rep_from): |
|
return True, rep_to + key[len(rep_from) :] |
|
return False, key |
|
|
|
|
|
def merge(args): |
|
if args.precision == "fp16": |
|
dtype = torch.float16 |
|
elif args.precision == "bf16": |
|
dtype = torch.bfloat16 |
|
else: |
|
dtype = torch.float |
|
|
|
if args.saving_precision == "fp16": |
|
save_dtype = torch.float16 |
|
elif args.saving_precision == "bf16": |
|
save_dtype = torch.bfloat16 |
|
else: |
|
save_dtype = torch.float |
|
|
|
|
|
for model in args.models: |
|
if not model.endswith("safetensors"): |
|
print(f"Model {model} is not a safetensors model") |
|
exit() |
|
if not os.path.isfile(model): |
|
print(f"Model {model} does not exist") |
|
exit() |
|
|
|
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" |
|
|
|
|
|
ratio = 1.0 / len(args.models) |
|
supplementary_key_ratios = {} |
|
|
|
merged_sd = None |
|
first_model_keys = set() |
|
for i, model in enumerate(args.models): |
|
if args.ratios is not None: |
|
ratio = args.ratios[i] |
|
|
|
if merged_sd is None: |
|
|
|
print(f"Loading model {model}, ratio = {ratio}...") |
|
merged_sd = {} |
|
with safe_open(model, framework="pt", device=args.device) as f: |
|
for key in tqdm(f.keys()): |
|
value = f.get_tensor(key) |
|
_, key = replace_text_encoder_key(key) |
|
|
|
first_model_keys.add(key) |
|
|
|
if not is_unet_key(key) and args.unet_only: |
|
supplementary_key_ratios[key] = 1.0 |
|
continue |
|
|
|
value = ratio * value.to(dtype) |
|
merged_sd[key] = value |
|
|
|
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) |
|
continue |
|
|
|
|
|
print(f"Loading model {model}, ratio = {ratio}...") |
|
|
|
with safe_open(model, framework="pt", device=args.device) as f: |
|
model_keys = f.keys() |
|
for key in tqdm(model_keys): |
|
_, new_key = replace_text_encoder_key(key) |
|
if new_key not in merged_sd: |
|
if args.show_skipped and new_key not in first_model_keys: |
|
print(f"Skip: {new_key}") |
|
continue |
|
|
|
value = f.get_tensor(key) |
|
merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) |
|
|
|
|
|
model_keys = set(model_keys) |
|
for key in merged_sd.keys(): |
|
if key in model_keys: |
|
continue |
|
print(f"Key {key} not in model {model}, use first model's value") |
|
if key in supplementary_key_ratios: |
|
supplementary_key_ratios[key] += ratio |
|
else: |
|
supplementary_key_ratios[key] = ratio |
|
|
|
|
|
if len(supplementary_key_ratios) > 0: |
|
print("add first model's value") |
|
with safe_open(args.models[0], framework="pt", device=args.device) as f: |
|
for key in tqdm(f.keys()): |
|
_, new_key = replace_text_encoder_key(key) |
|
if new_key not in supplementary_key_ratios: |
|
continue |
|
|
|
if is_unet_key(new_key): |
|
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") |
|
|
|
value = f.get_tensor(key) |
|
|
|
if new_key not in merged_sd: |
|
merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) |
|
else: |
|
merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) |
|
|
|
|
|
output_file = args.output |
|
if not output_file.endswith(".safetensors"): |
|
output_file = output_file + ".safetensors" |
|
|
|
print(f"Saving to {output_file}...") |
|
|
|
|
|
for k in merged_sd.keys(): |
|
merged_sd[k] = merged_sd[k].to(save_dtype) |
|
|
|
save_file(merged_sd, output_file) |
|
|
|
print("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Merge models") |
|
parser.add_argument("--models", nargs="+", type=str, help="Models to merge") |
|
parser.add_argument("--output", type=str, help="Output model") |
|
parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") |
|
parser.add_argument("--unet_only", action="store_true", help="Only merge unet") |
|
parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") |
|
parser.add_argument( |
|
"--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" |
|
) |
|
parser.add_argument( |
|
"--saving_precision", |
|
type=str, |
|
default="float", |
|
choices=["float", "fp16", "bf16"], |
|
help="Saving precision, default is float", |
|
) |
|
parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") |
|
|
|
args = parser.parse_args() |
|
merge(args) |
|
|