|
import fire |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
import torch.nn as nn |
|
|
|
from .lora import ( |
|
save_all, |
|
_find_modules, |
|
LoraInjectedConv2d, |
|
LoraInjectedLinear, |
|
inject_trainable_lora, |
|
inject_trainable_lora_extended, |
|
) |
|
|
|
|
|
def _iter_lora(model): |
|
for module in model.modules(): |
|
if isinstance(module, LoraInjectedConv2d) or isinstance( |
|
module, LoraInjectedLinear |
|
): |
|
yield module |
|
|
|
|
|
def overwrite_base(base_model, tuned_model, rank, clamp_quantile): |
|
device = base_model.device |
|
dtype = base_model.dtype |
|
|
|
for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)): |
|
|
|
if isinstance(lor_base, LoraInjectedLinear): |
|
residual = lor_tune.linear.weight.data - lor_base.linear.weight.data |
|
|
|
print("Distill Linear shape ", residual.shape) |
|
residual = residual.float() |
|
U, S, Vh = torch.linalg.svd(residual) |
|
U = U[:, :rank] |
|
S = S[:rank] |
|
U = U @ torch.diag(S) |
|
|
|
Vh = Vh[:rank, :] |
|
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
hi_val = torch.quantile(dist, clamp_quantile) |
|
low_val = -hi_val |
|
|
|
U = U.clamp(low_val, hi_val) |
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
|
assert lor_base.lora_up.weight.shape == U.shape |
|
assert lor_base.lora_down.weight.shape == Vh.shape |
|
|
|
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) |
|
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) |
|
|
|
if isinstance(lor_base, LoraInjectedConv2d): |
|
residual = lor_tune.conv.weight.data - lor_base.conv.weight.data |
|
print("Distill Conv shape ", residual.shape) |
|
|
|
residual = residual.float() |
|
residual = residual.flatten(start_dim=1) |
|
|
|
|
|
U, S, Vh = torch.linalg.svd(residual) |
|
U = U[:, :rank] |
|
S = S[:rank] |
|
U = U @ torch.diag(S) |
|
|
|
Vh = Vh[:rank, :] |
|
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
hi_val = torch.quantile(dist, clamp_quantile) |
|
low_val = -hi_val |
|
|
|
U = U.clamp(low_val, hi_val) |
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
|
|
|
U = U.reshape(U.shape[0], U.shape[1], 1, 1) |
|
|
|
|
|
Vh = Vh.reshape( |
|
Vh.shape[0], |
|
lor_base.conv.in_channels, |
|
lor_base.conv.kernel_size[0], |
|
lor_base.conv.kernel_size[1], |
|
) |
|
|
|
assert lor_base.lora_up.weight.shape == U.shape |
|
assert lor_base.lora_down.weight.shape == Vh.shape |
|
|
|
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) |
|
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) |
|
|
|
|
|
def svd_distill( |
|
target_model: str, |
|
base_model: str, |
|
rank: int = 4, |
|
clamp_quantile: float = 0.99, |
|
device: str = "cuda:0", |
|
save_path: str = "svd_distill.safetensors", |
|
): |
|
pipe_base = StableDiffusionPipeline.from_pretrained( |
|
base_model, torch_dtype=torch.float16 |
|
).to(device) |
|
|
|
pipe_tuned = StableDiffusionPipeline.from_pretrained( |
|
target_model, torch_dtype=torch.float16 |
|
).to(device) |
|
|
|
|
|
_ = inject_trainable_lora_extended(pipe_base.unet, r=rank) |
|
_ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank) |
|
|
|
overwrite_base( |
|
pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile |
|
) |
|
|
|
|
|
_ = inject_trainable_lora( |
|
pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"} |
|
) |
|
_ = inject_trainable_lora( |
|
pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"} |
|
) |
|
|
|
overwrite_base( |
|
pipe_base.text_encoder, |
|
pipe_tuned.text_encoder, |
|
rank=rank, |
|
clamp_quantile=clamp_quantile, |
|
) |
|
|
|
save_all( |
|
unet=pipe_base.unet, |
|
text_encoder=pipe_base.text_encoder, |
|
placeholder_token_ids=None, |
|
placeholder_tokens=None, |
|
save_path=save_path, |
|
save_lora=True, |
|
save_ti=False, |
|
) |
|
|
|
|
|
def main(): |
|
fire.Fire(svd_distill) |
|
|