Spaces:
No application file
No application file
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Text-to-image | |
<Tip warning={true}> | |
text-to-image ํ์ธํ๋ ์คํฌ๋ฆฝํธ๋ experimental ์ํ์ ๋๋ค. ๊ณผ์ ํฉํ๊ธฐ ์ฝ๊ณ ์น๋ช ์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ์ ๋ฌธ์ ์ ๋ถ๋ชํ๊ธฐ ์ฝ์ต๋๋ค. ์์ฒด ๋ฐ์ดํฐ์ ์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. | |
</Tip> | |
Stable Diffusion๊ณผ ๊ฐ์ text-to-image ๋ชจ๋ธ์ ํ ์คํธ ํ๋กฌํํธ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ์ด ๊ฐ์ด๋๋ PyTorch ๋ฐ Flax๋ฅผ ์ฌ์ฉํ์ฌ ์์ฒด ๋ฐ์ดํฐ์ ์์ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) ๋ชจ๋ธ๋ก ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ฐ์ด๋์ ์ฌ์ฉ๋ text-to-image ํ์ธํ๋์ ์ํ ๋ชจ๋ ํ์ต ์คํฌ๋ฆฝํธ์ ๊ด์ฌ์ด ์๋ ๊ฒฝ์ฐ ์ด [๋ฆฌํฌ์งํ ๋ฆฌ](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)์์ ์์ธํ ์ฐพ์ ์ ์์ต๋๋ค. | |
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์, ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต dependency๋ค์ ์ค์นํด์ผ ํฉ๋๋ค: | |
```bash | |
pip install git+https://github.com/huggingface/diffusers.git | |
pip install -U -r requirements.txt | |
``` | |
๊ทธ๋ฆฌ๊ณ [๐คAccelerate](https://github.com/huggingface/accelerate/) ํ๊ฒฝ์ ์ด๊ธฐํํฉ๋๋ค: | |
```bash | |
accelerate config | |
``` | |
๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ ํ ๊ฒฝ์ฐ, ์ด ๋จ๊ณ๋ฅผ ์ํํ ํ์๊ฐ ์์ต๋๋ค. ๋์ , ๋ก์ปฌ ์ฒดํฌ์์ ๊ฒฝ๋ก๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช ์ํ ์ ์์ผ๋ฉฐ ๊ฑฐ๊ธฐ์์ ๋ก๋๋ฉ๋๋ค. | |
### ํ๋์จ์ด ์๊ตฌ ์ฌํญ | |
`gradient_checkpointing` ๋ฐ `mixed_precision`์ ์ฌ์ฉํ๋ฉด ๋จ์ผ 24GB GPU์์ ๋ชจ๋ธ์ ํ์ธํ๋ํ ์ ์์ต๋๋ค. ๋ ๋์ `batch_size`์ ๋ ๋น ๋ฅธ ํ๋ จ์ ์ํด์๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. TPU ๋๋ GPU์์ ํ์ธํ๋์ ์ํด JAX๋ Flax๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์๋](#flax-jax-finetuning)๋ฅผ ์ฐธ์กฐํ์ธ์. | |
xFormers๋ก memory efficient attention์ ํ์ฑํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํจ์ฌ ๋ ์ค์ผ ์ ์์ต๋๋ค. [xFormers๊ฐ ์ค์น](./optimization/xformers)๋์ด ์๋์ง ํ์ธํ๊ณ `--enable_xformers_memory_efficient_attention`๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช ์ํฉ๋๋ค. | |
xFormers๋ Flax์ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
## Hub์ ๋ชจ๋ธ ์ ๋ก๋ํ๊ธฐ | |
ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ์ ์ฅํฉ๋๋ค: | |
```bash | |
--push_to_hub | |
``` | |
## ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ | |
ํ์ต ์ค ๋ฐ์ํ ์ ์๋ ์ผ์ ๋๋นํ์ฌ ์ ๊ธฐ์ ์ผ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํด ๋๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ ค๋ฉด ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ๋ช ์ํฉ๋๋ค. | |
```bash | |
--checkpointing_steps=500 | |
``` | |
500์คํ ๋ง๋ค ์ ์ฒด ํ์ต state๊ฐ 'output_dir'์ ํ์ ํด๋์ ์ ์ฅ๋ฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ 'checkpoint-'์ ์ง๊ธ๊น์ง ํ์ต๋ step ์์ ๋๋ค. ์๋ฅผ ๋ค์ด 'checkpoint-1500'์ 1500 ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์ ๋๋ค. | |
ํ์ต์ ์ฌ๊ฐํ๊ธฐ ์ํด ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช ์ํ๊ณ ์ฌ๊ฐํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ํ์ญ์์ค. ์๋ฅผ ๋ค์ด ๋ค์ ์ธ์๋ 1500๊ฐ์ ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํฉ๋๋ค. | |
```bash | |
--resume_from_checkpoint="checkpoint-1500" | |
``` | |
## ํ์ธํ๋ | |
<frameworkcontent> | |
<pt> | |
๋ค์๊ณผ ๊ฐ์ด [Pokรฉmon BLIP ์บก์ ](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) ๋ฐ์ดํฐ์ ์์ ํ์ธํ๋ ์คํ์ ์ํด [PyTorch ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์คํํฉ๋๋ค: | |
```bash | |
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
export dataset_name="lambdalabs/pokemon-blip-captions" | |
accelerate launch train_text_to_image.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--dataset_name=$dataset_name \ | |
--use_ema \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=4 \ | |
--gradient_checkpointing \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--lr_scheduler="constant" --lr_warmup_steps=0 \ | |
--output_dir="sd-pokemon-model" | |
``` | |
์์ฒด ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์ ์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์ ์ ํ๋ธ์ ์ ๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค. | |
์ฌ์ฉ์ ์ปค์คํ loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์ ์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ๊ณผ `OUTPUT_DIR`์์ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ๋ณด์ฌ์ค๋๋ค: | |
```bash | |
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
export TRAIN_DIR="path_to_your_dataset" | |
export OUTPUT_DIR="path_to_save_model" | |
accelerate launch train_text_to_image.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$TRAIN_DIR \ | |
--use_ema \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=4 \ | |
--gradient_checkpointing \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--lr_scheduler="constant" --lr_warmup_steps=0 \ | |
--output_dir=${OUTPUT_DIR} | |
``` | |
</pt> | |
<jax> | |
[@duongna211](https://github.com/duongna21)์ ๊ธฐ์ฌ๋ก, Flax๋ฅผ ์ฌ์ฉํด TPU ๋ฐ GPU์์ Stable Diffusion ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์์ต๋๋ค. ์ด๋ TPU ํ๋์จ์ด์์ ๋งค์ฐ ํจ์จ์ ์ด์ง๋ง GPU์์๋ ํ๋ฅญํ๊ฒ ์๋ํฉ๋๋ค. Flax ํ์ต ์คํฌ๋ฆฝํธ๋ gradient checkpointing๋ gradient accumulation๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์์ง ์ง์ํ์ง ์์ผ๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU ๋๋ TPU v3๊ฐ ํ์ํฉ๋๋ค. | |
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์ ์๊ตฌ ์ฌํญ์ด ์ค์น๋์ด ์๋์ง ํ์ธํ์ญ์์ค: | |
```bash | |
pip install -U -r requirements_flax.txt | |
``` | |
๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ด [Flax ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)๋ฅผ ์คํํ ์ ์์ต๋๋ค. | |
```bash | |
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | |
export dataset_name="lambdalabs/pokemon-blip-captions" | |
python train_text_to_image_flax.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--dataset_name=$dataset_name \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--output_dir="sd-pokemon-model" | |
``` | |
์์ฒด ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์ ์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์ ์ ํ๋ธ์ ์ ๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค. | |
์ฌ์ฉ์ ์ปค์คํ loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์ ์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค: | |
```bash | |
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | |
export TRAIN_DIR="path_to_your_dataset" | |
python train_text_to_image_flax.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$TRAIN_DIR \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--output_dir="sd-pokemon-model" | |
``` | |
</jax> | |
</frameworkcontent> | |
## LoRA | |
Text-to-image ๋ชจ๋ธ ํ์ธํ๋์ ์ํด, ๋๊ท๋ชจ ๋ชจ๋ธ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํ ํ์ธํ๋ ๊ธฐ์ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [LoRA ํ์ต](lora#text-to-image) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์. | |
## ์ถ๋ก | |
ํ๋ธ์ ๋ชจ๋ธ ๊ฒฝ๋ก ๋๋ ๋ชจ๋ธ ์ด๋ฆ์ [`StableDiffusionPipeline`]์ ์ ๋ฌํ์ฌ ์ถ๋ก ์ ์ํด ํ์ธ ํ๋๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค: | |
<frameworkcontent> | |
<pt> | |
```python | |
from diffusers import StableDiffusionPipeline | |
model_path = "path_to_saved_model" | |
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
image = pipe(prompt="yoda").images[0] | |
image.save("yoda-pokemon.png") | |
``` | |
</pt> | |
<jax> | |
```python | |
import jax | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from diffusers import FlaxStableDiffusionPipeline | |
model_path = "path_to_saved_model" | |
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
prompt = "yoda pokemon" | |
prng_seed = jax.random.PRNGKey(0) | |
num_inference_steps = 50 | |
num_samples = jax.device_count() | |
prompt = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
# shard inputs and rng | |
params = replicate(params) | |
prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
prompt_ids = shard(prompt_ids) | |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
image.save("yoda-pokemon.png") | |
``` | |
</jax> | |
</frameworkcontent> |