PyTorch
File size: 3,351 Bytes
62c1d5f
 
 
 
 
 
 
 
 
 
 
 
 
48d048e
62c1d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# lightning.pytorch==2.4.0
seed_everything: 42

### Trainer configuration
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  # precision: 16-mixed
  logger:
    class_path: TensorBoardLogger
    init_args:
      save_dir: ../experiments
      name: finetune_region
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 100
  max_epochs: 300
  check_val_every_n_epoch: 1
  log_every_n_steps: 20
  enable_checkpointing: true
  default_root_dir: ./experiments

### Data configuration
data:
  class_path: GenericNonGeoPixelwiseRegressionDataModule
  init_args:
    batch_size: 64
    num_workers: 8
    train_transform:
    - class_path: albumentations.HorizontalFlip
      init_args:
        p: 0.5
    - class_path: albumentations.RandomRotate90
      init_args:
        p: 0.5
    - class_path: albumentations.VerticalFlip
      init_args:
        p: 0.5
    - class_path: ToTensorV2
    # Specify all bands which are in the input data. 
    # -1 are placeholders for bands that are in the data but that we will discard
    dataset_bands: 
      - -1
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - -1
      - -1
      - -1
      - -1
    output_bands: #Specify the bands which are used from the input data.
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    rgb_indices:
      - 2
      - 1
      - 0
    # Directory roots to training, validation and test datasplits:
    train_data_root: train_images
    train_label_data_root: train_labels
    val_data_root: val_images
    val_label_data_root: val_labels
    test_data_root: test_images
    test_label_data_root: test_labels
    means: # Mean value of the training dataset per band
    - 556.025024
    - 910.020020
    - 1039.141968
    - 2665.447266
    - 2361.062256
    - 1633.309326
    stds: # Standard deviation of the training dataset per band
    - 413.787903
    - 562.086670
    - 819.830444
    - 816.528381
    - 1120.049438
    - 1072.057861
    # Nodata value in label data 
    no_label_replace: -1 
    # Nodata value in the input data
    no_data_replace: 0

### Model configuration
model:
  class_path: terratorch.tasks.PixelwiseRegressionTask 
  init_args:
    model_args:
      decoder: UperNetDecoder
      pretrained: false
      backbone: prithvi_swin_B
      backbone_drop_path_rate: 0.3
      decoder_channels: 32
      in_channels: 6
      bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
      num_frames: 1
      head_dropout: 0.16
      head_final_act: torch.nn.ReLU
      head_learned_upscale_layers: 2
    loss: rmse
    ignore_index: -1
    freeze_backbone: false
    freeze_decoder: false
    model_factory: PrithviModelFactory
    # uncomment this block for tiled inference
    # tiled_inference_parameters:
    #   h_crop: 224
    #   h_stride: 192
    #   w_crop: 224
    #   w_stride: 192
    #   average_patches: true
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 5.0e-05
    weight_decay: 0.3
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss
out_dtype: float32