File size: 1,521 Bytes
6230dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

from argparse import Namespace
from multiprocessing import cpu_count
from src.lab import Lab

# runs on 10GB VRAM GPU (RTX 3080)
args = Namespace(
    
    pretrained_model_name_or_path="lint/liquidfix",
    controlnet_weights_path="lint/anime_control/anime_merge",
    #controlnet_weights_path=None, #
    vae_path="lint/anime_vae",

    # dataset args
    train_data_dir="lint/anybooru",
    valid_data_dir="",
    resolution=512,
    from_hf_hub=True,
    controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass

    # training args
    # options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet
    training_stage = "",
    learning_rate=5e-6,
    num_train_epochs=1000,
    max_train_steps=None,
    seed=3434554,
    max_grad_norm=1.0,
    gradient_accumulation_steps=1,

    # VRAM args
    batch_size=1,
    mixed_precision="fp16", # set to "fp16" for mixed-precision training.
    gradient_checkpointing=True, # set this to True to lower the memory usage.
    use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
    enable_xformers_memory_efficient_attention=True,
    allow_tf32=True,
    dataloader_num_workers=cpu_count(),

    # logging args
    output_dir="./models",
    report_to='tensorboard',
    image_logging_steps=600, # disabled when 0. costs additional VRAM to log images
    save_whole_pipeline=True,
    checkpointing_steps=6000, 
)

if __name__ == '__main__':
    lab = Lab(args)
    lab.train(args.num_train_epochs)