Spaces:
No application file
No application file
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Low-Rank Adaptation of Large Language Models (LoRA) | |
[[open-in-colab]] | |
<Tip warning={true}> | |
ํ์ฌ LoRA๋ [`UNet2DConditionalModel`]์ ์ดํ ์ ๋ ์ด์ด์์๋ง ์ง์๋ฉ๋๋ค. | |
</Tip> | |
[LoRA(Low-Rank Adaptation of Large Language Models)](https://arxiv.org/abs/2106.09685)๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ๊ฒ ์ฌ์ฉํ๋ฉด์ ๋๊ท๋ชจ ๋ชจ๋ธ์ ํ์ต์ ๊ฐ์ํํ๋ ํ์ต ๋ฐฉ๋ฒ์ ๋๋ค. ์ด๋ rank-decomposition weight ํ๋ ฌ ์(**์ ๋ฐ์ดํธ ํ๋ ฌ**์ด๋ผ๊ณ ํจ)์ ์ถ๊ฐํ๊ณ ์๋ก ์ถ๊ฐ๋ ๊ฐ์ค์น**๋ง** ํ์ตํฉ๋๋ค. ์ฌ๊ธฐ์๋ ๋ช ๊ฐ์ง ์ฅ์ ์ด ์์ต๋๋ค. | |
- ์ด์ ์ ๋ฏธ๋ฆฌ ํ์ต๋ ๊ฐ์ค์น๋ ๊ณ ์ ๋ ์ํ๋ก ์ ์ง๋๋ฏ๋ก ๋ชจ๋ธ์ด [์น๋ช ์ ์ธ ๋ง๊ฐ](https://www.pnas.org/doi/10.1073/pnas.1611835114) ๊ฒฝํฅ์ด ์์ต๋๋ค. | |
- Rank-decomposition ํ๋ ฌ์ ์๋ ๋ชจ๋ธ๋ณด๋ค ํ๋ผ๋ฉํฐ ์๊ฐ ํจ์ฌ ์ ์ผ๋ฏ๋ก ํ์ต๋ LoRA ๊ฐ์ค์น๋ฅผ ์ฝ๊ฒ ๋ผ์๋ฃ์ ์ ์์ต๋๋ค. | |
- LoRA ๋งคํธ๋ฆญ์ค๋ ์ผ๋ฐ์ ์ผ๋ก ์๋ณธ ๋ชจ๋ธ์ ์ดํ ์ ๋ ์ด์ด์ ์ถ๊ฐ๋ฉ๋๋ค. ๐งจ Diffusers๋ [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] ๋ฉ์๋๋ฅผ ์ ๊ณตํ์ฌ LoRA ๊ฐ์ค์น๋ฅผ ๋ชจ๋ธ์ ์ดํ ์ ๋ ์ด์ด๋ก ๋ถ๋ฌ์ต๋๋ค. `scale` ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ๋ชจ๋ธ์ด ์๋ก์ด ํ์ต ์ด๋ฏธ์ง์ ๋ง๊ฒ ์กฐ์ ๋๋ ๋ฒ์๋ฅผ ์ ์ดํ ์ ์์ต๋๋ค. | |
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ด ํฅ์๋์ด Tesla T4, RTX 3080 ๋๋ RTX 2080 Ti์ ๊ฐ์ ์๋น์์ฉ GPU์์ ํ์ธํ๋์ ์คํํ ์ ์์ต๋๋ค! T4์ ๊ฐ์ GPU๋ ๋ฌด๋ฃ์ด๋ฉฐ Kaggle ๋๋ Google Colab ๋ ธํธ๋ถ์์ ์ฝ๊ฒ ์ก์ธ์คํ ์ ์์ต๋๋ค. | |
<Tip> | |
๐ก LoRA๋ ์ดํ ์ ๋ ์ด์ด์๋ง ํ์ ๋์ง๋ ์์ต๋๋ค. ์ ์๋ ์ธ์ด ๋ชจ๋ธ์ ์ดํ ์ ๋ ์ด์ด๋ฅผ ์์ ํ๋ ๊ฒ์ด ๋งค์ฐ ํจ์จ์ ์ผ๋ก ์ฃป์ ์ฑ๋ฅ์ ์ป๊ธฐ์ ์ถฉ๋ถํ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์ด๊ฒ์ด LoRA ๊ฐ์ค์น๋ฅผ ๋ชจ๋ธ์ ์ดํ ์ ๋ ์ด์ด์ ์ถ๊ฐํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ธ ์ด์ ์ ๋๋ค. LoRA ์๋ ๋ฐฉ์์ ๋ํ ์์ธํ ๋ด์ฉ์ [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) ๋ธ๋ก๊ทธ๋ฅผ ํ์ธํ์ธ์! | |
</Tip> | |
[cloneofsimo](https://github.com/cloneofsimo)๋ ์ธ๊ธฐ ์๋ [lora](https://github.com/cloneofsimo/lora) GitHub ๋ฆฌํฌ์งํ ๋ฆฌ์์ Stable Diffusion์ ์ํ LoRA ํ์ต์ ์ต์ด๋ก ์๋ํ์ต๋๋ค. ๐งจ Diffusers๋ [text-to-image ์์ฑ](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) ๋ฐ [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)์ ์ง์ํฉ๋๋ค. ์ด ๊ฐ์ด๋๋ ๋ ๊ฐ์ง๋ฅผ ๋ชจ๋ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. | |
๋ชจ๋ธ์ ์ ์ฅํ๊ฑฐ๋ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ๋ ค๋ฉด Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ์ธ์(์์ง ๊ณ์ ์ด ์๋ ๊ฒฝ์ฐ [์์ฑ](hf.co/join)ํ์ธ์): | |
```bash | |
huggingface-cli login | |
``` | |
## Text-to-image | |
์์ญ์ต ๊ฐ์ ํ๋ผ๋ฉํฐ๋ค์ด ์๋ Stable Diffusion๊ณผ ๊ฐ์ ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๊ฒ์ ๋๋ฆฌ๊ณ ์ด๋ ค์ธ ์ ์์ต๋๋ค. LoRA๋ฅผ ์ฌ์ฉํ๋ฉด diffusion ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๊ฒ์ด ํจ์ฌ ์ฝ๊ณ ๋น ๋ฆ ๋๋ค. 8๋นํธ ์ตํฐ๋ง์ด์ ์ ๊ฐ์ ํธ๋ฆญ์ ์์กดํ์ง ์๊ณ ๋ 11GB์ GPU RAM์ผ๋ก ํ๋์จ์ด์์ ์คํํ ์ ์์ต๋๋ค. | |
### ํ์ต[[dreambooth-training]] | |
[Pokรฉmon BLIP ์บก์ ](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) ๋ฐ์ดํฐ์ ์ผ๋ก [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)๋ฅผ ํ์ธํ๋ํด ๋๋ง์ ํฌ์ผ๋ชฌ์ ์์ฑํด ๋ณด๊ฒ ์ต๋๋ค. | |
์์ํ๋ ค๋ฉด `MODEL_NAME` ๋ฐ `DATASET_NAME` ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ด ์๋์ง ํ์ธํ์ญ์์ค. `OUTPUT_DIR` ๋ฐ `HUB_MODEL_ID` ๋ณ์๋ ์ ํ ์ฌํญ์ด๋ฉฐ ํ๋ธ์์ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํฉ๋๋ค. | |
```bash | |
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | |
export OUTPUT_DIR="/sddata/finetune/lora/pokemon" | |
export HUB_MODEL_ID="pokemon-lora" | |
export DATASET_NAME="lambdalabs/pokemon-blip-captions" | |
``` | |
ํ์ต์ ์์ํ๊ธฐ ์ ์ ์์์ผ ํ ๋ช ๊ฐ์ง ํ๋๊ทธ๊ฐ ์์ต๋๋ค. | |
* `--push_to_hub`๋ฅผ ๋ช ์ํ๋ฉด ํ์ต๋ LoRA ์๋ฒ ๋ฉ์ ํ๋ธ์ ์ ์ฅํฉ๋๋ค. | |
* `--report_to=wandb`๋ ํ์ต ๊ฒฐ๊ณผ๋ฅผ ๊ฐ์ค์น ๋ฐ ํธํฅ ๋์๋ณด๋์ ๋ณด๊ณ ํ๊ณ ๊ธฐ๋กํฉ๋๋ค(์๋ฅผ ๋ค์ด, ์ด [๋ณด๊ณ ์](https://wandb.ai/pcuenq/text2image-fine-tune/run/b4k1w0tn?workspace=user-pcuenq)๋ฅผ ์ฐธ์กฐํ์ธ์). | |
* `--learning_rate=1e-04`, ์ผ๋ฐ์ ์ผ๋ก LoRA์์ ์ฌ์ฉํ๋ ๊ฒ๋ณด๋ค ๋ ๋์ ํ์ต๋ฅ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
์ด์ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค (์ ์ฒด ํ์ต ์คํฌ๋ฆฝํธ๋ [์ฌ๊ธฐ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)์์ ์ฐพ์ ์ ์์ต๋๋ค). | |
```bash | |
accelerate launch train_dreambooth_lora.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--instance_data_dir=$INSTANCE_DIR \ | |
--output_dir=$OUTPUT_DIR \ | |
--instance_prompt="a photo of sks dog" \ | |
--resolution=512 \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=1 \ | |
--checkpointing_steps=100 \ | |
--learning_rate=1e-4 \ | |
--report_to="wandb" \ | |
--lr_scheduler="constant" \ | |
--lr_warmup_steps=0 \ | |
--max_train_steps=500 \ | |
--validation_prompt="A photo of sks dog in a bucket" \ | |
--validation_epochs=50 \ | |
--seed="0" \ | |
--push_to_hub | |
``` | |
### ์ถ๋ก [[dreambooth-inference]] | |
์ด์ [`StableDiffusionPipeline`]์์ ๊ธฐ๋ณธ ๋ชจ๋ธ์ ๋ถ๋ฌ์ ์ถ๋ก ์ ์ํด ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค: | |
```py | |
>>> import torch | |
>>> from diffusers import StableDiffusionPipeline | |
>>> model_base = "runwayml/stable-diffusion-v1-5" | |
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16) | |
``` | |
*๊ธฐ๋ณธ ๋ชจ๋ธ์ ๊ฐ์ค์น ์์* ํ์ธํ๋๋ DreamBooth ๋ชจ๋ธ์์ LoRA ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์จ ๋ค์, ๋ ๋น ๋ฅธ ์ถ๋ก ์ ์ํด ํ์ดํ๋ผ์ธ์ GPU๋ก ์ด๋ํฉ๋๋ค. LoRA ๊ฐ์ค์น๋ฅผ ํ๋ฆฌ์ง๋ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๊ฐ์ค์น์ ๋ณํฉํ ๋, ์ ํ์ ์ผ๋ก 'scale' ๋งค๊ฐ๋ณ์๋ก ์ด๋ ์ ๋์ ๊ฐ์ค์น๋ฅผ ๋ณํฉํ ์ง ์กฐ์ ํ ์ ์์ต๋๋ค: | |
<Tip> | |
๐ก `0`์ `scale` ๊ฐ์ LoRA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ์ง ์์ ์๋ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ง ์ฌ์ฉํ ๊ฒ๊ณผ ๊ฐ๊ณ , `1`์ `scale` ๊ฐ์ ํ์ธํ๋๋ LoRA ๊ฐ์ค์น๋ง ์ฌ์ฉํจ์ ์๋ฏธํฉ๋๋ค. 0๊ณผ 1 ์ฌ์ด์ ๊ฐ๋ค์ ๋ ๊ฒฐ๊ณผ๋ค ์ฌ์ด๋ก ๋ณด๊ฐ๋ฉ๋๋ค. | |
</Tip> | |
```py | |
>>> pipe.unet.load_attn_procs(model_path) | |
>>> pipe.to("cuda") | |
# LoRA ํ์ธํ๋๋ ๋ชจ๋ธ์ ๊ฐ์ค์น ์ ๋ฐ๊ณผ ๊ธฐ๋ณธ ๋ชจ๋ธ์ ๊ฐ์ค์น ์ ๋ฐ ์ฌ์ฉ | |
>>> image = pipe( | |
... "A picture of a sks dog in a bucket.", | |
... num_inference_steps=25, | |
... guidance_scale=7.5, | |
... cross_attention_kwargs={"scale": 0.5}, | |
... ).images[0] | |
# ์์ ํ ํ์ธํ๋๋ LoRA ๋ชจ๋ธ์ ๊ฐ์ค์น ์ฌ์ฉ | |
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0] | |
>>> image.save("bucket-dog.png") | |
``` |