image_modification / configs /latent-diffusion /lsun_churches_f8-autoencoder-ldm.yaml
ablattmann
add code
aea2f58
raw
history blame
2.6 kB
model:
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0155
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
loss_type: l1
first_stage_key: "image"
cond_stage_key: "image"
image_size: 32
channels: 4
cond_stage_trainable: False
concat_mode: False
scale_by_std: True
monitor: 'val/loss_simple_ema'
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [ 1.]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 192
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
num_res_blocks: 2
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
num_heads: 8
use_scale_shift_norm: True
resblock_updown: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: "val/rec_loss"
ckpt_path: "/export/compvis-nfs/user/ablattma/logs/braket/2021-11-26T11-25-56_lsun_churches-convae-f8-ft_from_oi/checkpoints/step=000180071-fidfrechet_inception_distance=2.335.ckpt"
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: "__is_unconditional__"
data:
target: main.DataModuleFromConfig
params:
batch_size: 24 # TODO: was 96 in our experiments
num_workers: 5
wrap: False
train:
target: ldm.data.lsun.LSUNChurchesTrain
params:
size: 256
validation:
target: ldm.data.lsun.LSUNChurchesValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000 # TODO 5000
max_images: 8
increase_log_steps: False
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_train_steps: 20000
trainer:
benchmark: True