|
|
|
|
|
|
|
import bisect
|
|
import math
|
|
import random
|
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
|
from diffusers import UNet2DConditionModel
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from transformers import CLIPTextModel
|
|
import torch
|
|
|
|
|
|
def make_unet_conversion_map() -> Dict[str, str]:
|
|
unet_conversion_map_layer = []
|
|
|
|
for i in range(3):
|
|
|
|
for j in range(2):
|
|
|
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
|
|
if i < 3:
|
|
|
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
|
|
for j in range(3):
|
|
|
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
|
|
|
|
|
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
|
|
if i < 3:
|
|
|
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
|
|
|
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}."
|
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
|
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
sd_mid_atn_prefix = "middle_block.1."
|
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
|
|
for j in range(2):
|
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
|
|
unet_conversion_map_resnet = [
|
|
|
|
("in_layers.0.", "norm1."),
|
|
("in_layers.2.", "conv1."),
|
|
("out_layers.0.", "norm2."),
|
|
("out_layers.3.", "conv2."),
|
|
("emb_layers.1.", "time_emb_proj."),
|
|
("skip_connection.", "conv_shortcut."),
|
|
]
|
|
|
|
unet_conversion_map = []
|
|
for sd, hf in unet_conversion_map_layer:
|
|
if "resnets" in hf:
|
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
|
else:
|
|
unet_conversion_map.append((sd, hf))
|
|
|
|
for j in range(2):
|
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
|
|
|
for j in range(2):
|
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
|
|
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
|
|
|
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
|
return sd_hf_conversion_map
|
|
|
|
|
|
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
|
|
|
|
|
class LoRAModule(torch.nn.Module):
|
|
"""
|
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
lora_name,
|
|
org_module: torch.nn.Module,
|
|
multiplier=1.0,
|
|
lora_dim=4,
|
|
alpha=1,
|
|
):
|
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
|
super().__init__()
|
|
self.lora_name = lora_name
|
|
|
|
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
|
in_dim = org_module.in_channels
|
|
out_dim = org_module.out_channels
|
|
else:
|
|
in_dim = org_module.in_features
|
|
out_dim = org_module.out_features
|
|
|
|
self.lora_dim = lora_dim
|
|
|
|
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
|
kernel_size = org_module.kernel_size
|
|
stride = org_module.stride
|
|
padding = org_module.padding
|
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
|
else:
|
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
|
|
|
if type(alpha) == torch.Tensor:
|
|
alpha = alpha.detach().float().numpy()
|
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
|
self.scale = alpha / self.lora_dim
|
|
self.register_buffer("alpha", torch.tensor(alpha))
|
|
|
|
|
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
|
torch.nn.init.zeros_(self.lora_up.weight)
|
|
|
|
self.multiplier = multiplier
|
|
self.org_module = [org_module]
|
|
self.enabled = True
|
|
self.network: LoRANetwork = None
|
|
self.org_forward = None
|
|
|
|
|
|
def apply_to(self, multiplier=None):
|
|
if multiplier is not None:
|
|
self.multiplier = multiplier
|
|
if self.org_forward is None:
|
|
self.org_forward = self.org_module[0].forward
|
|
self.org_module[0].forward = self.forward
|
|
|
|
|
|
def unapply_to(self):
|
|
if self.org_forward is not None:
|
|
self.org_module[0].forward = self.org_forward
|
|
|
|
|
|
|
|
def forward(self, x, scale=1.0):
|
|
if not self.enabled:
|
|
return self.org_forward(x)
|
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
|
|
|
def set_network(self, network):
|
|
self.network = network
|
|
|
|
|
|
def merge_to(self, multiplier=1.0):
|
|
|
|
lora_weight = self.get_weight(multiplier)
|
|
|
|
|
|
org_sd = self.org_module[0].state_dict()
|
|
org_weight = org_sd["weight"]
|
|
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
|
|
|
|
|
org_sd["weight"] = weight
|
|
self.org_module[0].load_state_dict(org_sd)
|
|
|
|
|
|
def restore_from(self, multiplier=1.0):
|
|
|
|
lora_weight = self.get_weight(multiplier)
|
|
|
|
|
|
org_sd = self.org_module[0].state_dict()
|
|
org_weight = org_sd["weight"]
|
|
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
|
|
|
|
|
org_sd["weight"] = weight
|
|
self.org_module[0].load_state_dict(org_sd)
|
|
|
|
|
|
def get_weight(self, multiplier=None):
|
|
if multiplier is None:
|
|
multiplier = self.multiplier
|
|
|
|
|
|
up_weight = self.lora_up.weight.to(torch.float)
|
|
down_weight = self.lora_down.weight.to(torch.float)
|
|
|
|
|
|
if len(down_weight.size()) == 2:
|
|
|
|
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
|
elif down_weight.size()[2:4] == (1, 1):
|
|
|
|
weight = (
|
|
self.multiplier
|
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
* self.scale
|
|
)
|
|
else:
|
|
|
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
|
weight = self.multiplier * conved * self.scale
|
|
|
|
return weight
|
|
|
|
|
|
|
|
def create_network_from_weights(
|
|
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
|
):
|
|
|
|
modules_dim = {}
|
|
modules_alpha = {}
|
|
for key, value in weights_sd.items():
|
|
if "." not in key:
|
|
continue
|
|
|
|
lora_name = key.split(".")[0]
|
|
if "alpha" in key:
|
|
modules_alpha[lora_name] = value
|
|
elif "lora_down" in key:
|
|
dim = value.size()[0]
|
|
modules_dim[lora_name] = dim
|
|
|
|
|
|
|
|
for key in modules_dim.keys():
|
|
if key not in modules_alpha:
|
|
modules_alpha[key] = modules_dim[key]
|
|
|
|
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
|
|
|
|
|
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
|
unet = pipe.unet
|
|
|
|
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
|
lora_network.load_state_dict(weights_sd)
|
|
lora_network.merge_to(multiplier=multiplier)
|
|
|
|
|
|
|
|
class LoRANetwork(torch.nn.Module):
|
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
|
LORA_PREFIX_UNET = "lora_unet"
|
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
|
|
|
|
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
|
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
|
|
|
def __init__(
|
|
self,
|
|
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
|
unet: UNet2DConditionModel,
|
|
multiplier: float = 1.0,
|
|
modules_dim: Optional[Dict[str, int]] = None,
|
|
modules_alpha: Optional[Dict[str, int]] = None,
|
|
varbose: Optional[bool] = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.multiplier = multiplier
|
|
|
|
print(f"create LoRA network from weights")
|
|
|
|
|
|
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
|
if converted:
|
|
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
|
|
|
|
|
def create_modules(
|
|
is_unet: bool,
|
|
text_encoder_idx: Optional[int],
|
|
root_module: torch.nn.Module,
|
|
target_replace_modules: List[torch.nn.Module],
|
|
) -> List[LoRAModule]:
|
|
prefix = (
|
|
self.LORA_PREFIX_UNET
|
|
if is_unet
|
|
else (
|
|
self.LORA_PREFIX_TEXT_ENCODER
|
|
if text_encoder_idx is None
|
|
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
|
)
|
|
)
|
|
loras = []
|
|
skipped = []
|
|
for name, module in root_module.named_modules():
|
|
if module.__class__.__name__ in target_replace_modules:
|
|
for child_name, child_module in module.named_modules():
|
|
is_linear = (
|
|
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
|
)
|
|
is_conv2d = (
|
|
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
|
)
|
|
|
|
if is_linear or is_conv2d:
|
|
lora_name = prefix + "." + name + "." + child_name
|
|
lora_name = lora_name.replace(".", "_")
|
|
|
|
if lora_name not in modules_dim:
|
|
|
|
skipped.append(lora_name)
|
|
continue
|
|
|
|
dim = modules_dim[lora_name]
|
|
alpha = modules_alpha[lora_name]
|
|
lora = LoRAModule(
|
|
lora_name,
|
|
child_module,
|
|
self.multiplier,
|
|
dim,
|
|
alpha,
|
|
)
|
|
loras.append(lora)
|
|
return loras, skipped
|
|
|
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
|
|
|
|
|
|
|
self.text_encoder_loras: List[LoRAModule] = []
|
|
skipped_te = []
|
|
for i, text_encoder in enumerate(text_encoders):
|
|
if len(text_encoders) > 1:
|
|
index = i + 1
|
|
else:
|
|
index = None
|
|
|
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
|
self.text_encoder_loras.extend(text_encoder_loras)
|
|
skipped_te += skipped
|
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
|
if len(skipped_te) > 0:
|
|
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
|
|
|
|
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
|
|
|
self.unet_loras: List[LoRAModule]
|
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
|
if len(skipped_un) > 0:
|
|
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
|
|
|
|
|
names = set()
|
|
for lora in self.text_encoder_loras + self.unet_loras:
|
|
names.add(lora.lora_name)
|
|
for lora_name in modules_dim.keys():
|
|
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
|
|
|
|
|
for lora in self.text_encoder_loras + self.unet_loras:
|
|
self.add_module(lora.lora_name, lora)
|
|
|
|
|
|
def convert_unet_modules(self, modules_dim, modules_alpha):
|
|
converted_count = 0
|
|
not_converted_count = 0
|
|
|
|
map_keys = list(UNET_CONVERSION_MAP.keys())
|
|
map_keys.sort()
|
|
|
|
for key in list(modules_dim.keys()):
|
|
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
|
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
|
position = bisect.bisect_right(map_keys, search_key)
|
|
map_key = map_keys[position - 1]
|
|
if search_key.startswith(map_key):
|
|
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
|
modules_dim[new_key] = modules_dim[key]
|
|
modules_alpha[new_key] = modules_alpha[key]
|
|
del modules_dim[key]
|
|
del modules_alpha[key]
|
|
converted_count += 1
|
|
else:
|
|
not_converted_count += 1
|
|
assert (
|
|
converted_count == 0 or not_converted_count == 0
|
|
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
|
return converted_count
|
|
|
|
def set_multiplier(self, multiplier):
|
|
self.multiplier = multiplier
|
|
for lora in self.text_encoder_loras + self.unet_loras:
|
|
lora.multiplier = self.multiplier
|
|
|
|
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
|
if apply_text_encoder:
|
|
print("enable LoRA for text encoder")
|
|
for lora in self.text_encoder_loras:
|
|
lora.apply_to(multiplier)
|
|
if apply_unet:
|
|
print("enable LoRA for U-Net")
|
|
for lora in self.unet_loras:
|
|
lora.apply_to(multiplier)
|
|
|
|
def unapply_to(self):
|
|
for lora in self.text_encoder_loras + self.unet_loras:
|
|
lora.unapply_to()
|
|
|
|
def merge_to(self, multiplier=1.0):
|
|
print("merge LoRA weights to original weights")
|
|
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
|
lora.merge_to(multiplier)
|
|
print(f"weights are merged")
|
|
|
|
def restore_from(self, multiplier=1.0):
|
|
print("restore LoRA weights from original weights")
|
|
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
|
lora.restore_from(multiplier)
|
|
print(f"weights are restored")
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
|
|
map_keys = list(UNET_CONVERSION_MAP.keys())
|
|
map_keys.sort()
|
|
for key in list(state_dict.keys()):
|
|
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
|
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
|
position = bisect.bisect_right(map_keys, search_key)
|
|
map_key = map_keys[position - 1]
|
|
if search_key.startswith(map_key):
|
|
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
|
state_dict[new_key] = state_dict[key]
|
|
del state_dict[key]
|
|
|
|
|
|
|
|
my_state_dict = self.state_dict()
|
|
for key in state_dict.keys():
|
|
if state_dict[key].size() != my_state_dict[key].size():
|
|
|
|
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
|
|
|
return super().load_state_dict(state_dict, strict)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import os
|
|
import argparse
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
|
import torch
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
|
parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
|
|
parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
|
|
parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
|
|
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
|
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
|
args = parser.parse_args()
|
|
|
|
image_prefix = args.model_id.replace("/", "_") + "_"
|
|
|
|
|
|
print(f"load model from {args.model_id}")
|
|
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
|
if args.sdxl:
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
|
else:
|
|
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
|
pipe.to(device)
|
|
pipe.set_use_memory_efficient_attention_xformers(True)
|
|
|
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
|
|
|
|
|
print(f"load LoRA weights from {args.lora_weights}")
|
|
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
|
from safetensors.torch import load_file
|
|
|
|
lora_sd = load_file(args.lora_weights)
|
|
else:
|
|
lora_sd = torch.load(args.lora_weights)
|
|
|
|
|
|
print(f"create LoRA network")
|
|
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
|
|
|
print(f"load LoRA network weights")
|
|
lora_network.load_state_dict(lora_sd)
|
|
|
|
lora_network.to(device, dtype=pipe.unet.dtype)
|
|
|
|
|
|
|
|
def detach_and_move_to_cpu(state_dict):
|
|
for k, v in state_dict.items():
|
|
state_dict[k] = v.detach().cpu()
|
|
return state_dict
|
|
|
|
org_unet_sd = pipe.unet.state_dict()
|
|
detach_and_move_to_cpu(org_unet_sd)
|
|
|
|
org_text_encoder_sd = pipe.text_encoder.state_dict()
|
|
detach_and_move_to_cpu(org_text_encoder_sd)
|
|
|
|
if args.sdxl:
|
|
org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
|
|
detach_and_move_to_cpu(org_text_encoder_2_sd)
|
|
|
|
def seed_everything(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
|
|
print(f"create image with original weights")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "original.png")
|
|
|
|
|
|
print(f"apply LoRA network to the model")
|
|
lora_network.apply_to(multiplier=1.0)
|
|
|
|
print(f"create image with applied LoRA")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "applied_lora.png")
|
|
|
|
|
|
print(f"unapply LoRA network to the model")
|
|
lora_network.unapply_to()
|
|
|
|
print(f"create image with unapplied LoRA")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "unapplied_lora.png")
|
|
|
|
|
|
print(f"merge LoRA network to the model")
|
|
lora_network.merge_to(multiplier=1.0)
|
|
|
|
print(f"create image with LoRA")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "merged_lora.png")
|
|
|
|
|
|
|
|
|
|
print(f"restore (unmerge) LoRA weights")
|
|
lora_network.restore_from(multiplier=1.0)
|
|
|
|
print(f"create image without LoRA")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "unmerged_lora.png")
|
|
|
|
|
|
print(f"restore original weights")
|
|
pipe.unet.load_state_dict(org_unet_sd)
|
|
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
|
if args.sdxl:
|
|
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
|
|
|
print(f"create image with restored original weights")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "restore_original.png")
|
|
|
|
|
|
print(f"merge LoRA weights with convenience function")
|
|
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
|
|
|
print(f"create image with merged LoRA weights")
|
|
seed_everything(args.seed)
|
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
|
image.save(image_prefix + "convenience_merged_lora.png")
|
|
|