|
|
|
seed_everything: 0 |
|
|
|
|
|
trainer: |
|
accelerator: auto |
|
strategy: auto |
|
devices: auto |
|
num_nodes: 1 |
|
|
|
logger: |
|
|
|
|
|
class_path: lightning.pytorch.loggers.csv_logs.CSVLogger |
|
init_args: |
|
save_dir: ./experiments |
|
name: fine_tune_suhi |
|
callbacks: |
|
- class_path: RichProgressBar |
|
- class_path: LearningRateMonitor |
|
init_args: |
|
logging_interval: epoch |
|
- class_path: EarlyStopping |
|
init_args: |
|
monitor: val/loss |
|
patience: 600 |
|
max_epochs: 600 |
|
check_val_every_n_epoch: 1 |
|
log_every_n_steps: 10 |
|
enable_checkpointing: true |
|
default_root_dir: ./experiments |
|
out_dtype: float32 |
|
|
|
|
|
data: |
|
class_path: GenericNonGeoPixelwiseRegressionDataModule |
|
init_args: |
|
batch_size: 64 |
|
num_workers: 8 |
|
train_transform: |
|
- class_path: albumentations.HorizontalFlip |
|
init_args: |
|
p: 0.5 |
|
- class_path: albumentations.Rotate |
|
init_args: |
|
limit: 30 |
|
border_mode: 0 |
|
value: 0 |
|
mask_value: 1 |
|
p: 0.5 |
|
- class_path: ToTensorV2 |
|
|
|
dataset_bands: |
|
|
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
|
|
- 7 |
|
|
|
- 8 |
|
|
|
- 9 |
|
|
|
- 10 |
|
|
|
- 11 |
|
|
|
- 12 |
|
|
|
- 13 |
|
|
|
- 14 |
|
|
|
|
|
output_bands: |
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
- 7 |
|
rgb_indices: |
|
- 2 |
|
- 1 |
|
- 0 |
|
|
|
train_data_root: train/inputs |
|
train_label_data_root: train/targets |
|
val_data_root: val/inputs |
|
val_label_data_root: val/targets |
|
test_data_root: test/inputs |
|
test_label_data_root: test/targets |
|
img_grep: "*.inputs.tif" |
|
label_grep: "*.lst.tif" |
|
|
|
no_data_replace: 0 |
|
|
|
no_label_replace: -9999 |
|
|
|
means: |
|
- 702.4754028320312 |
|
- 1023.23291015625 |
|
- 1118.8924560546875 |
|
- 2440.750732421875 |
|
- 2052.705810546875 |
|
- 1514.15087890625 |
|
- 21.031919479370117 |
|
|
|
stds: |
|
- 554.8255615234375 |
|
- 613.5565185546875 |
|
- 745.929443359375 |
|
- 715.0111083984375 |
|
- 761.47607421875 |
|
- 734.991943359375 |
|
- 8.66781997680664 |
|
|
|
|
|
model: |
|
class_path: terratorch.tasks.PixelwiseRegressionTask |
|
init_args: |
|
model_args: |
|
decoder: UperNetDecoder |
|
pretrained: false |
|
backbone: prithvi_swin_L |
|
img_size: 224 |
|
backbone_drop_path_rate: 0.3 |
|
decoder_channels: 256 |
|
in_channels: 7 |
|
bands: |
|
- BLUE |
|
- GREEN |
|
- RED |
|
- NIR_NARROW |
|
- SWIR_1 |
|
- SWIR_2 |
|
- 7 |
|
num_frames: 1 |
|
loss: rmse |
|
aux_heads: |
|
- name: aux_head |
|
decoder: IdentityDecoder |
|
decoder_args: |
|
head_dropout: 0.5 |
|
head_channel_list: |
|
- 1 |
|
head_final_act: torch.nn.LazyLinear |
|
aux_loss: |
|
aux_head: 0.4 |
|
ignore_index: -9999 |
|
freeze_backbone: false |
|
freeze_decoder: false |
|
model_factory: PrithviModelFactory |
|
|
|
tiled_inference_parameters: |
|
h_crop: 224 |
|
h_stride: 224 |
|
w_crop: 224 |
|
w_stride: 224 |
|
average_patches: true |
|
optimizer: |
|
class_path: torch.optim.AdamW |
|
init_args: |
|
lr: 0.0001 |
|
weight_decay: 0.05 |
|
lr_scheduler: |
|
class_path: ReduceLROnPlateau |
|
init_args: |
|
monitor: val/loss |
|
|