PyTorch
jubeku's picture
Update readme and add config file
5029ebb verified
raw
history blame
3.35 kB
# lightning.pytorch==2.4.0
seed_everything: 42
### Trainer configuration
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: ./experiments
name: finetune_region
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 300
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: ./experiments
### Data configuration
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 64
num_workers: 8
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.RandomRotate90
init_args:
p: 0.5
- class_path: albumentations.VerticalFlip
init_args:
p: 0.5
- class_path: ToTensorV2
# Specify all bands which are in the input data.
# -1 are placeholders for bands that are in the data but that we will discard
dataset_bands:
- -1
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- -1
- -1
- -1
- -1
output_bands: #Specify the bands which are used from the input data.
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
# Directory roots to training, validation and test datasplits:
train_data_root: train_images
train_label_data_root: train_labels
val_data_root: val_images
val_label_data_root: val_labels
test_data_root: test_images
test_label_data_root: test_labels
means: # Mean value of the training dataset per band
- 556.025024
- 910.020020
- 1039.141968
- 2665.447266
- 2361.062256
- 1633.309326
stds: # Standard deviation of the training dataset per band
- 413.787903
- 562.086670
- 819.830444
- 816.528381
- 1120.049438
- 1072.057861
# Nodata value in label data
no_label_replace: -1
# Nodata value in the input data
no_data_replace: 0
### Model configuration
model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_swin_B
backbone_drop_path_rate: 0.3
decoder_channels: 32
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.16
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 5.0e-05
weight_decay: 0.3
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
out_dtype: float32