# lightning.pytorch==2.1.1 seed_everything: 0 ### Trainer configuration trainer: accelerator: auto strategy: auto devices: auto num_nodes: 1 # precision: 16-mixed logger: # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line #class_path: TensorBoardLogger class_path: lightning.pytorch.loggers.csv_logs.CSVLogger init_args: save_dir: ./experiments name: fine_tune_suhi callbacks: - class_path: RichProgressBar - class_path: LearningRateMonitor init_args: logging_interval: epoch - class_path: EarlyStopping init_args: monitor: val/loss patience: 600 max_epochs: 600 check_val_every_n_epoch: 1 log_every_n_steps: 10 enable_checkpointing: true default_root_dir: ./experiments out_dtype: float32 ### Data configuration data: class_path: GenericNonGeoPixelwiseRegressionDataModule init_args: batch_size: 1 num_workers: 8 train_transform: - class_path: albumentations.HorizontalFlip init_args: p: 0.5 - class_path: albumentations.Rotate init_args: limit: 30 border_mode: 0 # cv2.BORDER_CONSTANT value: 0 mask_value: 1 p: 0.5 - class_path: ToTensorV2 # Specify all bands which are in the input data. dataset_bands: # 6 HLS bands - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 # ERA5-Land t2m_spatial_avg - 7 # ERA5-Land t2m_sunrise_avg - 8 # ERA5-Land t2m_midnight_avg - 9 # ERA5-Land t2m_delta_avg - 10 # cos_tod - 11 # sin_tod - 12 # cos_doy - 13 # sin_doy - 14 # Specify the bands which are used from the input data. # Bands 8 - 14 were discarded in the final model output_bands: - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 - 7 rgb_indices: - 2 - 1 - 0 # Directory roots to training, validation and test datasplits: train_data_root: train/inputs train_label_data_root: train/targets val_data_root: val/inputs val_label_data_root: val/targets test_data_root: test/inputs test_label_data_root: test/targets img_grep: "*.inputs.tif" label_grep: "*.lst.tif" # Nodata value in the input data no_data_replace: 0 # Nodata value in label (target) data no_label_replace: -9999 # Mean value of the training dataset per band means: - 702.4754028320312 - 1023.23291015625 - 1118.8924560546875 - 2440.750732421875 - 2052.705810546875 - 1514.15087890625 - 21.031919479370117 # Standard deviation of the training dataset per band stds: - 554.8255615234375 - 613.5565185546875 - 745.929443359375 - 715.0111083984375 - 761.47607421875 - 734.991943359375 - 8.66781997680664 ### Model configuration model: class_path: terratorch.tasks.PixelwiseRegressionTask init_args: model_args: decoder: UperNetDecoder pretrained: false backbone: prithvi_swin_L img_size: 224 backbone_drop_path_rate: 0.3 decoder_channels: 256 in_channels: 7 bands: - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 - 7 num_frames: 1 loss: rmse aux_heads: - name: aux_head decoder: IdentityDecoder decoder_args: head_dropout: 0.5 head_channel_list: - 1 head_final_act: torch.nn.LazyLinear aux_loss: aux_head: 0.4 ignore_index: -9999 freeze_backbone: false freeze_decoder: false model_factory: PrithviModelFactory # This block is commented out when inferencing on full tiles. # It is possible to inference on full tiles with this paramter on, the benefit is that the compute requirement is smaller. # However, using this to inference on a full tile will introduce artefacting/"patchy" predictions. # tiled_inference_parameters: # h_crop: 224 # h_stride: 224 # w_crop: 224 # w_stride: 224 # average_patches: true optimizer: class_path: torch.optim.AdamW init_args: lr: 0.0001 weight_decay: 0.05 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss