yeq6x commited on
Commit
a6e21a3
·
verified ·
1 Parent(s): 1476306

Upload folder using huggingface_hub

Browse files
Files changed (37) hide show
  1. output/checkpoint-10000/controlnet/config.json +57 -0
  2. output/checkpoint-10000/controlnet/diffusion_pytorch_model.safetensors +3 -0
  3. output/checkpoint-10000/optimizer.bin +3 -0
  4. output/checkpoint-10000/random_states_0.pkl +3 -0
  5. output/checkpoint-10000/scaler.pt +3 -0
  6. output/checkpoint-10000/scheduler.bin +3 -0
  7. output/checkpoint-20000/controlnet/config.json +57 -0
  8. output/checkpoint-20000/controlnet/diffusion_pytorch_model.safetensors +3 -0
  9. output/checkpoint-20000/optimizer.bin +3 -0
  10. output/checkpoint-20000/random_states_0.pkl +3 -0
  11. output/checkpoint-20000/scaler.pt +3 -0
  12. output/checkpoint-20000/scheduler.bin +3 -0
  13. output/checkpoint-30000/controlnet/config.json +57 -0
  14. output/checkpoint-30000/controlnet/diffusion_pytorch_model.safetensors +3 -0
  15. output/checkpoint-30000/optimizer.bin +3 -0
  16. output/checkpoint-30000/random_states_0.pkl +3 -0
  17. output/checkpoint-30000/scaler.pt +3 -0
  18. output/checkpoint-30000/scheduler.bin +3 -0
  19. output/checkpoint-40000/controlnet/config.json +57 -0
  20. output/checkpoint-40000/controlnet/diffusion_pytorch_model.safetensors +3 -0
  21. output/checkpoint-40000/optimizer.bin +3 -0
  22. output/checkpoint-40000/random_states_0.pkl +3 -0
  23. output/checkpoint-40000/scaler.pt +3 -0
  24. output/checkpoint-40000/scheduler.bin +3 -0
  25. output/checkpoint-50000/controlnet/config.json +57 -0
  26. output/checkpoint-50000/controlnet/diffusion_pytorch_model.safetensors +3 -0
  27. output/checkpoint-50000/optimizer.bin +3 -0
  28. output/checkpoint-50000/random_states_0.pkl +3 -0
  29. output/checkpoint-50000/scaler.pt +3 -0
  30. output/checkpoint-50000/scheduler.bin +3 -0
  31. output/logs/fill50k_custom_v1_classification_v2/1732146579.9035878/events.out.tfevents.1732146579.f41554fe6d06.14666.1 +3 -0
  32. output/logs/fill50k_custom_v1_classification_v2/1732146579.9052796/hparams.yml +54 -0
  33. output/logs/fill50k_custom_v1_classification_v2/1732147449.7879949/events.out.tfevents.1732147449.f41554fe6d06.19394.1 +3 -0
  34. output/logs/fill50k_custom_v1_classification_v2/1732147449.7897315/hparams.yml +54 -0
  35. output/logs/fill50k_custom_v1_classification_v2/events.out.tfevents.1732146579.f41554fe6d06.14666.0 +3 -0
  36. output/logs/fill50k_custom_v1_classification_v2/events.out.tfevents.1732147449.f41554fe6d06.19394.0 +3 -0
  37. output/train_controlnet_sdxl.py +1404 -0
output/checkpoint-10000/controlnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
output/checkpoint-10000/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bafd13f4d2417e5b3822d1a059cbc196d2abd300768b123941e023ebad915333
3
+ size 5004167864
output/checkpoint-10000/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd4904554628f145f67896b54aa980489ad58b54767f41d7ef74ea798673c1b5
3
+ size 10008841510
output/checkpoint-10000/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25975d039a2387e37c1aecad5978e16cfe0ff0d0cf34ed90051653adb4e44888
3
+ size 14344
output/checkpoint-10000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d49367aa0f19981d0484e3ef5c1b009e6bdd8e5c0d8ef08f2f8d235b9f843816
3
+ size 988
output/checkpoint-10000/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d31ad9a5150e341b638f838a47b3cb3db0d48efa493bcc8d1f8bf787b2c8cee3
3
+ size 1000
output/checkpoint-20000/controlnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
output/checkpoint-20000/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2345c59cd1295084f1e8ed2fa331371806cb7450532f6f813a0d294814f43d1
3
+ size 5004167864
output/checkpoint-20000/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b234645d7f7a0b52965799e39e7e858aa22f281da49f3545aea48e3da6c9fda
3
+ size 10008841510
output/checkpoint-20000/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37e2fbb65cb9637df6850b605e724fa082972239ff87a4f9918779aa2b77e3d6
3
+ size 14344
output/checkpoint-20000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a33b95288c044d4b8931227377713cfbcc8d73a1b4e184632c562d7fdaf703f9
3
+ size 988
output/checkpoint-20000/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb423915b7f7abbe2676b44b9381dc20a834b3ca26d3e2d87b59cdd536701fe
3
+ size 1000
output/checkpoint-30000/controlnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
output/checkpoint-30000/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f682210750d2b1f9884f885692856d9bf5559834a9c18f906cb17d956e8ab0
3
+ size 5004167864
output/checkpoint-30000/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99e653650c3f6bde37b99cee8e088773793438e910a0cc437cab759a18b2b6e2
3
+ size 10008841510
output/checkpoint-30000/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01a525adca3ed55a55b0e32f3565e03eb3de12c488d72632c54c9405dd858a28
3
+ size 14344
output/checkpoint-30000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f4794b5f9540306c90a772b3b6a73944ea7ab02d45c760ac7be8ea8a7e65d86
3
+ size 988
output/checkpoint-30000/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:673ee106494e80bd7254452368b4a8a9ef1ab34af463e19284b555963ef6efb4
3
+ size 1000
output/checkpoint-40000/controlnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
output/checkpoint-40000/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d5192eb2d2a16f448e6d9f2f4cf0bba9e70f5ff82f704cfb4f3c5fafb709dee
3
+ size 5004167864
output/checkpoint-40000/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb5f26dd982c7dbd0d9c6d0b34e90dde235380560292773f127f6a49fd19d54
3
+ size 10008841510
output/checkpoint-40000/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d644c26fba3d3cb700c68307ca8cb9ac04f54d82b614f513e80db8806db6812d
3
+ size 14344
output/checkpoint-40000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f2bdba071ad840f3d59eece2f6e711860f58323f128e62b5b857210cb6d9e78
3
+ size 988
output/checkpoint-40000/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef8d9e6dd1cc417151cc565c35c30235164cd222d36b119a405afcf1c1b3336a
3
+ size 1000
output/checkpoint-50000/controlnet/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
output/checkpoint-50000/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:697916c805339c10cc915c99b64574ade0d02f078cc4b07791bc462b36eba503
3
+ size 5004167864
output/checkpoint-50000/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccd8799b95d5519fa9b7d67cd3dac4546375e588b88e83642d57beb264889107
3
+ size 10008841510
output/checkpoint-50000/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4379e91ff6e90233f72224afc7dbf7de2293cb359d55e6f013b1161b32e86f6d
3
+ size 14344
output/checkpoint-50000/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8492ca5c6199ddead96f76b53048d3fbbbc4e32cd1d4f41abe48be425993b1fb
3
+ size 988
output/checkpoint-50000/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:884c2374c79400966c8ad3459e73aad51c02318b552c4ffc9fd1792596d58237
3
+ size 1000
output/logs/fill50k_custom_v1_classification_v2/1732146579.9035878/events.out.tfevents.1732146579.f41554fe6d06.14666.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e34905b217b6bc4510060179e060e5bbaa215b51d6fcf0084c52fb677e422ae3
3
+ size 2702
output/logs/fill50k_custom_v1_classification_v2/1732146579.9052796/hparams.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ checkpointing_steps: 10000
9
+ checkpoints_total_limit: null
10
+ conditioning_image_column: conditioning_image
11
+ controlnet_model_name_or_path: null
12
+ crops_coords_top_left_h: 0
13
+ crops_coords_top_left_w: 0
14
+ dataloader_num_workers: 0
15
+ dataset_config_name: null
16
+ dataset_name: null
17
+ enable_npu_flash_attention: false
18
+ enable_xformers_memory_efficient_attention: true
19
+ gradient_accumulation_steps: 1
20
+ gradient_checkpointing: false
21
+ hub_model_id: null
22
+ hub_token: null
23
+ image_column: image
24
+ learning_rate: 1.0e-05
25
+ logging_dir: logs
26
+ lr_num_cycles: 1
27
+ lr_power: 1.0
28
+ lr_scheduler: constant
29
+ lr_warmup_steps: 500
30
+ max_grad_norm: 1.0
31
+ max_train_samples: null
32
+ max_train_steps: 100000
33
+ mixed_precision: fp16
34
+ num_train_epochs: 2
35
+ num_validation_images: 1
36
+ output_dir: output
37
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0
38
+ pretrained_vae_model_name_or_path: madebyollin/sdxl-vae-fp16-fix
39
+ proportion_empty_prompts: 0
40
+ push_to_hub: false
41
+ report_to: tensorboard
42
+ resolution: 512
43
+ resume_from_checkpoint: null
44
+ revision: null
45
+ scale_lr: false
46
+ seed: 43
47
+ set_grads_to_none: false
48
+ tokenizer_name: null
49
+ tracker_project_name: fill50k_custom_v1_classification_v2
50
+ train_batch_size: 1
51
+ train_data_dir: yeq6x/fill50k_custom
52
+ use_8bit_adam: false
53
+ validation_steps: 1000
54
+ variant: null
output/logs/fill50k_custom_v1_classification_v2/1732147449.7879949/events.out.tfevents.1732147449.f41554fe6d06.19394.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:decb931a69bff1fb816694fb00b74ad9cdf3817240a934912f4c9c781889561a
3
+ size 2702
output/logs/fill50k_custom_v1_classification_v2/1732147449.7897315/hparams.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ checkpointing_steps: 10000
9
+ checkpoints_total_limit: null
10
+ conditioning_image_column: conditioning_image
11
+ controlnet_model_name_or_path: null
12
+ crops_coords_top_left_h: 0
13
+ crops_coords_top_left_w: 0
14
+ dataloader_num_workers: 0
15
+ dataset_config_name: null
16
+ dataset_name: null
17
+ enable_npu_flash_attention: false
18
+ enable_xformers_memory_efficient_attention: true
19
+ gradient_accumulation_steps: 1
20
+ gradient_checkpointing: false
21
+ hub_model_id: null
22
+ hub_token: null
23
+ image_column: image
24
+ learning_rate: 1.0e-05
25
+ logging_dir: logs
26
+ lr_num_cycles: 1
27
+ lr_power: 1.0
28
+ lr_scheduler: constant
29
+ lr_warmup_steps: 500
30
+ max_grad_norm: 1.0
31
+ max_train_samples: null
32
+ max_train_steps: 100000
33
+ mixed_precision: fp16
34
+ num_train_epochs: 2
35
+ num_validation_images: 1
36
+ output_dir: output
37
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0
38
+ pretrained_vae_model_name_or_path: madebyollin/sdxl-vae-fp16-fix
39
+ proportion_empty_prompts: 0
40
+ push_to_hub: false
41
+ report_to: tensorboard
42
+ resolution: 512
43
+ resume_from_checkpoint: null
44
+ revision: null
45
+ scale_lr: false
46
+ seed: 43
47
+ set_grads_to_none: false
48
+ tokenizer_name: null
49
+ tracker_project_name: fill50k_custom_v1_classification_v2
50
+ train_batch_size: 1
51
+ train_data_dir: yeq6x/fill50k_custom
52
+ use_8bit_adam: false
53
+ validation_steps: 1000
54
+ variant: null
output/logs/fill50k_custom_v1_classification_v2/events.out.tfevents.1732146579.f41554fe6d06.14666.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94a3ea9d98ade653492259433559d4c345e8cc64abedbf58e62ee45475b35c18
3
+ size 145561
output/logs/fill50k_custom_v1_classification_v2/events.out.tfevents.1732147449.f41554fe6d06.19394.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:933390ca81864cb6ac5f30e6a6d48f8f936bfcca025e57688f59c50802754a39
3
+ size 28954054
output/train_controlnet_sdxl.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import functools
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from contextlib import nullcontext
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
36
+ from datasets import load_dataset
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ ControlNetModel,
48
+ DDPMScheduler,
49
+ StableDiffusionXLControlNetPipeline,
50
+ UNet2DConditionModel,
51
+ UniPCMultistepScheduler,
52
+ )
53
+ from diffusers.optimization import get_scheduler
54
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
55
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
56
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
57
+ from diffusers.utils.torch_utils import is_compiled_module
58
+
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.32.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+ if is_torch_npu_available():
68
+ torch.npu.config.allow_internal_format = False
69
+
70
+
71
+ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
72
+ logger.info("Running validation... ")
73
+
74
+ if not is_final_validation:
75
+ controlnet = accelerator.unwrap_model(controlnet)
76
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
77
+ args.pretrained_model_name_or_path,
78
+ vae=vae,
79
+ unet=unet,
80
+ controlnet=controlnet,
81
+ revision=args.revision,
82
+ variant=args.variant,
83
+ torch_dtype=weight_dtype,
84
+ )
85
+ else:
86
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
87
+ if args.pretrained_vae_model_name_or_path is not None:
88
+ vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
89
+ else:
90
+ vae = AutoencoderKL.from_pretrained(
91
+ args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
92
+ )
93
+
94
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
95
+ args.pretrained_model_name_or_path,
96
+ vae=vae,
97
+ controlnet=controlnet,
98
+ revision=args.revision,
99
+ variant=args.variant,
100
+ torch_dtype=weight_dtype,
101
+ )
102
+
103
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
104
+ pipeline = pipeline.to(accelerator.device)
105
+ pipeline.set_progress_bar_config(disable=True)
106
+
107
+ if args.enable_xformers_memory_efficient_attention:
108
+ pipeline.enable_xformers_memory_efficient_attention()
109
+
110
+ if args.seed is None:
111
+ generator = None
112
+ else:
113
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
114
+
115
+ if len(args.validation_image) == len(args.validation_prompt):
116
+ validation_images = args.validation_image
117
+ validation_prompts = args.validation_prompt
118
+ elif len(args.validation_image) == 1:
119
+ validation_images = args.validation_image * len(args.validation_prompt)
120
+ validation_prompts = args.validation_prompt
121
+ elif len(args.validation_prompt) == 1:
122
+ validation_images = args.validation_image
123
+ validation_prompts = args.validation_prompt * len(args.validation_image)
124
+ else:
125
+ raise ValueError(
126
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
127
+ )
128
+
129
+ image_logs = []
130
+ if is_final_validation or torch.backends.mps.is_available():
131
+ autocast_ctx = nullcontext()
132
+ else:
133
+ autocast_ctx = torch.autocast(accelerator.device.type)
134
+
135
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136
+ validation_image = Image.open(validation_image).convert("RGB")
137
+ validation_image = validation_image.resize((args.resolution, args.resolution))
138
+
139
+ images = []
140
+
141
+ for _ in range(args.num_validation_images):
142
+ with autocast_ctx:
143
+ image = pipeline(
144
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
145
+ ).images[0]
146
+ images.append(image)
147
+
148
+ image_logs.append(
149
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
150
+ )
151
+
152
+ tracker_key = "test" if is_final_validation else "validation"
153
+ for tracker in accelerator.trackers:
154
+ if tracker.name == "tensorboard":
155
+ for log in image_logs:
156
+ images = log["images"]
157
+ validation_prompt = log["validation_prompt"]
158
+ validation_image = log["validation_image"]
159
+
160
+ formatted_images = []
161
+
162
+ formatted_images.append(np.asarray(validation_image))
163
+
164
+ for image in images:
165
+ formatted_images.append(np.asarray(image))
166
+
167
+ formatted_images = np.stack(formatted_images)
168
+
169
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
170
+ elif tracker.name == "wandb":
171
+ formatted_images = []
172
+
173
+ for log in image_logs:
174
+ images = log["images"]
175
+ validation_prompt = log["validation_prompt"]
176
+ validation_image = log["validation_image"]
177
+
178
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
179
+
180
+ for image in images:
181
+ image = wandb.Image(image, caption=validation_prompt)
182
+ formatted_images.append(image)
183
+
184
+ tracker.log({tracker_key: formatted_images})
185
+ else:
186
+ logger.warning(f"image logging not implemented for {tracker.name}")
187
+
188
+ del pipeline
189
+ gc.collect()
190
+ torch.cuda.empty_cache()
191
+
192
+ return image_logs
193
+
194
+
195
+ def import_model_class_from_model_name_or_path(
196
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
197
+ ):
198
+ text_encoder_config = PretrainedConfig.from_pretrained(
199
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
200
+ )
201
+ model_class = text_encoder_config.architectures[0]
202
+
203
+ if model_class == "CLIPTextModel":
204
+ from transformers import CLIPTextModel
205
+
206
+ return CLIPTextModel
207
+ elif model_class == "CLIPTextModelWithProjection":
208
+ from transformers import CLIPTextModelWithProjection
209
+
210
+ return CLIPTextModelWithProjection
211
+ else:
212
+ raise ValueError(f"{model_class} is not supported.")
213
+
214
+
215
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
216
+ img_str = ""
217
+ if image_logs is not None:
218
+ img_str = "You can find some example images below.\n\n"
219
+ for i, log in enumerate(image_logs):
220
+ images = log["images"]
221
+ validation_prompt = log["validation_prompt"]
222
+ validation_image = log["validation_image"]
223
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
224
+ img_str += f"prompt: {validation_prompt}\n"
225
+ images = [validation_image] + images
226
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
227
+ img_str += f"![images_{i})](./images_{i}.png)\n"
228
+
229
+ model_description = f"""
230
+ # controlnet-{repo_id}
231
+
232
+ These are controlnet weights trained on {base_model} with new type of conditioning.
233
+ {img_str}
234
+ """
235
+
236
+ model_card = load_or_create_model_card(
237
+ repo_id_or_path=repo_id,
238
+ from_training=True,
239
+ license="openrail++",
240
+ base_model=base_model,
241
+ model_description=model_description,
242
+ inference=True,
243
+ )
244
+
245
+ tags = [
246
+ "stable-diffusion-xl",
247
+ "stable-diffusion-xl-diffusers",
248
+ "text-to-image",
249
+ "diffusers",
250
+ "controlnet",
251
+ "diffusers-training",
252
+ ]
253
+ model_card = populate_model_card(model_card, tags=tags)
254
+
255
+ model_card.save(os.path.join(repo_folder, "README.md"))
256
+
257
+
258
+ def parse_args(input_args=None):
259
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
260
+ parser.add_argument(
261
+ "--pretrained_model_name_or_path",
262
+ type=str,
263
+ default=None,
264
+ required=True,
265
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
266
+ )
267
+ parser.add_argument(
268
+ "--pretrained_vae_model_name_or_path",
269
+ type=str,
270
+ default=None,
271
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
272
+ )
273
+ parser.add_argument(
274
+ "--controlnet_model_name_or_path",
275
+ type=str,
276
+ default=None,
277
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
278
+ " If not specified controlnet weights are initialized from unet.",
279
+ )
280
+ parser.add_argument(
281
+ "--variant",
282
+ type=str,
283
+ default=None,
284
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
285
+ )
286
+ parser.add_argument(
287
+ "--revision",
288
+ type=str,
289
+ default=None,
290
+ required=False,
291
+ help="Revision of pretrained model identifier from huggingface.co/models.",
292
+ )
293
+ parser.add_argument(
294
+ "--tokenizer_name",
295
+ type=str,
296
+ default=None,
297
+ help="Pretrained tokenizer name or path if not the same as model_name",
298
+ )
299
+ parser.add_argument(
300
+ "--output_dir",
301
+ type=str,
302
+ default="controlnet-model",
303
+ help="The output directory where the model predictions and checkpoints will be written.",
304
+ )
305
+ parser.add_argument(
306
+ "--cache_dir",
307
+ type=str,
308
+ default=None,
309
+ help="The directory where the downloaded models and datasets will be stored.",
310
+ )
311
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
312
+ parser.add_argument(
313
+ "--resolution",
314
+ type=int,
315
+ default=512,
316
+ help=(
317
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
318
+ " resolution"
319
+ ),
320
+ )
321
+ parser.add_argument(
322
+ "--crops_coords_top_left_h",
323
+ type=int,
324
+ default=0,
325
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
326
+ )
327
+ parser.add_argument(
328
+ "--crops_coords_top_left_w",
329
+ type=int,
330
+ default=0,
331
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
332
+ )
333
+ parser.add_argument(
334
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
335
+ )
336
+ parser.add_argument("--num_train_epochs", type=int, default=1)
337
+ parser.add_argument(
338
+ "--max_train_steps",
339
+ type=int,
340
+ default=None,
341
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
342
+ )
343
+ parser.add_argument(
344
+ "--checkpointing_steps",
345
+ type=int,
346
+ default=500,
347
+ help=(
348
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
349
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
350
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
351
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
352
+ "instructions."
353
+ ),
354
+ )
355
+ parser.add_argument(
356
+ "--checkpoints_total_limit",
357
+ type=int,
358
+ default=None,
359
+ help=("Max number of checkpoints to store."),
360
+ )
361
+ parser.add_argument(
362
+ "--resume_from_checkpoint",
363
+ type=str,
364
+ default=None,
365
+ help=(
366
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
367
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
368
+ ),
369
+ )
370
+ parser.add_argument(
371
+ "--gradient_accumulation_steps",
372
+ type=int,
373
+ default=1,
374
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
375
+ )
376
+ parser.add_argument(
377
+ "--gradient_checkpointing",
378
+ action="store_true",
379
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
380
+ )
381
+ parser.add_argument(
382
+ "--learning_rate",
383
+ type=float,
384
+ default=5e-6,
385
+ help="Initial learning rate (after the potential warmup period) to use.",
386
+ )
387
+ parser.add_argument(
388
+ "--scale_lr",
389
+ action="store_true",
390
+ default=False,
391
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
392
+ )
393
+ parser.add_argument(
394
+ "--lr_scheduler",
395
+ type=str,
396
+ default="constant",
397
+ help=(
398
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
399
+ ' "constant", "constant_with_warmup"]'
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
404
+ )
405
+ parser.add_argument(
406
+ "--lr_num_cycles",
407
+ type=int,
408
+ default=1,
409
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
410
+ )
411
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
412
+ parser.add_argument(
413
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
414
+ )
415
+ parser.add_argument(
416
+ "--dataloader_num_workers",
417
+ type=int,
418
+ default=0,
419
+ help=(
420
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
421
+ ),
422
+ )
423
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
424
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
425
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
426
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
427
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
428
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
429
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
430
+ parser.add_argument(
431
+ "--hub_model_id",
432
+ type=str,
433
+ default=None,
434
+ help="The name of the repository to keep in sync with the local `output_dir`.",
435
+ )
436
+ parser.add_argument(
437
+ "--logging_dir",
438
+ type=str,
439
+ default="logs",
440
+ help=(
441
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
442
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
443
+ ),
444
+ )
445
+ parser.add_argument(
446
+ "--allow_tf32",
447
+ action="store_true",
448
+ help=(
449
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
450
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
451
+ ),
452
+ )
453
+ parser.add_argument(
454
+ "--report_to",
455
+ type=str,
456
+ default="tensorboard",
457
+ help=(
458
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
459
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
460
+ ),
461
+ )
462
+ parser.add_argument(
463
+ "--mixed_precision",
464
+ type=str,
465
+ default=None,
466
+ choices=["no", "fp16", "bf16"],
467
+ help=(
468
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
469
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
470
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
471
+ ),
472
+ )
473
+ parser.add_argument(
474
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
475
+ )
476
+ parser.add_argument(
477
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
478
+ )
479
+ parser.add_argument(
480
+ "--set_grads_to_none",
481
+ action="store_true",
482
+ help=(
483
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
484
+ " behaviors, so disable this argument if it causes any problems. More info:"
485
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
486
+ ),
487
+ )
488
+ parser.add_argument(
489
+ "--dataset_name",
490
+ type=str,
491
+ default=None,
492
+ help=(
493
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
494
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
495
+ " or to a folder containing files that 🤗 Datasets can understand."
496
+ ),
497
+ )
498
+ parser.add_argument(
499
+ "--dataset_config_name",
500
+ type=str,
501
+ default=None,
502
+ help="The config of the Dataset, leave as None if there's only one config.",
503
+ )
504
+ parser.add_argument(
505
+ "--train_data_dir",
506
+ type=str,
507
+ default=None,
508
+ help=(
509
+ "A folder containing the training data. Folder contents must follow the structure described in"
510
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
511
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
512
+ ),
513
+ )
514
+ parser.add_argument(
515
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
516
+ )
517
+ parser.add_argument(
518
+ "--conditioning_image_column",
519
+ type=str,
520
+ default="conditioning_image",
521
+ help="The column of the dataset containing the controlnet conditioning image.",
522
+ )
523
+ parser.add_argument(
524
+ "--caption_column",
525
+ type=str,
526
+ default="text",
527
+ help="The column of the dataset containing a caption or a list of captions.",
528
+ )
529
+ parser.add_argument(
530
+ "--max_train_samples",
531
+ type=int,
532
+ default=None,
533
+ help=(
534
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
535
+ "value if set."
536
+ ),
537
+ )
538
+ parser.add_argument(
539
+ "--proportion_empty_prompts",
540
+ type=float,
541
+ default=0,
542
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
543
+ )
544
+ parser.add_argument(
545
+ "--validation_prompt",
546
+ type=str,
547
+ default=None,
548
+ nargs="+",
549
+ help=(
550
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
551
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
552
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
553
+ ),
554
+ )
555
+ parser.add_argument(
556
+ "--validation_image",
557
+ type=str,
558
+ default=None,
559
+ nargs="+",
560
+ help=(
561
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
562
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
563
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
564
+ " `--validation_image` that will be used with all `--validation_prompt`s."
565
+ ),
566
+ )
567
+ parser.add_argument(
568
+ "--num_validation_images",
569
+ type=int,
570
+ default=4,
571
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
572
+ )
573
+ parser.add_argument(
574
+ "--validation_steps",
575
+ type=int,
576
+ default=100,
577
+ help=(
578
+ "Run validation every X steps. Validation consists of running the prompt"
579
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
580
+ " and logging the images."
581
+ ),
582
+ )
583
+ parser.add_argument(
584
+ "--tracker_project_name",
585
+ type=str,
586
+ default="sd_xl_train_controlnet",
587
+ help=(
588
+ "The `project_name` argument passed to Accelerator.init_trackers for"
589
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
590
+ ),
591
+ )
592
+
593
+ if input_args is not None:
594
+ args = parser.parse_args(input_args)
595
+ else:
596
+ args = parser.parse_args()
597
+
598
+ if args.dataset_name is None and args.train_data_dir is None:
599
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
600
+
601
+ if args.dataset_name is not None and args.train_data_dir is not None:
602
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
603
+
604
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
605
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
606
+
607
+ if args.validation_prompt is not None and args.validation_image is None:
608
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
609
+
610
+ if args.validation_prompt is None and args.validation_image is not None:
611
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
612
+
613
+ if (
614
+ args.validation_image is not None
615
+ and args.validation_prompt is not None
616
+ and len(args.validation_image) != 1
617
+ and len(args.validation_prompt) != 1
618
+ and len(args.validation_image) != len(args.validation_prompt)
619
+ ):
620
+ raise ValueError(
621
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
622
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
623
+ )
624
+
625
+ if args.resolution % 8 != 0:
626
+ raise ValueError(
627
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
628
+ )
629
+
630
+ return args
631
+
632
+
633
+ def get_train_dataset(args, accelerator):
634
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
635
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
636
+
637
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
638
+ # download the dataset.
639
+ if args.dataset_name is not None:
640
+ # Downloading and loading a dataset from the hub.
641
+ dataset = load_dataset(
642
+ args.dataset_name,
643
+ args.dataset_config_name,
644
+ cache_dir=args.cache_dir,
645
+ )
646
+ else:
647
+ if args.train_data_dir is not None:
648
+ dataset = load_dataset(
649
+ args.train_data_dir,
650
+ cache_dir=args.cache_dir,
651
+ )
652
+ # See more about loading custom images at
653
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
654
+
655
+ # Preprocessing the datasets.
656
+ # We need to tokenize inputs and targets.
657
+ column_names = dataset["train"].column_names
658
+
659
+ # 6. Get the column names for input/target.
660
+ if args.image_column is None:
661
+ image_column = column_names[0]
662
+ logger.info(f"image column defaulting to {image_column}")
663
+ else:
664
+ image_column = args.image_column
665
+ if image_column not in column_names:
666
+ raise ValueError(
667
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
668
+ )
669
+
670
+ if args.caption_column is None:
671
+ caption_column = column_names[1]
672
+ logger.info(f"caption column defaulting to {caption_column}")
673
+ else:
674
+ caption_column = args.caption_column
675
+ if caption_column not in column_names:
676
+ raise ValueError(
677
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
678
+ )
679
+
680
+ if args.conditioning_image_column is None:
681
+ conditioning_image_column = column_names[2]
682
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
683
+ else:
684
+ conditioning_image_column = args.conditioning_image_column
685
+ if conditioning_image_column not in column_names:
686
+ raise ValueError(
687
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
688
+ )
689
+
690
+ with accelerator.main_process_first():
691
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
692
+ if args.max_train_samples is not None:
693
+ train_dataset = train_dataset.select(range(args.max_train_samples))
694
+ return train_dataset
695
+
696
+
697
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
698
+ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
699
+ prompt_embeds_list = []
700
+
701
+ captions = []
702
+ for caption in prompt_batch:
703
+ if random.random() < proportion_empty_prompts:
704
+ captions.append("")
705
+ elif isinstance(caption, str):
706
+ captions.append(caption)
707
+ elif isinstance(caption, (list, np.ndarray)):
708
+ # take a random caption if there are multiple
709
+ captions.append(random.choice(caption) if is_train else caption[0])
710
+
711
+ with torch.no_grad():
712
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
713
+ text_inputs = tokenizer(
714
+ captions,
715
+ padding="max_length",
716
+ max_length=tokenizer.model_max_length,
717
+ truncation=True,
718
+ return_tensors="pt",
719
+ )
720
+ text_input_ids = text_inputs.input_ids
721
+ prompt_embeds = text_encoder(
722
+ text_input_ids.to(text_encoder.device),
723
+ output_hidden_states=True,
724
+ )
725
+
726
+ # We are only ALWAYS interested in the pooled output of the final text encoder
727
+ pooled_prompt_embeds = prompt_embeds[0]
728
+ prompt_embeds = prompt_embeds.hidden_states[-2]
729
+ bs_embed, seq_len, _ = prompt_embeds.shape
730
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
731
+ prompt_embeds_list.append(prompt_embeds)
732
+
733
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
734
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
735
+ return prompt_embeds, pooled_prompt_embeds
736
+
737
+
738
+ def prepare_train_dataset(dataset, accelerator):
739
+ image_transforms = transforms.Compose(
740
+ [
741
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
742
+ transforms.CenterCrop(args.resolution),
743
+ transforms.ToTensor(),
744
+ transforms.Normalize([0.5], [0.5]),
745
+ ]
746
+ )
747
+
748
+ conditioning_image_transforms = transforms.Compose(
749
+ [
750
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
751
+ transforms.CenterCrop(args.resolution),
752
+ transforms.ToTensor(),
753
+ ]
754
+ )
755
+
756
+ def preprocess_train(examples):
757
+ images = [image.convert("RGB") for image in examples[args.image_column]]
758
+ images = [image_transforms(image) for image in images]
759
+
760
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
761
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
762
+
763
+ examples["pixel_values"] = images
764
+ examples["conditioning_pixel_values"] = conditioning_images
765
+
766
+ return examples
767
+
768
+ with accelerator.main_process_first():
769
+ dataset = dataset.with_transform(preprocess_train)
770
+
771
+ return dataset
772
+
773
+
774
+ def collate_fn(examples):
775
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
776
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
777
+
778
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
779
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
780
+
781
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
782
+
783
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
784
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
785
+
786
+ return {
787
+ "pixel_values": pixel_values,
788
+ "conditioning_pixel_values": conditioning_pixel_values,
789
+ "prompt_ids": prompt_ids,
790
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
791
+ "text": [example["text"] for example in examples]
792
+ }
793
+
794
+
795
+ def main(args):
796
+ if args.report_to == "wandb" and args.hub_token is not None:
797
+ raise ValueError(
798
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
799
+ " Please use `huggingface-cli login` to authenticate with the Hub."
800
+ )
801
+
802
+ logging_dir = Path(args.output_dir, args.logging_dir)
803
+
804
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
805
+ # due to pytorch#99272, MPS does not yet support bfloat16.
806
+ raise ValueError(
807
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
808
+ )
809
+
810
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
811
+
812
+ accelerator = Accelerator(
813
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
814
+ mixed_precision=args.mixed_precision,
815
+ log_with=args.report_to,
816
+ project_config=accelerator_project_config,
817
+ )
818
+
819
+ # Disable AMP for MPS.
820
+ if torch.backends.mps.is_available():
821
+ accelerator.native_amp = False
822
+
823
+ # Make one log on every process with the configuration for debugging.
824
+ logging.basicConfig(
825
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
826
+ datefmt="%m/%d/%Y %H:%M:%S",
827
+ level=logging.INFO,
828
+ )
829
+ logger.info(accelerator.state, main_process_only=False)
830
+ if accelerator.is_local_main_process:
831
+ transformers.utils.logging.set_verbosity_warning()
832
+ diffusers.utils.logging.set_verbosity_info()
833
+ else:
834
+ transformers.utils.logging.set_verbosity_error()
835
+ diffusers.utils.logging.set_verbosity_error()
836
+
837
+ # If passed along, set the training seed now.
838
+ if args.seed is not None:
839
+ set_seed(args.seed)
840
+
841
+ # Handle the repository creation
842
+ if accelerator.is_main_process:
843
+ if args.output_dir is not None:
844
+ os.makedirs(args.output_dir, exist_ok=True)
845
+
846
+ if args.push_to_hub:
847
+ repo_id = create_repo(
848
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
849
+ ).repo_id
850
+
851
+ # Load the tokenizers
852
+ tokenizer_one = AutoTokenizer.from_pretrained(
853
+ args.pretrained_model_name_or_path,
854
+ subfolder="tokenizer",
855
+ revision=args.revision,
856
+ use_fast=False,
857
+ )
858
+ tokenizer_two = AutoTokenizer.from_pretrained(
859
+ args.pretrained_model_name_or_path,
860
+ subfolder="tokenizer_2",
861
+ revision=args.revision,
862
+ use_fast=False,
863
+ )
864
+
865
+ # import correct text encoder classes
866
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
867
+ args.pretrained_model_name_or_path, args.revision
868
+ )
869
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
870
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
871
+ )
872
+
873
+ # Load scheduler and models
874
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
875
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
876
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
877
+ )
878
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
879
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
880
+ )
881
+ vae_path = (
882
+ args.pretrained_model_name_or_path
883
+ if args.pretrained_vae_model_name_or_path is None
884
+ else args.pretrained_vae_model_name_or_path
885
+ )
886
+ vae = AutoencoderKL.from_pretrained(
887
+ vae_path,
888
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
889
+ revision=args.revision,
890
+ variant=args.variant,
891
+ )
892
+ unet = UNet2DConditionModel.from_pretrained(
893
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
894
+ )
895
+
896
+ if args.controlnet_model_name_or_path:
897
+ logger.info("Loading existing controlnet weights")
898
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
899
+ else:
900
+ logger.info("Initializing controlnet weights from unet")
901
+ controlnet = ControlNetModel.from_unet(unet)
902
+
903
+ def unwrap_model(model):
904
+ model = accelerator.unwrap_model(model)
905
+ model = model._orig_mod if is_compiled_module(model) else model
906
+ return model
907
+
908
+ # `accelerate` 0.16.0 will have better support for customized saving
909
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
910
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
911
+ def save_model_hook(models, weights, output_dir):
912
+ if accelerator.is_main_process:
913
+ i = len(weights) - 1
914
+
915
+ while len(weights) > 0:
916
+ weights.pop()
917
+ model = models[i]
918
+
919
+ sub_dir = "controlnet"
920
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
921
+
922
+ i -= 1
923
+
924
+ def load_model_hook(models, input_dir):
925
+ while len(models) > 0:
926
+ # pop models so that they are not loaded again
927
+ model = models.pop()
928
+
929
+ # load diffusers style into model
930
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
931
+ model.register_to_config(**load_model.config)
932
+
933
+ model.load_state_dict(load_model.state_dict())
934
+ del load_model
935
+
936
+ accelerator.register_save_state_pre_hook(save_model_hook)
937
+ accelerator.register_load_state_pre_hook(load_model_hook)
938
+
939
+ vae.requires_grad_(False)
940
+ unet.requires_grad_(False)
941
+ text_encoder_one.requires_grad_(False)
942
+ text_encoder_two.requires_grad_(False)
943
+ controlnet.train()
944
+
945
+ if args.enable_npu_flash_attention:
946
+ if is_torch_npu_available():
947
+ logger.info("npu flash attention enabled.")
948
+ unet.enable_npu_flash_attention()
949
+ else:
950
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
951
+
952
+ if args.enable_xformers_memory_efficient_attention:
953
+ if is_xformers_available():
954
+ import xformers
955
+
956
+ xformers_version = version.parse(xformers.__version__)
957
+ if xformers_version == version.parse("0.0.16"):
958
+ logger.warning(
959
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
960
+ )
961
+ unet.enable_xformers_memory_efficient_attention()
962
+ controlnet.enable_xformers_memory_efficient_attention()
963
+ else:
964
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
965
+
966
+ if args.gradient_checkpointing:
967
+ controlnet.enable_gradient_checkpointing()
968
+ unet.enable_gradient_checkpointing()
969
+
970
+ # Check that all trainable models are in full precision
971
+ low_precision_error_string = (
972
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
973
+ " doing mixed precision training, copy of the weights should still be float32."
974
+ )
975
+
976
+ if unwrap_model(controlnet).dtype != torch.float32:
977
+ raise ValueError(
978
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
979
+ )
980
+
981
+ # Enable TF32 for faster training on Ampere GPUs,
982
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
983
+ if args.allow_tf32:
984
+ torch.backends.cuda.matmul.allow_tf32 = True
985
+
986
+ if args.scale_lr:
987
+ args.learning_rate = (
988
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
989
+ )
990
+
991
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
992
+ if args.use_8bit_adam:
993
+ try:
994
+ import bitsandbytes as bnb
995
+ except ImportError:
996
+ raise ImportError(
997
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
998
+ )
999
+
1000
+ optimizer_class = bnb.optim.AdamW8bit
1001
+ else:
1002
+ optimizer_class = torch.optim.AdamW
1003
+
1004
+ # Optimizer creation
1005
+ params_to_optimize = controlnet.parameters()
1006
+ optimizer = optimizer_class(
1007
+ params_to_optimize,
1008
+ lr=args.learning_rate,
1009
+ betas=(args.adam_beta1, args.adam_beta2),
1010
+ weight_decay=args.adam_weight_decay,
1011
+ eps=args.adam_epsilon,
1012
+ )
1013
+
1014
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1015
+ # as these models are only used for inference, keeping weights in full precision is not required.
1016
+ weight_dtype = torch.float32
1017
+ if accelerator.mixed_precision == "fp16":
1018
+ weight_dtype = torch.float16
1019
+ elif accelerator.mixed_precision == "bf16":
1020
+ weight_dtype = torch.bfloat16
1021
+
1022
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1023
+ # The VAE is in float32 to avoid NaN losses.
1024
+ if args.pretrained_vae_model_name_or_path is not None:
1025
+ vae.to(accelerator.device, dtype=weight_dtype)
1026
+ else:
1027
+ vae.to(accelerator.device, dtype=torch.float32)
1028
+ unet.to(accelerator.device, dtype=weight_dtype)
1029
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1030
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1031
+
1032
+ # Here, we compute not just the text embeddings but also the additional embeddings
1033
+ # needed for the SD XL UNet to operate.
1034
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
1035
+ original_size = (args.resolution, args.resolution)
1036
+ target_size = (args.resolution, args.resolution)
1037
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
1038
+ prompt_batch = batch[args.caption_column]
1039
+
1040
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1041
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
1042
+ )
1043
+ add_text_embeds = pooled_prompt_embeds
1044
+
1045
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1046
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1047
+ add_time_ids = torch.tensor([add_time_ids])
1048
+
1049
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1050
+ add_text_embeds = add_text_embeds.to(accelerator.device)
1051
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
1052
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1053
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1054
+
1055
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
1056
+
1057
+ # Let's first compute all the embeddings so that we can free up the text encoders
1058
+ # from memory.
1059
+ text_encoders = [text_encoder_one, text_encoder_two]
1060
+ tokenizers = [tokenizer_one, tokenizer_two]
1061
+ train_dataset = get_train_dataset(args, accelerator)
1062
+ compute_embeddings_fn = functools.partial(
1063
+ compute_embeddings,
1064
+ text_encoders=text_encoders,
1065
+ tokenizers=tokenizers,
1066
+ proportion_empty_prompts=args.proportion_empty_prompts,
1067
+ )
1068
+ with accelerator.main_process_first():
1069
+ from datasets.fingerprint import Hasher
1070
+
1071
+ # fingerprint used by the cache for the other processes to load the result
1072
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
1073
+ new_fingerprint = Hasher.hash(args)
1074
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
1075
+
1076
+ del text_encoders, tokenizers
1077
+ gc.collect()
1078
+ torch.cuda.empty_cache()
1079
+
1080
+ # Then get the training dataset ready to be passed to the dataloader.
1081
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
1082
+
1083
+ train_dataloader = torch.utils.data.DataLoader(
1084
+ train_dataset,
1085
+ shuffle=True,
1086
+ collate_fn=collate_fn,
1087
+ batch_size=args.train_batch_size,
1088
+ num_workers=args.dataloader_num_workers,
1089
+ )
1090
+
1091
+ # Scheduler and math around the number of training steps.
1092
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1093
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
1094
+ if args.max_train_steps is None:
1095
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1096
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1097
+ num_training_steps_for_scheduler = (
1098
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1099
+ )
1100
+ else:
1101
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1102
+
1103
+ lr_scheduler = get_scheduler(
1104
+ args.lr_scheduler,
1105
+ optimizer=optimizer,
1106
+ num_warmup_steps=num_warmup_steps_for_scheduler,
1107
+ num_training_steps=num_training_steps_for_scheduler,
1108
+ num_cycles=args.lr_num_cycles,
1109
+ power=args.lr_power,
1110
+ )
1111
+
1112
+ # Prepare everything with our `accelerator`.
1113
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1114
+ controlnet, optimizer, train_dataloader, lr_scheduler
1115
+ )
1116
+
1117
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1118
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1119
+ if args.max_train_steps is None:
1120
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1121
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1122
+ logger.warning(
1123
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1124
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1125
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1126
+ )
1127
+ # Afterwards we recalculate our number of training epochs
1128
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1129
+
1130
+ # We need to initialize the trackers we use, and also store our configuration.
1131
+ # The trackers initializes automatically on the main process.
1132
+ if accelerator.is_main_process:
1133
+ tracker_config = dict(vars(args))
1134
+
1135
+ # tensorboard cannot handle list types for config
1136
+ tracker_config.pop("validation_prompt")
1137
+ tracker_config.pop("validation_image")
1138
+
1139
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1140
+
1141
+ # Train!
1142
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1143
+
1144
+ logger.info("***** Running training *****")
1145
+ logger.info(f" Num examples = {len(train_dataset)}")
1146
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1147
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1148
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1149
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1150
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1151
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1152
+ global_step = 0
1153
+ first_epoch = 0
1154
+
1155
+ # Potentially load in the weights and states from a previous save
1156
+ if args.resume_from_checkpoint:
1157
+ if args.resume_from_checkpoint != "latest":
1158
+ path = os.path.basename(args.resume_from_checkpoint)
1159
+ else:
1160
+ # Get the most recent checkpoint
1161
+ dirs = os.listdir(args.output_dir)
1162
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1163
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1164
+ path = dirs[-1] if len(dirs) > 0 else None
1165
+
1166
+ if path is None:
1167
+ accelerator.print(
1168
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1169
+ )
1170
+ args.resume_from_checkpoint = None
1171
+ initial_global_step = 0
1172
+ else:
1173
+ accelerator.print(f"Resuming from checkpoint {path}")
1174
+ accelerator.load_state(os.path.join(args.output_dir, path))
1175
+ global_step = int(path.split("-")[1])
1176
+
1177
+ initial_global_step = global_step
1178
+ first_epoch = global_step // num_update_steps_per_epoch
1179
+ else:
1180
+ initial_global_step = 0
1181
+
1182
+ progress_bar = tqdm(
1183
+ range(0, args.max_train_steps),
1184
+ initial=initial_global_step,
1185
+ desc="Steps",
1186
+ # Only show the progress bar once on each machine.
1187
+ disable=not accelerator.is_local_main_process,
1188
+ )
1189
+
1190
+ from torch import nn
1191
+
1192
+ class MultiLayerClassifierHead(nn.Module):
1193
+ def __init__(self, input_dim, hidden_dim, num_classes):
1194
+ super(MultiLayerClassifierHead, self).__init__()
1195
+ self.fc1 = nn.Linear(input_dim, hidden_dim) # 入力 -> 隠れ層
1196
+ self.fc2 = nn.Linear(hidden_dim, num_classes) # 隠れ層 -> 出力
1197
+ self.activation = nn.ReLU() # 非線形活性化関数
1198
+
1199
+ def forward(self, x):
1200
+ # Global Average Pooling
1201
+ x = F.adaptive_avg_pool2d(x, (1, 1)) # [batch, channels, h, w] -> [batch, channels, 1, 1]
1202
+ x = x.view(x.size(0), -1) # Flatten: [batch, channels]
1203
+
1204
+ # 2層の全結合層
1205
+ x = self.fc1(x) # 隠れ層
1206
+ x = self.activation(x) # 活性化
1207
+ x = self.fc2(x) # 出力層
1208
+ return x
1209
+
1210
+ # クラスの種類を定義
1211
+ class_types = ["circle", "star", "octagon"]
1212
+
1213
+ # クラス名をインデックスに変換する辞書
1214
+ class_to_idx = {cls: idx for idx, cls in enumerate(class_types)}
1215
+ print(class_to_idx) # {'circle': 0, 'star': 1, 'octagon': 2}
1216
+
1217
+ # 分類ヘッドの初期化
1218
+ classification_head = MultiLayerClassifierHead(input_dim=1280, hidden_dim=512, num_classes=len(class_types))
1219
+ classification_head = classification_head.to(accelerator.device)
1220
+
1221
+ image_logs = None
1222
+ for epoch in range(first_epoch, args.num_train_epochs):
1223
+ for step, batch in enumerate(train_dataloader):
1224
+ with accelerator.accumulate(controlnet):
1225
+ # Convert images to latent space
1226
+ if args.pretrained_vae_model_name_or_path is not None:
1227
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1228
+ else:
1229
+ pixel_values = batch["pixel_values"]
1230
+ latents = vae.encode(pixel_values).latent_dist.sample()
1231
+ latents = latents * vae.config.scaling_factor
1232
+ if args.pretrained_vae_model_name_or_path is None:
1233
+ latents = latents.to(weight_dtype)
1234
+
1235
+ # Sample noise that we'll add to the latents
1236
+ noise = torch.randn_like(latents)
1237
+ bsz = latents.shape[0]
1238
+
1239
+ # Sample a random timestep for each image
1240
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1241
+ timesteps = timesteps.long()
1242
+
1243
+ # Add noise to the latents according to the noise magnitude at each timestep
1244
+ # (this is the forward diffusion process)
1245
+ noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
1246
+ dtype=weight_dtype
1247
+ )
1248
+
1249
+ # ControlNet conditioning.
1250
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1251
+ down_block_res_samples, mid_block_res_sample = controlnet(
1252
+ noisy_latents,
1253
+ timesteps,
1254
+ encoder_hidden_states=batch["prompt_ids"],
1255
+ added_cond_kwargs=batch["unet_added_conditions"],
1256
+ controlnet_cond=controlnet_image,
1257
+ return_dict=False,
1258
+ )
1259
+
1260
+ # Predict the noise residual
1261
+ model_pred = unet(
1262
+ noisy_latents,
1263
+ timesteps,
1264
+ encoder_hidden_states=batch["prompt_ids"],
1265
+ added_cond_kwargs=batch["unet_added_conditions"],
1266
+ down_block_additional_residuals=[
1267
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1268
+ ],
1269
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1270
+ return_dict=False,
1271
+ )[0]
1272
+
1273
+ # Get the target for loss depending on the prediction type
1274
+ if noise_scheduler.config.prediction_type == "epsilon":
1275
+ target = noise
1276
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1277
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1278
+ else:
1279
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1280
+ mse_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1281
+
1282
+
1283
+ # print(mid_block_res_sample.shape) # torch.Size([1, 1280, 16, 16])
1284
+ # 分類タスクの損失を計算
1285
+ batch_labels = [] # 正解データを格納するリスト
1286
+ for current_prompt in batch["text"]: # バッチ内の各プロンプトを処理
1287
+ label = None # デフォルトの値を設定
1288
+ for class_name in class_types:
1289
+ if class_name in current_prompt:
1290
+ label = class_to_idx[class_name] # クラス名をインデックスに変換
1291
+ break
1292
+ if label is not None:
1293
+ batch_labels.append(label)
1294
+ else:
1295
+ raise ValueError(f"Prompt '{current_prompt}' に該当するクラスが見つかりません。")
1296
+ batch_labels = torch.tensor(batch_labels).to(accelerator.device)
1297
+
1298
+ classification_logits = classification_head(mid_block_res_sample)
1299
+ classification_loss = F.cross_entropy(classification_logits, batch_labels)
1300
+
1301
+ # 主タスクと分類タスクの損失を統合
1302
+ alpha = 0.1
1303
+ loss = mse_loss + alpha * classification_loss # alphaは分類タスクの重み
1304
+
1305
+ accelerator.backward(loss)
1306
+ if accelerator.sync_gradients:
1307
+ params_to_clip = controlnet.parameters()
1308
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1309
+ optimizer.step()
1310
+ lr_scheduler.step()
1311
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1312
+
1313
+ # Checks if the accelerator has performed an optimization step behind the scenes
1314
+ if accelerator.sync_gradients:
1315
+ progress_bar.update(1)
1316
+ global_step += 1
1317
+
1318
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1319
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1320
+ if global_step % args.checkpointing_steps == 0:
1321
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1322
+ if args.checkpoints_total_limit is not None:
1323
+ checkpoints = os.listdir(args.output_dir)
1324
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1325
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1326
+
1327
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1328
+ if len(checkpoints) >= args.checkpoints_total_limit:
1329
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1330
+ removing_checkpoints = checkpoints[0:num_to_remove]
1331
+
1332
+ logger.info(
1333
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1334
+ )
1335
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1336
+
1337
+ for removing_checkpoint in removing_checkpoints:
1338
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1339
+ shutil.rmtree(removing_checkpoint)
1340
+
1341
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1342
+ accelerator.save_state(save_path)
1343
+ logger.info(f"Saved state to {save_path}")
1344
+
1345
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1346
+ image_logs = log_validation(
1347
+ vae=vae,
1348
+ unet=unet,
1349
+ controlnet=controlnet,
1350
+ args=args,
1351
+ accelerator=accelerator,
1352
+ weight_dtype=weight_dtype,
1353
+ step=global_step,
1354
+ )
1355
+
1356
+ logs = {"mse_loss": mse_loss.detach().item(), "classification_loss": classification_loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1357
+ accelerator.log(logs, step=global_step)
1358
+ logs["batch_labels"] = batch_labels.detach().item()
1359
+ progress_bar.set_postfix(**logs)
1360
+
1361
+ if global_step >= args.max_train_steps:
1362
+ break
1363
+
1364
+ # Create the pipeline using using the trained modules and save it.
1365
+ accelerator.wait_for_everyone()
1366
+ if accelerator.is_main_process:
1367
+ controlnet = unwrap_model(controlnet)
1368
+ controlnet.save_pretrained(args.output_dir)
1369
+
1370
+ # Run a final round of validation.
1371
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1372
+ image_logs = None
1373
+ if args.validation_prompt is not None:
1374
+ image_logs = log_validation(
1375
+ vae=None,
1376
+ unet=None,
1377
+ controlnet=None,
1378
+ args=args,
1379
+ accelerator=accelerator,
1380
+ weight_dtype=weight_dtype,
1381
+ step=global_step,
1382
+ is_final_validation=True,
1383
+ )
1384
+
1385
+ if args.push_to_hub:
1386
+ save_model_card(
1387
+ repo_id,
1388
+ image_logs=image_logs,
1389
+ base_model=args.pretrained_model_name_or_path,
1390
+ repo_folder=args.output_dir,
1391
+ )
1392
+ upload_folder(
1393
+ repo_id=repo_id,
1394
+ folder_path=args.output_dir,
1395
+ commit_message="End of training",
1396
+ ignore_patterns=["step_*", "epoch_*"],
1397
+ )
1398
+
1399
+ accelerator.end_training()
1400
+
1401
+
1402
+ if __name__ == "__main__":
1403
+ args = parse_args()
1404
+ main(args)