sayakpaul's picture
sayakpaul HF staff
Create README.md
6d81182 verified

State dicts generated with:

from diffusers import DiffusionPipeline
import torch 
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from huggingface_hub import create_repo, upload_file
import tempfile
import os


ckpts = [
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "black-forest-labs/FLUX.1-dev"
]

ranks = [16, 32, 128]

repo_id = create_repo(repo_id="sayakpaul/dummy-lora-state-dicts", exist_ok=True).repo_id

def get_lora_config(rank=16):
    return LoraConfig(
        r=rank,
        lora_alpha=rank,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_v", "to_q", "to_out.0"],
    )


def load_pipeline_and_obtain_lora(ckpt, rank):
    pipeline = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
    pipeline_cls = pipeline.__class__

    lora_config = get_lora_config(rank=rank)
    weight_name = f"r@{rank}-{ckpt.split('/')[-1]}.safetensors"

    with tempfile.TemporaryDirectory() as tmpdir:
        save_kwargs = {"weight_name": weight_name}
        if hasattr(pipeline, "unet"):
            pipeline.unet.add_adapter(lora_config)
            save_kwargs.update({"unet_lora_layers": get_peft_model_state_dict(pipeline.unet)})
        else:
            pipeline.transformer.add_adapter(lora_config)
            save_kwargs.update({"transformer_lora_layers": get_peft_model_state_dict(pipeline.transformer)})
        
        pipeline_cls.save_lora_weights(save_directory=tmpdir, **save_kwargs)
        upload_file(repo_id=repo_id, path_or_fileobj=os.path.join(tmpdir, weight_name), path_in_repo=weight_name)


for ckpt in ckpts:
    for rank in ranks:
        load_pipeline_and_obtain_lora(ckpt=ckpt, rank=rank)