# lightning.pytorch==2.4.0 | |
seed_everything: 42 | |
### Trainer configuration | |
trainer: | |
accelerator: auto | |
strategy: auto | |
devices: auto | |
num_nodes: 1 | |
# precision: 16-mixed | |
logger: | |
class_path: TensorBoardLogger | |
init_args: | |
save_dir: ../experiments | |
name: finetune_region | |
callbacks: | |
- class_path: RichProgressBar | |
- class_path: LearningRateMonitor | |
init_args: | |
logging_interval: epoch | |
- class_path: EarlyStopping | |
init_args: | |
monitor: val/loss | |
patience: 100 | |
max_epochs: 300 | |
check_val_every_n_epoch: 1 | |
log_every_n_steps: 20 | |
enable_checkpointing: true | |
default_root_dir: ./experiments | |
### Data configuration | |
data: | |
class_path: GenericNonGeoPixelwiseRegressionDataModule | |
init_args: | |
batch_size: 64 | |
num_workers: 8 | |
train_transform: | |
- class_path: albumentations.HorizontalFlip | |
init_args: | |
p: 0.5 | |
- class_path: albumentations.RandomRotate90 | |
init_args: | |
p: 0.5 | |
- class_path: albumentations.VerticalFlip | |
init_args: | |
p: 0.5 | |
- class_path: ToTensorV2 | |
# Specify all bands which are in the input data. | |
# -1 are placeholders for bands that are in the data but that we will discard | |
dataset_bands: | |
- -1 | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
- -1 | |
- -1 | |
- -1 | |
- -1 | |
output_bands: #Specify the bands which are used from the input data. | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
rgb_indices: | |
- 2 | |
- 1 | |
- 0 | |
# Directory roots to training, validation and test datasplits: | |
train_data_root: train_images | |
train_label_data_root: train_labels | |
val_data_root: val_images | |
val_label_data_root: val_labels | |
test_data_root: test_images | |
test_label_data_root: test_labels | |
means: # Mean value of the training dataset per band | |
- 556.025024 | |
- 910.020020 | |
- 1039.141968 | |
- 2665.447266 | |
- 2361.062256 | |
- 1633.309326 | |
stds: # Standard deviation of the training dataset per band | |
- 413.787903 | |
- 562.086670 | |
- 819.830444 | |
- 816.528381 | |
- 1120.049438 | |
- 1072.057861 | |
# Nodata value in label data | |
no_label_replace: -1 | |
# Nodata value in the input data | |
no_data_replace: 0 | |
### Model configuration | |
model: | |
class_path: terratorch.tasks.PixelwiseRegressionTask | |
init_args: | |
model_args: | |
decoder: UperNetDecoder | |
pretrained: false | |
backbone: prithvi_swin_B | |
backbone_drop_path_rate: 0.3 | |
decoder_channels: 32 | |
in_channels: 6 | |
bands: | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
num_frames: 1 | |
head_dropout: 0.16 | |
head_final_act: torch.nn.ReLU | |
head_learned_upscale_layers: 2 | |
loss: rmse | |
ignore_index: -1 | |
freeze_backbone: false | |
freeze_decoder: false | |
model_factory: PrithviModelFactory | |
# uncomment this block for tiled inference | |
# tiled_inference_parameters: | |
# h_crop: 224 | |
# h_stride: 192 | |
# w_crop: 224 | |
# w_stride: 192 | |
# average_patches: true | |
optimizer: | |
class_path: torch.optim.AdamW | |
init_args: | |
lr: 5.0e-05 | |
weight_decay: 0.3 | |
lr_scheduler: | |
class_path: ReduceLROnPlateau | |
init_args: | |
monitor: val/loss | |
out_dtype: float32 |