# Model architecture model: input_dim: 3 # RGB images hidden_dim: 512 num_blocks: 8 num_heads: 8 patch_size: 8 patch_stride: 4 time_freq_dim: 256 time_max_period: 1024 mlp_ratio: 4 use_bias: false padding: "SAME" pos_embed_cls_token: false pos_embed_extra_tokens: 0 # Training parameters training: learning_rate: 1.0e-4 batch_size: 128 num_steps: 1_000_000 warmup_pct: 0.01 weight_decay: 0.0 grad_clip_norm: 100.0 # Checkpointing and logging checkpointing: log_every: 1_000 plot_every: 10_000 save_every: 10_000 resume_from_checkpoint: null # Data data: train_split: 0.9 # 90% for training, 10% for testing random_seed: 42