kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
|
raw
history blame
12.5 kB
<!--Copyright 2023 Custom Diffusion authors The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ์ปค์Šคํ…€ Diffusion ํ•™์Šต ์˜ˆ์ œ
[์ปค์Šคํ…€ Diffusion](https://arxiv.org/abs/2212.04488)์€ ํ”ผ์‚ฌ์ฒด์˜ ์ด๋ฏธ์ง€ ๋ช‡ ์žฅ(4~5์žฅ)๋งŒ ์ฃผ์–ด์ง€๋ฉด Stable Diffusion์ฒ˜๋Ÿผ text-to-image ๋ชจ๋ธ์„ ์ปค์Šคํ„ฐ๋งˆ์ด์ง•ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
'train_custom_diffusion.py' ์Šคํฌ๋ฆฝํŠธ๋Š” ํ•™์Šต ๊ณผ์ •์„ ๊ตฌํ˜„ํ•˜๊ณ  ์ด๋ฅผ Stable Diffusion์— ๋งž๊ฒŒ ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.
์ด ๊ต์œก ์‚ฌ๋ก€๋Š” [Nupur Kumari](https://nupurkmr9.github.io/)๊ฐ€ ์ œ๊ณตํ•˜์˜€์Šต๋‹ˆ๋‹ค. (Custom Diffusion์˜ ์ €์ž ์ค‘ ํ•œ๋ช…).
## ๋กœ์ปฌ์—์„œ PyTorch๋กœ ์‹คํ–‰ํ•˜๊ธฐ
### Dependencies ์„ค์น˜ํ•˜๊ธฐ
์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์— ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ํ•™์Šต dependencies๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
**์ค‘์š”**
์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ์˜ ์ตœ์‹  ๋ฒ„์ „์„ ์„ฑ๊ณต์ ์œผ๋กœ ์‹คํ–‰ํ•˜๋ ค๋ฉด **์†Œ์Šค๋กœ๋ถ€ํ„ฐ ์„ค์น˜**ํ•˜๋Š” ๊ฒƒ์„ ๋งค์šฐ ๊ถŒ์žฅํ•˜๋ฉฐ, ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ž์ฃผ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋งŒํผ ์ผ๋ถ€ ์˜ˆ์ œ๋ณ„ ์š”๊ตฌ ์‚ฌํ•ญ์„ ์„ค์น˜ํ•˜๊ณ  ์„ค์น˜๋ฅผ ์ตœ์‹  ์ƒํƒœ๋กœ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ์ƒˆ ๊ฐ€์ƒ ํ™˜๊ฒฝ์—์„œ ๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
[example folder](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion)๋กœ cdํ•˜์—ฌ ์ด๋™ํ•˜์„ธ์š”.
```
cd examples/custom_diffusion
```
์ด์ œ ์‹คํ–‰
```bash
pip install -r requirements.txt
pip install clip-retrieval
```
๊ทธ๋ฆฌ๊ณ  [๐Ÿค—Accelerate](https://github.com/huggingface/accelerate/) ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”:
```bash
accelerate config
```
๋˜๋Š” ์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋Œ€ํ•œ ์งˆ๋ฌธ์— ๋‹ตํ•˜์ง€ ์•Š๊ณ  ๊ธฐ๋ณธ ๊ฐ€์† ๊ตฌ์„ฑ์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•˜์„ธ์š”.
```bash
accelerate config default
```
๋˜๋Š” ์‚ฌ์šฉ ์ค‘์ธ ํ™˜๊ฒฝ์ด ๋Œ€ํ™”ํ˜• ์…ธ์„ ์ง€์›ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ(์˜ˆ: jupyter notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
### ๊ณ ์–‘์ด ์˜ˆ์ œ ๐Ÿ˜บ
์ด์ œ ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. [์—ฌ๊ธฐ](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip)์—์„œ ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์••์ถ•์„ ํ’‰๋‹ˆ๋‹ค. ์ง์ ‘ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด [ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑํ•˜๊ธฐ](create_dataset) ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.
๋˜ํ•œ 'clip-retrieval'์„ ์‚ฌ์šฉํ•˜์—ฌ 200๊ฐœ์˜ ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ˆ˜์ง‘ํ•˜๊ณ , regularization์œผ๋กœ์„œ ์ด๋ฅผ ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹์˜ ํƒ€๊ฒŸ ์ด๋ฏธ์ง€์™€ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์ฃผ์–ด์ง„ ํƒ€๊ฒŸ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ํ”Œ๋ž˜๊ทธ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด `prior_loss_weight=1.`๋กœ `prior_preservation`, `real_prior` regularization์„ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
ํด๋ž˜์Šค_ํ”„๋กฌํ”„ํŠธ`๋Š” ๋Œ€์ƒ ์ด๋ฏธ์ง€์™€ ๋™์ผํ•œ ์นดํ…Œ๊ณ ๋ฆฌ ์ด๋ฆ„์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜์ง‘๋œ ์‹ค์ œ ์ด๋ฏธ์ง€์—๋Š” `class_prompt`์™€ ์œ ์‚ฌํ•œ ํ…์ŠคํŠธ ์บก์…˜์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฒ€์ƒ‰๋œ ์ด๋ฏธ์ง€๋Š” `class_data_dir`์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ regularization์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด `real_prior`๋ฅผ ๋น„ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ˆ˜์ง‘ํ•˜๋ ค๋ฉด ํ›ˆ๋ จ ์ „์— ์ด ๋ช…๋ น์„ ๋จผ์ € ์‚ฌ์šฉํ•˜์‹ญ์‹œ์˜ค.
```bash
pip install clip-retrieval
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
```
**___์ฐธ๊ณ : [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ 'ํ•ด์ƒ๋„'๋ฅผ 768๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.___**
์Šคํฌ๋ฆฝํŠธ๋Š” ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ์™€ `pytorch_custom_diffusion_weights.bin` ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜์—ฌ ์ €์žฅ์†Œ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
export INSTANCE_DIR="./data/cat"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="cat" --num_class_images=200 \
--instance_prompt="photo of a <new1> cat" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token "<new1>" \
--push_to_hub
```
**๋” ๋‚ฎ์€ VRAM ์š”๊ตฌ ์‚ฌํ•ญ(GPU๋‹น 16GB)์œผ๋กœ ๋” ๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จํ•˜๋ ค๋ฉด `--enable_xformers_memory_efficient_attention`์„ ์‚ฌ์šฉํ•˜์„ธ์š”. ์„ค์น˜ ๋ฐฉ๋ฒ•์€ [๊ฐ€์ด๋“œ](https://github.com/facebookresearch/xformers)๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.**
๊ฐ€์ค‘์น˜ ๋ฐ ํŽธํ–ฅ(`wandb`)์„ ์‚ฌ์šฉํ•˜์—ฌ ์‹คํ—˜์„ ์ถ”์ ํ•˜๊ณ  ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด(๊ฐ•๋ ฅํžˆ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค) ๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ๋”ฐ๋ฅด์„ธ์š”:
* `wandb` ์„ค์น˜: `pip install wandb`.
* ๋กœ๊ทธ์ธ : `wandb login`.
* ๊ทธ๋Ÿฐ ๋‹ค์Œ ํŠธ๋ ˆ์ด๋‹์„ ์‹œ์ž‘ํ•˜๋Š” ๋™์•ˆ `validation_prompt`๋ฅผ ์ง€์ •ํ•˜๊ณ  `report_to`๋ฅผ `wandb`๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ด€๋ จ ์ธ์ˆ˜๋ฅผ ๊ตฌ์„ฑํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:
* `num_validation_images`
* `validation_steps`
```bash
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="cat" --num_class_images=200 \
--instance_prompt="photo of a <new1> cat" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token "<new1>" \
--validation_prompt="<new1> cat sitting in a bucket" \
--report_to="wandb" \
--push_to_hub
```
๋‹ค์Œ์€ [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau)์˜ ์˜ˆ์‹œ์ด๋ฉฐ, ์—ฌ๋Ÿฌ ํ•™์Šต ์„ธ๋ถ€ ์ •๋ณด์™€ ํ•จ๊ป˜ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋“ค์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
`--push_to_hub`๋ฅผ ์ง€์ •ํ•˜๋ฉด ํ•™์Šต๋œ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ํ—ˆ๊น… ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์˜ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์— ํ‘ธ์‹œ๋ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์€ [์˜ˆ์ œ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ](https://huggingface.co/sayakpaul/custom-diffusion-cat)์ž…๋‹ˆ๋‹ค.
### ๋ฉ€ํ‹ฐ ์ปจ์…‰์— ๋Œ€ํ•œ ํ•™์Šต ๐Ÿฑ๐Ÿชต
[this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)์™€ ์œ ์‚ฌํ•˜๊ฒŒ ๊ฐ ์ปจ์…‰์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ํฌํ•จ๋œ [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) ํŒŒ์ผ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ˆ˜์ง‘ํ•˜๋ ค๋ฉด json ํŒŒ์ผ์˜ ๊ฐ ์ปจ์…‰์— ๋Œ€ํ•ด ์ด ๋ช…๋ น์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
```bash
pip install clip-retrieval
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
```
๊ทธ๋Ÿผ ์šฐ๋ฆฌ๋Š” ํ•™์Šต์‹œํ‚ฌ ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--concepts_list=./concept_list.json \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--num_class_images=200 \
--scale_lr --hflip \
--modifier_token "<new1>+<new2>" \
--push_to_hub
```
๋‹ค์Œ์€ [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg)์˜ ์˜ˆ์‹œ์ด๋ฉฐ, ๋‹ค๋ฅธ ํ•™์Šต ์„ธ๋ถ€ ์ •๋ณด์™€ ํ•จ๊ป˜ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋“ค์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
### ์‚ฌ๋žŒ ์–ผ๊ตด์— ๋Œ€ํ•œ ํ•™์Šต
์‚ฌ๋žŒ ์–ผ๊ตด์— ๋Œ€ํ•œ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์„ค์ •์ด ๋” ํšจ๊ณผ์ ์ด๋ผ๋Š” ๊ฒƒ์„ ํ™•์ธํ–ˆ์Šต๋‹ˆ๋‹ค: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, `freeze_model=crossattn`์„ ์ตœ์†Œ 15~20๊ฐœ์˜ ์ด๋ฏธ์ง€๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ˆ˜์ง‘ํ•˜๋ ค๋ฉด ํ›ˆ๋ จ ์ „์— ์ด ๋ช…๋ น์„ ๋จผ์ € ์‚ฌ์šฉํ•˜์‹ญ์‹œ์˜ค.
```bash
pip install clip-retrieval
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
```
์ด์ œ ํ•™์Šต์„ ์‹œ์ž‘ํ•˜์„ธ์š”!
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
export INSTANCE_DIR="path-to-images"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_person/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="person" --num_class_images=200 \
--instance_prompt="photo of a <new1> person" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=5e-6 \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--scale_lr --hflip --noaug \
--freeze_model crossattn \
--modifier_token "<new1>" \
--enable_xformers_memory_efficient_attention \
--push_to_hub
```
## ์ถ”๋ก 
์œ„ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚จ ํ›„์—๋Š” ์•„๋ž˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ์— 'modifier token'(์˜ˆ: ์œ„ ์˜ˆ์ œ์—์„œ๋Š” \<new1\>)์„ ๋ฐ˜๋“œ์‹œ ํฌํ•จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
```python
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
image = pipe(
"<new1> cat sitting in a bucket",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("cat.png")
```
ํ—ˆ๋ธŒ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์—์„œ ์ด๋Ÿฌํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ง์ ‘ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```python
import torch
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
model_id = "sayakpaul/custom-diffusion-cat"
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
image = pipe(
"<new1> cat sitting in a bucket",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("cat.png")
```
๋‹ค์Œ์€ ์—ฌ๋Ÿฌ ์ปจ์…‰์œผ๋กœ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค:
```python
import torch
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
image = pipe(
"the <new1> cat sculpture in the style of a <new2> wooden pot",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("multi-subject.png")
```
์—ฌ๊ธฐ์„œ '๊ณ ์–‘์ด'์™€ '๋‚˜๋ฌด ๋ƒ„๋น„'๋Š” ์—ฌ๋Ÿฌ ์ปจ์…‰์„ ๋งํ•ฉ๋‹ˆ๋‹ค.
### ํ•™์Šต๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ์ถ”๋ก ํ•˜๊ธฐ
`--checkpointing_steps` ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ ํ•™์Šต ๊ณผ์ •์—์„œ ์ €์žฅ๋œ ์ „์ฒด ์ฒดํฌํฌ์ธํŠธ ์ค‘ ํ•˜๋‚˜์—์„œ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
## Grads๋ฅผ None์œผ๋กœ ์„ค์ •
๋” ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๋ ค๋ฉด ์Šคํฌ๋ฆฝํŠธ์— `--set_grads_to_none` ์ธ์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜์„ธ์š”. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์„ฑ์ ์ด 0์ด ์•„๋‹Œ ์—†์Œ์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํŠน์ • ๋™์ž‘์ด ๋ณ€๊ฒฝ๋˜๋ฏ€๋กœ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜๋ฉด ์ด ์ธ์ˆ˜๋ฅผ ์ œ๊ฑฐํ•˜์„ธ์š”.
์ž์„ธํ•œ ์ •๋ณด: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
## ์‹คํ—˜ ๊ฒฐ๊ณผ
์‹คํ—˜์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ [๋‹น์‚ฌ ์›นํŽ˜์ด์ง€](https://www.cs.cmu.edu/~custom-diffusion/)๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.