|
State dicts generated with: |
|
|
|
```py |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from peft import LoraConfig |
|
from peft.utils import get_peft_model_state_dict |
|
from huggingface_hub import create_repo, upload_file |
|
import tempfile |
|
import os |
|
|
|
|
|
ckpts = [ |
|
"stable-diffusion-v1-5/stable-diffusion-v1-5", |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
"black-forest-labs/FLUX.1-dev" |
|
] |
|
|
|
ranks = [16, 32, 128] |
|
|
|
repo_id = create_repo(repo_id="sayakpaul/dummy-lora-state-dicts", exist_ok=True).repo_id |
|
|
|
def get_lora_config(rank=16): |
|
return LoraConfig( |
|
r=rank, |
|
lora_alpha=rank, |
|
init_lora_weights="gaussian", |
|
target_modules=["to_k", "to_v", "to_q", "to_out.0"], |
|
) |
|
|
|
|
|
def load_pipeline_and_obtain_lora(ckpt, rank): |
|
pipeline = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.bfloat16) |
|
pipeline_cls = pipeline.__class__ |
|
|
|
lora_config = get_lora_config(rank=rank) |
|
weight_name = f"r@{rank}-{ckpt.split('/')[-1]}.safetensors" |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
save_kwargs = {"weight_name": weight_name} |
|
if hasattr(pipeline, "unet"): |
|
pipeline.unet.add_adapter(lora_config) |
|
save_kwargs.update({"unet_lora_layers": get_peft_model_state_dict(pipeline.unet)}) |
|
else: |
|
pipeline.transformer.add_adapter(lora_config) |
|
save_kwargs.update({"transformer_lora_layers": get_peft_model_state_dict(pipeline.transformer)}) |
|
|
|
pipeline_cls.save_lora_weights(save_directory=tmpdir, **save_kwargs) |
|
upload_file(repo_id=repo_id, path_or_fileobj=os.path.join(tmpdir, weight_name), path_in_repo=weight_name) |
|
|
|
|
|
for ckpt in ckpts: |
|
for rank in ranks: |
|
load_pipeline_and_obtain_lora(ckpt=ckpt, rank=rank) |
|
``` |