Spaces:
No application file
No application file
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Textual-Inversion | |
[[open-in-colab]] | |
[textual-inversion](https://arxiv.org/abs/2208.01618)์ ์์์ ์์ ์ด๋ฏธ์ง์์ ์๋ก์ด ์ฝ์ ํธ๋ฅผ ํฌ์ฐฉํ๋ ๊ธฐ๋ฒ์ ๋๋ค. ์ด ๊ธฐ์ ์ ์๋ [Latent Diffusion](https://github.com/CompVis/latent-diffusion)์์ ์์ฐ๋์์ง๋ง, ์ดํ [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion)๊ณผ ๊ฐ์ ์ ์ฌํ ๋ค๋ฅธ ๋ชจ๋ธ์๋ ์ ์ฉ๋์์ต๋๋ค. ํ์ต๋ ์ฝ์ ํธ๋ text-to-image ํ์ดํ๋ผ์ธ์์ ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ๋ ์ ์ ์ดํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด ๋ชจ๋ธ์ ํ ์คํธ ์ธ์ฝ๋์ ์๋ฒ ๋ฉ ๊ณต๊ฐ์์ ์๋ก์ด '๋จ์ด'๋ฅผ ํ์ตํ์ฌ ๊ฐ์ธํ๋ ์ด๋ฏธ์ง ์์ฑ์ ์ํ ํ ์คํธ ํ๋กฌํํธ ๋ด์์ ์ฌ์ฉ๋ฉ๋๋ค. | |
![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG) | |
<small>By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation <a href="https://github.com/rinongal/textual_inversion">(image source)</a>.</small> | |
์ด ๊ฐ์ด๋์์๋ textual-inversion์ผ๋ก [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค. ์ด ๊ฐ์ด๋์์ ์ฌ์ฉ๋ ๋ชจ๋ textual-inversion ํ์ต ์คํฌ๋ฆฝํธ๋ [์ฌ๊ธฐ](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)์์ ํ์ธํ ์ ์์ต๋๋ค. ๋ด๋ถ์ ์ผ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ์์ธํ ์ดํด๋ณด๊ณ ์ถ์ผ์๋ค๋ฉด ํด๋น ๋งํฌ๋ฅผ ์ฐธ์กฐํด์ฃผ์๊ธฐ ๋ฐ๋๋๋ค. | |
<Tip> | |
[Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)์๋ ์ปค๋ฎค๋ํฐ์์ ์ ์ํ ํ์ต๋ textual-inversion ๋ชจ๋ธ๋ค์ด ์์ต๋๋ค. ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ๋ ๋ง์ ์ฝ์ ํธ๋ค์ด ์ถ๊ฐ๋์ด ์ ์ฉํ ๋ฆฌ์์ค๋ก ์ฑ์ฅํ ๊ฒ์ ๋๋ค! | |
</Tip> | |
์์ํ๊ธฐ ์ ์ ํ์ต์ ์ํ ์์กด์ฑ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์ค์นํด์ผ ํฉ๋๋ค: | |
```bash | |
pip install diffusers accelerate transformers | |
``` | |
์์กด์ฑ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์ค์น๊ฐ ์๋ฃ๋๋ฉด, [๐คAccelerate](https://github.com/huggingface/accelerate/) ํ๊ฒฝ์ ์ด๊ธฐํ์ํต๋๋ค. | |
```bash | |
accelerate config | |
``` | |
๋ณ๋์ ์ค์ ์์ด, ๊ธฐ๋ณธ ๐คAccelerate ํ๊ฒฝ์ ์ค์ ํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์ธ์: | |
```bash | |
accelerate config default | |
``` | |
๋๋ ์ฌ์ฉ ์ค์ธ ํ๊ฒฝ์ด ๋ ธํธ๋ถ๊ณผ ๊ฐ์ ๋ํํ ์ ธ์ ์ง์ํ์ง ์๋๋ค๋ฉด, ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉํ ์ ์์ต๋๋ค: | |
```py | |
from accelerate.utils import write_basic_config | |
write_basic_config() | |
``` | |
๋ง์ง๋ง์ผ๋ก, Memory-Efficient Attention์ ํตํด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ธฐ ์ํด [xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers)๋ฅผ ์ค์นํฉ๋๋ค. xFormers๋ฅผ ์ค์นํ ํ, ํ์ต ์คํฌ๋ฆฝํธ์ `--enable_xformers_memory_efficient_attention` ์ธ์๋ฅผ ์ถ๊ฐํฉ๋๋ค. xFormers๋ Flax์์ ์ง์๋์ง ์์ต๋๋ค. | |
## ํ๋ธ์ ๋ชจ๋ธ ์ ๋ก๋ํ๊ธฐ | |
๋ชจ๋ธ์ ํ๋ธ์ ์ ์ฅํ๋ ค๋ฉด, ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. | |
```bash | |
--push_to_hub | |
``` | |
## ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ | |
ํ์ต์ค์ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ๊ธฐ์ ์ผ๋ก ์ ์ฅํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ์ด๋ค ์ด์ ๋ก๋ ํ์ต์ด ์ค๋จ๋ ๊ฒฝ์ฐ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์ ํ์ต์ ๋ค์ ์์ํ ์ ์์ต๋๋ค. ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ ๋ฌํ๋ฉด 500๋จ๊ณ๋ง๋ค ์ ์ฒด ํ์ต ์ํ๊ฐ `output_dir`์ ํ์ ํด๋์ ์ฒดํฌํฌ์ธํธ๋ก์ ์ ์ฅ๋ฉ๋๋ค. | |
```bash | |
--checkpointing_steps=500 | |
``` | |
์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์ ํ์ต์ ์ฌ๊ฐํ๋ ค๋ฉด, ํ์ต ์คํฌ๋ฆฝํธ์ ์ฌ๊ฐํ ํน์ ์ฒดํฌํฌ์ธํธ์ ๋ค์ ์ธ์๋ฅผ ์ ๋ฌํ์ธ์. | |
```bash | |
--resume_from_checkpoint="checkpoint-1500" | |
``` | |
## ํ์ธ ํ๋ | |
ํ์ต์ฉ ๋ฐ์ดํฐ์ ์ผ๋ก [๊ณ ์์ด ์ฅ๋๊ฐ ๋ฐ์ดํฐ์ ](https://huggingface.co/datasets/diffusers/cat_toy_example)์ ๋ค์ด๋ก๋ํ์ฌ ๋๋ ํ ๋ฆฌ์ ์ ์ฅํ์ธ์. ์ฌ๋ฌ๋ถ๋ง์ ๊ณ ์ ํ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๊ณ ์ ํ๋ค๋ฉด, [ํ์ต์ฉ ๋ฐ์ดํฐ์ ๋ง๋ค๊ธฐ](https://huggingface.co/docs/diffusers/training/create_dataset) ๊ฐ์ด๋๋ฅผ ์ดํด๋ณด์๊ธฐ ๋ฐ๋๋๋ค. | |
```py | |
from huggingface_hub import snapshot_download | |
local_dir = "./cat" | |
snapshot_download( | |
"diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes" | |
) | |
``` | |
๋ชจ๋ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ ID(๋๋ ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ํฌํจ๋ ๋๋ ํฐ๋ฆฌ ๊ฒฝ๋ก)๋ฅผ `MODEL_NAME` ํ๊ฒฝ ๋ณ์์ ํ ๋นํ๊ณ , ํด๋น ๊ฐ์ [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) ์ธ์์ ์ ๋ฌํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ฏธ์ง๊ฐ ํฌํจ๋ ๋๋ ํฐ๋ฆฌ ๊ฒฝ๋ก๋ฅผ `DATA_DIR` ํ๊ฒฝ ๋ณ์์ ํ ๋นํฉ๋๋ค. | |
์ด์ [ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)๋ฅผ ์คํํ ์ ์์ต๋๋ค. ์คํฌ๋ฆฝํธ๋ ๋ค์ ํ์ผ์ ์์ฑํ๊ณ ๋ฆฌํฌ์งํ ๋ฆฌ์ ์ ์ฅํฉ๋๋ค. | |
- `learned_embeds.bin` | |
- `token_identifier.txt` | |
- `type_of_concept.txt`. | |
<Tip> | |
๐กV100 GPU 1๊ฐ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ์ฒด ํ์ต์๋ ์ต๋ 1์๊ฐ์ด ๊ฑธ๋ฆฝ๋๋ค. ํ์ต์ด ์๋ฃ๋๊ธฐ๋ฅผ ๊ธฐ๋ค๋ฆฌ๋ ๋์ ๊ถ๊ธํ ์ ์ด ์์ผ๋ฉด ์๋ ์น์ ์์ [textual-inversion์ด ์ด๋ป๊ฒ ์๋ํ๋์ง](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) ์์ ๋กญ๊ฒ ํ์ธํ์ธ์ ! | |
</Tip> | |
<frameworkcontent> | |
<pt> | |
```bash | |
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | |
export DATA_DIR="./cat" | |
accelerate launch textual_inversion.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$DATA_DIR \ | |
--learnable_property="object" \ | |
--placeholder_token="<cat-toy>" --initializer_token="toy" \ | |
--resolution=512 \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=4 \ | |
--max_train_steps=3000 \ | |
--learning_rate=5.0e-04 --scale_lr \ | |
--lr_scheduler="constant" \ | |
--lr_warmup_steps=0 \ | |
--output_dir="textual_inversion_cat" \ | |
--push_to_hub | |
``` | |
<Tip> | |
๐กํ์ต ์ฑ๋ฅ์ ์ฌ๋ฆฌ๊ธฐ ์ํด, ํ๋ ์ด์คํ๋ ํ ํฐ(`<cat-toy>`)์ (๋จ์ผํ ์๋ฒ ๋ฉ ๋ฒกํฐ๊ฐ ์๋) ๋ณต์์ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ํํํ๋ ๊ฒ ์ญ์ ๊ณ ๋ คํ ์์ต๋๋ค. ์ด๋ฌํ ํธ๋ฆญ์ด ๋ชจ๋ธ์ด ๋ณด๋ค ๋ณต์กํ ์ด๋ฏธ์ง์ ์คํ์ผ(์์ ๋งํ ์ฝ์ ํธ)์ ๋ ์ ์บก์ฒํ๋ ๋ฐ ๋์์ด ๋ ์ ์์ต๋๋ค. ๋ณต์์ ์๋ฒ ๋ฉ ๋ฒกํฐ ํ์ต์ ํ์ฑํํ๋ ค๋ฉด ๋ค์ ์ต์ ์ ์ ๋ฌํ์ญ์์ค. | |
```bash | |
--num_vectors=5 | |
``` | |
</Tip> | |
</pt> | |
<jax> | |
TPU์ ์ก์ธ์คํ ์ ์๋ ๊ฒฝ์ฐ, [Flax ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)๋ฅผ ์ฌ์ฉํ์ฌ ๋ ๋น ๋ฅด๊ฒ ๋ชจ๋ธ์ ํ์ต์์ผ๋ณด์ธ์. (๋ฌผ๋ก GPU์์๋ ์๋ํฉ๋๋ค.) ๋์ผํ ์ค์ ์์ Flax ํ์ต ์คํฌ๋ฆฝํธ๋ PyTorch ํ์ต ์คํฌ๋ฆฝํธ๋ณด๋ค ์ต์ 70% ๋ ๋นจ๋ผ์ผ ํฉ๋๋ค! โก๏ธ | |
์์ํ๊ธฐ ์์ Flax์ ๋ํ ์์กด์ฑ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์ค์นํด์ผ ํฉ๋๋ค. | |
```bash | |
pip install -U -r requirements_flax.txt | |
``` | |
๋ชจ๋ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ ID(๋๋ ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ํฌํจ๋ ๋๋ ํฐ๋ฆฌ ๊ฒฝ๋ก)๋ฅผ `MODEL_NAME` ํ๊ฒฝ ๋ณ์์ ํ ๋นํ๊ณ , ํด๋น ๊ฐ์ [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) ์ธ์์ ์ ๋ฌํฉ๋๋ค. | |
๊ทธ๋ฐ ๋ค์ [ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)๋ฅผ ์์ํ ์ ์์ต๋๋ค. | |
```bash | |
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | |
export DATA_DIR="./cat" | |
python textual_inversion_flax.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$DATA_DIR \ | |
--learnable_property="object" \ | |
--placeholder_token="<cat-toy>" --initializer_token="toy" \ | |
--resolution=512 \ | |
--train_batch_size=1 \ | |
--max_train_steps=3000 \ | |
--learning_rate=5.0e-04 --scale_lr \ | |
--output_dir="textual_inversion_cat" \ | |
--push_to_hub | |
``` | |
</jax> | |
</frameworkcontent> | |
### ์ค๊ฐ ๋ก๊น | |
๋ชจ๋ธ์ ํ์ต ์งํ ์ํฉ์ ์ถ์ ํ๋ ๋ฐ ๊ด์ฌ์ด ์๋ ๊ฒฝ์ฐ, ํ์ต ๊ณผ์ ์์ ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ์ ์ฅํ ์ ์์ต๋๋ค. ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ์ค๊ฐ ๋ก๊น ์ ํ์ฑํํฉ๋๋ค. | |
- `validation_prompt` : ์ํ์ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋๋ ํ๋กฌํํธ(๊ธฐ๋ณธ๊ฐ์ `None`์ผ๋ก ์ค์ ๋๋ฉฐ, ์ด ๋ ์ค๊ฐ ๋ก๊น ์ ๋นํ์ฑํ๋จ) | |
- `num_validation_images` : ์์ฑํ ์ํ ์ด๋ฏธ์ง ์ | |
- `validation_steps` : `validation_prompt`๋ก๋ถํฐ ์ํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ ์คํ ์ ์ | |
```bash | |
--validation_prompt="A <cat-toy> backpack" | |
--num_validation_images=4 | |
--validation_steps=100 | |
``` | |
## ์ถ๋ก | |
๋ชจ๋ธ์ ํ์ตํ ํ์๋, ํด๋น ๋ชจ๋ธ์ [`StableDiffusionPipeline`]์ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
textual-inversion ์คํฌ๋ฆฝํธ๋ ๊ธฐ๋ณธ์ ์ผ๋ก textual-inversion์ ํตํด ์ป์ด์ง ์๋ฒ ๋ฉ ๋ฒกํฐ๋ง์ ์ ์ฅํฉ๋๋ค. ํด๋น ์๋ฒ ๋ฉ ๋ฒกํฐ๋ค์ ํ ์คํธ ์ธ์ฝ๋์ ์๋ฒ ๋ฉ ํ๋ ฌ์ ์ถ๊ฐ๋์ด ์์ต์ต๋๋ค. | |
<frameworkcontent> | |
<pt> | |
<Tip> | |
๐ก ์ปค๋ฎค๋ํฐ๋ [sd-concepts-library](https://huggingface.co/sd-concepts-library) ๋ผ๋ ๋๊ท๋ชจ์ textual-inversion ์๋ฒ ๋ฉ ๋ฒกํฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ง๋ค์์ต๋๋ค. textual-inversion ์๋ฒ ๋ฉ์ ๋ฐ๋ฐ๋ฅ๋ถํฐ ํ์ตํ๋ ๋์ , ํด๋น ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ณธ์ธ์ด ์ฐพ๋ textual-inversion ์๋ฒ ๋ฉ์ด ์ด๋ฏธ ์ถ๊ฐ๋์ด ์์ง ์์์ง๋ฅผ ํ์ธํ๋ ๊ฒ๋ ์ข์ ๋ฐฉ๋ฒ์ด ๋ ๊ฒ ๊ฐ์ต๋๋ค. | |
</Tip> | |
textual-inversion ์๋ฒ ๋ฉ ๋ฒกํฐ์ ๋ถ๋ฌ์ค๊ธฐ ์ํด์๋, ๋จผ์ ํด๋น ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ํ์ตํ ๋ ์ฌ์ฉํ ๋ชจ๋ธ์ ๋ถ๋ฌ์์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์๋ [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/docs/diffusers/training/runwayml/stable-diffusion-v1-5) ๋ชจ๋ธ์ด ์ฌ์ฉ๋์๋ค๊ณ ๊ฐ์ ํ๊ณ ๋ถ๋ฌ์ค๊ฒ ์ต๋๋ค. | |
```python | |
from diffusers import StableDiffusionPipeline | |
import torch | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") | |
``` | |
๋ค์์ผ๋ก `TextualInversionLoaderMixin.load_textual_inversion` ํจ์๋ฅผ ํตํด, textual-inversion ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ๋ถ๋ฌ์์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์ ์ฐ๋ฆฌ๋ ์ด์ ์ `<cat-toy>` ์์ ์ ์๋ฒ ๋ฉ์ ๋ถ๋ฌ์ฌ ๊ฒ์ ๋๋ค. | |
```python | |
pipe.load_textual_inversion("sd-concepts-library/cat-toy") | |
``` | |
์ด์ ํ๋ ์ด์คํ๋ ํ ํฐ(`<cat-toy>`)์ด ์ ๋์ํ๋์ง๋ฅผ ํ์ธํ๋ ํ์ดํ๋ผ์ธ์ ์คํํ ์ ์์ต๋๋ค. | |
```python | |
prompt = "A <cat-toy> backpack" | |
image = pipe(prompt, num_inference_steps=50).images[0] | |
image.save("cat-backpack.png") | |
``` | |
`TextualInversionLoaderMixin.load_textual_inversion`์ Diffusers ํ์์ผ๋ก ์ ์ฅ๋ ํ ์คํธ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ๋ก๋ํ ์ ์์ ๋ฟ๋ง ์๋๋ผ, [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) ํ์์ผ๋ก ์ ์ฅ๋ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ ๋ก๋ํ ์ ์์ต๋๋ค. ์ด๋ ๊ฒ ํ๋ ค๋ฉด, ๋จผ์ [civitAI](https://civitai.com/models/3036?modelVersionId=8387)์์ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ๋ค์ด๋ก๋ํ ๋ค์ ๋ก์ปฌ์์ ๋ถ๋ฌ์์ผ ํฉ๋๋ค. | |
```python | |
pipe.load_textual_inversion("./charturnerv2.pt") | |
``` | |
</pt> | |
<jax> | |
ํ์ฌ Flax์ ๋ํ `load_textual_inversion` ํจ์๋ ์์ต๋๋ค. ๋ฐ๋ผ์ ํ์ต ํ textual-inversion ์๋ฒ ๋ฉ ๋ฒกํฐ๊ฐ ๋ชจ๋ธ์ ์ผ๋ถ๋ก์ ์ ์ฅ๋์๋์ง๋ฅผ ํ์ธํด์ผ ํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์์ ๋ค๋ฅธ Flax ๋ชจ๋ธ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์คํํ ์ ์์ต๋๋ค. | |
```python | |
import jax | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from diffusers import FlaxStableDiffusionPipeline | |
model_path = "path-to-your-trained-model" | |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
prompt = "A <cat-toy> backpack" | |
prng_seed = jax.random.PRNGKey(0) | |
num_inference_steps = 50 | |
num_samples = jax.device_count() | |
prompt = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
# shard inputs and rng | |
params = replicate(params) | |
prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
prompt_ids = shard(prompt_ids) | |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
image.save("cat-backpack.png") | |
``` | |
</jax> | |
</frameworkcontent> | |
## ์๋ ๋ฐฉ์ | |
![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG) | |
<small>Architecture overview from the Textual Inversion <a href="https://textual-inversion.github.io/">blog post.</a></small> | |
์ผ๋ฐ์ ์ผ๋ก ํ ์คํธ ํ๋กฌํํธ๋ ๋ชจ๋ธ์ ์ ๋ฌ๋๊ธฐ ์ ์ ์๋ฒ ๋ฉ์ผ๋ก ํ ํฐํ๋ฉ๋๋ค. textual-inversion์ ๋น์ทํ ์์ ์ ์ํํ์ง๋ง, ์ ๋ค์ด์ด๊ทธ๋จ์ ํน์ ํ ํฐ `S*`๋ก๋ถํฐ ์๋ก์ด ํ ํฐ ์๋ฒ ๋ฉ `v*`๋ฅผ ํ์ตํฉ๋๋ค. ๋ชจ๋ธ์ ์์ํ์ ๋ํจ์ ๋ชจ๋ธ์ ์กฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋๋ฉฐ, ๋ํจ์ ๋ชจ๋ธ์ด ๋จ ๋ช ๊ฐ์ ์์ ์ด๋ฏธ์ง์์ ์ ์ํ๊ณ ์๋ก์ด ์ฝ์ ํธ๋ฅผ ์ดํดํ๋ ๋ฐ ๋์์ ์ค๋๋ค. | |
์ด๋ฅผ ์ํด textual-inversion์ ์ ๋๋ ์ดํฐ ๋ชจ๋ธ๊ณผ ํ์ต์ฉ ์ด๋ฏธ์ง์ ๋ ธ์ด์ฆ ๋ฒ์ ์ ์ฌ์ฉํฉ๋๋ค. ์ ๋๋ ์ดํฐ๋ ๋ ธ์ด์ฆ๊ฐ ์ ์ ๋ฒ์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํ๋ ค๊ณ ์๋ํ๋ฉฐ ํ ํฐ ์๋ฒ ๋ฉ `v*`์ ์ ๋๋ ์ดํฐ์ ์ฑ๋ฅ์ ๋ฐ๋ผ ์ต์ ํ๋ฉ๋๋ค. ํ ํฐ ์๋ฒ ๋ฉ์ด ์๋ก์ด ์ฝ์ ํธ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ํฌ์ฐฉํ๋ฉด ๋ํจ์ ๋ชจ๋ธ์ ๋ ์ ์ฉํ ์ ๋ณด๋ฅผ ์ ๊ณตํ๊ณ ๋ ธ์ด์ฆ๊ฐ ์ ์ ๋ ์ ๋ช ํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค. ์ด๋ฌํ ์ต์ ํ ํ๋ก์ธ์ค๋ ์ผ๋ฐ์ ์ผ๋ก ๋ค์ํ ํ๋กฌํํธ์ ์ด๋ฏธ์ง์ ์์ฒ ๋ฒ์ ๋ ธ์ถ๋จ์ผ๋ก์จ ์ด๋ฃจ์ด์ง๋๋ค. | |