Spaces:
Runtime error
Runtime error
File size: 1,521 Bytes
6230dda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from argparse import Namespace
from multiprocessing import cpu_count
from src.lab import Lab
# runs on 10GB VRAM GPU (RTX 3080)
args = Namespace(
pretrained_model_name_or_path="lint/liquidfix",
controlnet_weights_path="lint/anime_control/anime_merge",
#controlnet_weights_path=None, #
vae_path="lint/anime_vae",
# dataset args
train_data_dir="lint/anybooru",
valid_data_dir="",
resolution=512,
from_hf_hub=True,
controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass
# training args
# options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet
training_stage = "",
learning_rate=5e-6,
num_train_epochs=1000,
max_train_steps=None,
seed=3434554,
max_grad_norm=1.0,
gradient_accumulation_steps=1,
# VRAM args
batch_size=1,
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
gradient_checkpointing=True, # set this to True to lower the memory usage.
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
enable_xformers_memory_efficient_attention=True,
allow_tf32=True,
dataloader_num_workers=cpu_count(),
# logging args
output_dir="./models",
report_to='tensorboard',
image_logging_steps=600, # disabled when 0. costs additional VRAM to log images
save_whole_pipeline=True,
checkpointing_steps=6000,
)
if __name__ == '__main__':
lab = Lab(args)
lab.train(args.num_train_epochs)
|