# lightning.pytorch==2.4.0 seed_everything: 42 ### Trainer configuration trainer: accelerator: auto strategy: auto devices: auto num_nodes: 1 # precision: 16-mixed logger: class_path: TensorBoardLogger init_args: save_dir: ../experiments name: finetune_region callbacks: - class_path: RichProgressBar - class_path: LearningRateMonitor init_args: logging_interval: epoch - class_path: EarlyStopping init_args: monitor: val/loss patience: 100 max_epochs: 300 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true default_root_dir: ./experiments ### Data configuration data: class_path: GenericNonGeoPixelwiseRegressionDataModule init_args: batch_size: 64 num_workers: 8 train_transform: - class_path: albumentations.HorizontalFlip init_args: p: 0.5 - class_path: albumentations.RandomRotate90 init_args: p: 0.5 - class_path: albumentations.VerticalFlip init_args: p: 0.5 - class_path: ToTensorV2 # Specify all bands which are in the input data. # -1 are placeholders for bands that are in the data but that we will discard dataset_bands: - -1 - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 - -1 - -1 - -1 - -1 output_bands: #Specify the bands which are used from the input data. - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 rgb_indices: - 2 - 1 - 0 # Directory roots to training, validation and test datasplits: train_data_root: train_images train_label_data_root: train_labels val_data_root: val_images val_label_data_root: val_labels test_data_root: test_images test_label_data_root: test_labels means: # Mean value of the training dataset per band - 556.025024 - 910.020020 - 1039.141968 - 2665.447266 - 2361.062256 - 1633.309326 stds: # Standard deviation of the training dataset per band - 413.787903 - 562.086670 - 819.830444 - 816.528381 - 1120.049438 - 1072.057861 # Nodata value in label data no_label_replace: -1 # Nodata value in the input data no_data_replace: 0 ### Model configuration model: class_path: terratorch.tasks.PixelwiseRegressionTask init_args: model_args: decoder: UperNetDecoder pretrained: false backbone: prithvi_swin_B backbone_drop_path_rate: 0.3 decoder_channels: 32 in_channels: 6 bands: - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 num_frames: 1 head_dropout: 0.16 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 loss: rmse ignore_index: -1 freeze_backbone: false freeze_decoder: false model_factory: PrithviModelFactory # uncomment this block for tiled inference # tiled_inference_parameters: # h_crop: 224 # h_stride: 192 # w_crop: 224 # w_stride: 192 # average_patches: true optimizer: class_path: torch.optim.AdamW init_args: lr: 5.0e-05 weight_decay: 0.3 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss out_dtype: float32