File size: 5,292 Bytes
cfb7702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
model:
  base_learning_rate: 1.0e-4
  target: sgm.models.diffusion.DiffusionEngine
  params:
    scale_factor: 0.13025
    disable_first_stage_autocast: True
    log_keys:
      - cls

    scheduler_config:
      target: sgm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [10000]
        cycle_lengths: [10000000000000]
        f_start: [1.e-6]
        f_max: [1.]
        f_min: [1.]

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
      params:
        num_idx: 1000

        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        use_checkpoint: True
        in_channels: 4
        out_channels: 4
        model_channels: 256
        attention_resolutions: [1, 2, 4]
        num_res_blocks: 2
        channel_mult: [1, 2, 4]
        num_head_channels: 64
        num_classes: sequential
        adm_in_channels: 1024
        transformer_depth: 1
        context_dim: 1024
        spatial_transformer_attn_type: softmax-xformers

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          - is_trainable: True
            input_key: cls
            ucg_rate: 0.2
            target: sgm.modules.encoders.modules.ClassEmbedder
            params:
              add_sequence_dim: True
              embed_dim: 1024
              n_classes: 1000

          - is_trainable: False
            ucg_rate: 0.2
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - is_trainable: False
            input_key: crop_coords_top_left
            ucg_rate: 0.2
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

    first_stage_config:
      target: sgm.models.autoencoder.AutoencoderKL
      params:
        ckpt_path: CKPT_PATH
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          attn_type: vanilla-xformers
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [1, 2, 4, 4]
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:        
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
          params:
            num_idx: 1000

            discretization_config:
              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 5.0

data:
  target: sgm.data.dataset.StableDataModuleFromConfig
  params:
    train:
      datapipeline:
        urls:
          # USER: adapt this path the root of your custom dataset
          - DATA_PATH
        pipeline_config:
          shardshuffle: 10000
          sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM

        decoders:
          - pil

        postprocessors:
          - target: sdata.mappers.TorchVisionImageTransforms
            params:
              key: jpg # USER: you might wanna adapt this for your custom dataset
              transforms:
                - target: torchvision.transforms.Resize
                  params:
                    size: 256
                    interpolation: 3
                - target: torchvision.transforms.ToTensor
          - target: sdata.mappers.Rescaler

          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
            params:
              h_key: height # USER: you might wanna adapt this for your custom dataset
              w_key: width # USER: you might wanna adapt this for your custom dataset

      loader:
        batch_size: 64
        num_workers: 6

lightning:
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 25000

    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        enable_autocast: False
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          N: 8
          n_rows: 2

  trainer:
    devices: 0,
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 1000