|
|
|
seed_everything: 42 |
|
|
|
|
|
trainer: |
|
accelerator: auto |
|
strategy: auto |
|
devices: auto |
|
num_nodes: 1 |
|
|
|
logger: |
|
class_path: TensorBoardLogger |
|
init_args: |
|
save_dir: ./experiments |
|
name: finetune_region |
|
callbacks: |
|
- class_path: RichProgressBar |
|
- class_path: LearningRateMonitor |
|
init_args: |
|
logging_interval: epoch |
|
- class_path: EarlyStopping |
|
init_args: |
|
monitor: val/loss |
|
patience: 100 |
|
max_epochs: 300 |
|
check_val_every_n_epoch: 1 |
|
log_every_n_steps: 20 |
|
enable_checkpointing: true |
|
default_root_dir: ./experiments |
|
|
|
|
|
data: |
|
class_path: GenericNonGeoPixelwiseRegressionDataModule |
|
init_args: |
|
batch_size: 64 |
|
num_workers: 8 |
|
train_transform: |
|
- class_path: albumentations.HorizontalFlip |
|
init_args: |
|
p: 0.5 |
|
- class_path: albumentations.RandomRotate90 |
|
init_args: |
|
p: 0.5 |
|
- class_path: albumentations.VerticalFlip |
|
init_args: |
|
p: 0.5 |
|
- class_path: ToTensorV2 |
|
|
|
|
|
dataset_bands: |
|
- -1 |
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
- -1 |
|
- -1 |
|
- -1 |
|
- -1 |
|
output_bands: |
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
rgb_indices: |
|
- 2 |
|
- 1 |
|
- 0 |
|
|
|
train_data_root: train_images |
|
train_label_data_root: train_labels |
|
val_data_root: val_images |
|
val_label_data_root: val_labels |
|
test_data_root: test_images |
|
test_label_data_root: test_labels |
|
means: |
|
- 556.025024 |
|
- 910.020020 |
|
- 1039.141968 |
|
- 2665.447266 |
|
- 2361.062256 |
|
- 1633.309326 |
|
stds: |
|
- 413.787903 |
|
- 562.086670 |
|
- 819.830444 |
|
- 816.528381 |
|
- 1120.049438 |
|
- 1072.057861 |
|
|
|
no_label_replace: -1 |
|
|
|
no_data_replace: 0 |
|
|
|
|
|
model: |
|
class_path: terratorch.tasks.PixelwiseRegressionTask |
|
init_args: |
|
model_args: |
|
decoder: UperNetDecoder |
|
pretrained: false |
|
backbone: prithvi_swin_B |
|
backbone_drop_path_rate: 0.3 |
|
decoder_channels: 32 |
|
in_channels: 6 |
|
bands: |
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
num_frames: 1 |
|
head_dropout: 0.16 |
|
head_final_act: torch.nn.ReLU |
|
head_learned_upscale_layers: 2 |
|
loss: rmse |
|
ignore_index: -1 |
|
freeze_backbone: false |
|
freeze_decoder: false |
|
model_factory: PrithviModelFactory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer: |
|
class_path: torch.optim.AdamW |
|
init_args: |
|
lr: 5.0e-05 |
|
weight_decay: 0.3 |
|
lr_scheduler: |
|
class_path: ReduceLROnPlateau |
|
init_args: |
|
monitor: val/loss |
|
out_dtype: float32 |