Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
from multiprocessing import cpu_count | |
from src.lab import Lab | |
# runs on 10GB VRAM GPU (RTX 3080) | |
args = Namespace( | |
pretrained_model_name_or_path="lint/liquidfix", | |
controlnet_weights_path="lint/anime_control/anime_merge", | |
#controlnet_weights_path=None, # | |
vae_path="lint/anime_vae", | |
# dataset args | |
train_data_dir="lint/anybooru", | |
valid_data_dir="", | |
resolution=512, | |
from_hf_hub=True, | |
controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass | |
# training args | |
# options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet | |
training_stage = "", | |
learning_rate=5e-6, | |
num_train_epochs=1000, | |
max_train_steps=None, | |
seed=3434554, | |
max_grad_norm=1.0, | |
gradient_accumulation_steps=1, | |
# VRAM args | |
batch_size=1, | |
mixed_precision="fp16", # set to "fp16" for mixed-precision training. | |
gradient_checkpointing=True, # set this to True to lower the memory usage. | |
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes | |
enable_xformers_memory_efficient_attention=True, | |
allow_tf32=True, | |
dataloader_num_workers=cpu_count(), | |
# logging args | |
output_dir="./models", | |
report_to='tensorboard', | |
image_logging_steps=600, # disabled when 0. costs additional VRAM to log images | |
save_whole_pipeline=True, | |
checkpointing_steps=6000, | |
) | |
if __name__ == '__main__': | |
lab = Lab(args) | |
lab.train(args.num_train_epochs) | |