from argparse import Namespace from multiprocessing import cpu_count from src.lab import Lab # runs on 10GB VRAM GPU (RTX 3080) args = Namespace( pretrained_model_name_or_path="lint/liquidfix", controlnet_weights_path="lint/anime_control/anime_merge", #controlnet_weights_path=None, # vae_path="lint/anime_vae", # dataset args train_data_dir="lint/anybooru", valid_data_dir="", resolution=512, from_hf_hub=True, controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass # training args # options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet training_stage = "", learning_rate=5e-6, num_train_epochs=1000, max_train_steps=None, seed=3434554, max_grad_norm=1.0, gradient_accumulation_steps=1, # VRAM args batch_size=1, mixed_precision="fp16", # set to "fp16" for mixed-precision training. gradient_checkpointing=True, # set this to True to lower the memory usage. use_8bit_adam=True, # use 8bit optimizer from bitsandbytes enable_xformers_memory_efficient_attention=True, allow_tf32=True, dataloader_num_workers=cpu_count(), # logging args output_dir="./models", report_to='tensorboard', image_logging_steps=600, # disabled when 0. costs additional VRAM to log images save_whole_pipeline=True, checkpointing_steps=6000, ) if __name__ == '__main__': lab = Lab(args) lab.train(args.num_train_epochs)