anime_controlnet / quickstart_train.py
1lint
init commit
6230dda
from argparse import Namespace
from multiprocessing import cpu_count
from src.lab import Lab
# runs on 10GB VRAM GPU (RTX 3080)
args = Namespace(
pretrained_model_name_or_path="lint/liquidfix",
controlnet_weights_path="lint/anime_control/anime_merge",
#controlnet_weights_path=None, #
vae_path="lint/anime_vae",
# dataset args
train_data_dir="lint/anybooru",
valid_data_dir="",
resolution=512,
from_hf_hub=True,
controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass
# training args
# options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet
training_stage = "",
learning_rate=5e-6,
num_train_epochs=1000,
max_train_steps=None,
seed=3434554,
max_grad_norm=1.0,
gradient_accumulation_steps=1,
# VRAM args
batch_size=1,
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
gradient_checkpointing=True, # set this to True to lower the memory usage.
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
enable_xformers_memory_efficient_attention=True,
allow_tf32=True,
dataloader_num_workers=cpu_count(),
# logging args
output_dir="./models",
report_to='tensorboard',
image_logging_steps=600, # disabled when 0. costs additional VRAM to log images
save_whole_pipeline=True,
checkpointing_steps=6000,
)
if __name__ == '__main__':
lab = Lab(args)
lab.train(args.num_train_epochs)