Files changed (2) hide show
  1. configs/config.yaml +0 -174
  2. configs/config_full_tile.yaml +0 -176
configs/config.yaml DELETED
@@ -1,174 +0,0 @@
1
- # lightning.pytorch==2.1.1
2
- seed_everything: 0
3
-
4
- ### Trainer configuration
5
- trainer:
6
- accelerator: auto
7
- strategy: auto
8
- devices: auto
9
- num_nodes: 1
10
- # precision: 16-mixed
11
- logger:
12
- # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
13
- #class_path: TensorBoardLogger
14
- class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
15
- init_args:
16
- save_dir: ./experiments
17
- name: fine_tune_suhi
18
- callbacks:
19
- - class_path: RichProgressBar
20
- - class_path: LearningRateMonitor
21
- init_args:
22
- logging_interval: epoch
23
- - class_path: EarlyStopping
24
- init_args:
25
- monitor: val/loss
26
- patience: 600
27
- max_epochs: 600
28
- check_val_every_n_epoch: 1
29
- log_every_n_steps: 10
30
- enable_checkpointing: true
31
- default_root_dir: ./experiments
32
- out_dtype: float32
33
-
34
- ### Data configuration
35
- data:
36
- class_path: GenericNonGeoPixelwiseRegressionDataModule
37
- init_args:
38
- batch_size: 64
39
- num_workers: 8
40
- train_transform:
41
- - class_path: albumentations.HorizontalFlip
42
- init_args:
43
- p: 0.5
44
- - class_path: albumentations.Rotate
45
- init_args:
46
- limit: 30
47
- border_mode: 0 # cv2.BORDER_CONSTANT
48
- value: 0
49
- mask_value: 1
50
- p: 0.5
51
- - class_path: ToTensorV2
52
- # Specify all bands which are in the input data.
53
- dataset_bands:
54
- # 6 HLS bands
55
- - BLUE
56
- - GREEN
57
- - RED
58
- - NIR_NARROW
59
- - SWIR_1
60
- - SWIR_2
61
- # ERA5-Land t2m_spatial_avg
62
- - 7
63
- # ERA5-Land t2m_sunrise_avg
64
- - 8
65
- # ERA5-Land t2m_midnight_avg
66
- - 9
67
- # ERA5-Land t2m_delta_avg
68
- - 10
69
- # cos_tod
70
- - 11
71
- # sin_tod
72
- - 12
73
- # cos_doy
74
- - 13
75
- # sin_doy
76
- - 14
77
- # Specify the bands which are used from the input data.
78
- # Bands 8 - 14 were discarded in the final model
79
- output_bands:
80
- - BLUE
81
- - GREEN
82
- - RED
83
- - NIR_NARROW
84
- - SWIR_1
85
- - SWIR_2
86
- - 7
87
- rgb_indices:
88
- - 2
89
- - 1
90
- - 0
91
- # Directory roots to training, validation and test datasplits:
92
- train_data_root: train/inputs
93
- train_label_data_root: train/targets
94
- val_data_root: val/inputs
95
- val_label_data_root: val/targets
96
- test_data_root: test/inputs
97
- test_label_data_root: test/targets
98
- img_grep: "*.inputs.tif"
99
- label_grep: "*.lst.tif"
100
- # Nodata value in the input data
101
- no_data_replace: 0
102
- # Nodata value in label (target) data
103
- no_label_replace: -9999
104
- # Mean value of the training dataset per band
105
- means:
106
- - 702.4754028320312
107
- - 1023.23291015625
108
- - 1118.8924560546875
109
- - 2440.750732421875
110
- - 2052.705810546875
111
- - 1514.15087890625
112
- - 21.031919479370117
113
- # Standard deviation of the training dataset per band
114
- stds:
115
- - 554.8255615234375
116
- - 613.5565185546875
117
- - 745.929443359375
118
- - 715.0111083984375
119
- - 761.47607421875
120
- - 734.991943359375
121
- - 8.66781997680664
122
-
123
- ### Model configuration
124
- model:
125
- class_path: terratorch.tasks.PixelwiseRegressionTask
126
- init_args:
127
- model_args:
128
- decoder: UperNetDecoder
129
- pretrained: false
130
- backbone: prithvi_swin_L
131
- img_size: 224
132
- backbone_drop_path_rate: 0.3
133
- decoder_channels: 256
134
- in_channels: 7
135
- bands:
136
- - BLUE
137
- - GREEN
138
- - RED
139
- - NIR_NARROW
140
- - SWIR_1
141
- - SWIR_2
142
- - 7
143
- num_frames: 1
144
- loss: rmse
145
- aux_heads:
146
- - name: aux_head
147
- decoder: IdentityDecoder
148
- decoder_args:
149
- head_dropout: 0.5
150
- head_channel_list:
151
- - 1
152
- head_final_act: torch.nn.LazyLinear
153
- aux_loss:
154
- aux_head: 0.4
155
- ignore_index: -9999
156
- freeze_backbone: false
157
- freeze_decoder: false
158
- model_factory: PrithviModelFactory
159
- # uncomment this block for tiled inference
160
- tiled_inference_parameters:
161
- h_crop: 224
162
- h_stride: 224
163
- w_crop: 224
164
- w_stride: 224
165
- average_patches: true
166
- optimizer:
167
- class_path: torch.optim.AdamW
168
- init_args:
169
- lr: 0.0001
170
- weight_decay: 0.05
171
- lr_scheduler:
172
- class_path: ReduceLROnPlateau
173
- init_args:
174
- monitor: val/loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/config_full_tile.yaml DELETED
@@ -1,176 +0,0 @@
1
- # lightning.pytorch==2.1.1
2
- seed_everything: 0
3
-
4
- ### Trainer configuration
5
- trainer:
6
- accelerator: auto
7
- strategy: auto
8
- devices: auto
9
- num_nodes: 1
10
- # precision: 16-mixed
11
- logger:
12
- # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
13
- #class_path: TensorBoardLogger
14
- class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
15
- init_args:
16
- save_dir: ./experiments
17
- name: fine_tune_suhi
18
- callbacks:
19
- - class_path: RichProgressBar
20
- - class_path: LearningRateMonitor
21
- init_args:
22
- logging_interval: epoch
23
- - class_path: EarlyStopping
24
- init_args:
25
- monitor: val/loss
26
- patience: 600
27
- max_epochs: 600
28
- check_val_every_n_epoch: 1
29
- log_every_n_steps: 10
30
- enable_checkpointing: true
31
- default_root_dir: ./experiments
32
- out_dtype: float32
33
-
34
- ### Data configuration
35
- data:
36
- class_path: GenericNonGeoPixelwiseRegressionDataModule
37
- init_args:
38
- batch_size: 1
39
- num_workers: 8
40
- train_transform:
41
- - class_path: albumentations.HorizontalFlip
42
- init_args:
43
- p: 0.5
44
- - class_path: albumentations.Rotate
45
- init_args:
46
- limit: 30
47
- border_mode: 0 # cv2.BORDER_CONSTANT
48
- value: 0
49
- mask_value: 1
50
- p: 0.5
51
- - class_path: ToTensorV2
52
- # Specify all bands which are in the input data.
53
- dataset_bands:
54
- # 6 HLS bands
55
- - BLUE
56
- - GREEN
57
- - RED
58
- - NIR_NARROW
59
- - SWIR_1
60
- - SWIR_2
61
- # ERA5-Land t2m_spatial_avg
62
- - 7
63
- # ERA5-Land t2m_sunrise_avg
64
- - 8
65
- # ERA5-Land t2m_midnight_avg
66
- - 9
67
- # ERA5-Land t2m_delta_avg
68
- - 10
69
- # cos_tod
70
- - 11
71
- # sin_tod
72
- - 12
73
- # cos_doy
74
- - 13
75
- # sin_doy
76
- - 14
77
- # Specify the bands which are used from the input data.
78
- # Bands 8 - 14 were discarded in the final model
79
- output_bands:
80
- - BLUE
81
- - GREEN
82
- - RED
83
- - NIR_NARROW
84
- - SWIR_1
85
- - SWIR_2
86
- - 7
87
- rgb_indices:
88
- - 2
89
- - 1
90
- - 0
91
- # Directory roots to training, validation and test datasplits:
92
- train_data_root: train/inputs
93
- train_label_data_root: train/targets
94
- val_data_root: val/inputs
95
- val_label_data_root: val/targets
96
- test_data_root: test/inputs
97
- test_label_data_root: test/targets
98
- img_grep: "*.inputs.tif"
99
- label_grep: "*.lst.tif"
100
- # Nodata value in the input data
101
- no_data_replace: 0
102
- # Nodata value in label (target) data
103
- no_label_replace: -9999
104
- # Mean value of the training dataset per band
105
- means:
106
- - 702.4754028320312
107
- - 1023.23291015625
108
- - 1118.8924560546875
109
- - 2440.750732421875
110
- - 2052.705810546875
111
- - 1514.15087890625
112
- - 21.031919479370117
113
- # Standard deviation of the training dataset per band
114
- stds:
115
- - 554.8255615234375
116
- - 613.5565185546875
117
- - 745.929443359375
118
- - 715.0111083984375
119
- - 761.47607421875
120
- - 734.991943359375
121
- - 8.66781997680664
122
-
123
- ### Model configuration
124
- model:
125
- class_path: terratorch.tasks.PixelwiseRegressionTask
126
- init_args:
127
- model_args:
128
- decoder: UperNetDecoder
129
- pretrained: false
130
- backbone: prithvi_swin_L
131
- img_size: 224
132
- backbone_drop_path_rate: 0.3
133
- decoder_channels: 256
134
- in_channels: 7
135
- bands:
136
- - BLUE
137
- - GREEN
138
- - RED
139
- - NIR_NARROW
140
- - SWIR_1
141
- - SWIR_2
142
- - 7
143
- num_frames: 1
144
- loss: rmse
145
- aux_heads:
146
- - name: aux_head
147
- decoder: IdentityDecoder
148
- decoder_args:
149
- head_dropout: 0.5
150
- head_channel_list:
151
- - 1
152
- head_final_act: torch.nn.LazyLinear
153
- aux_loss:
154
- aux_head: 0.4
155
- ignore_index: -9999
156
- freeze_backbone: false
157
- freeze_decoder: false
158
- model_factory: PrithviModelFactory
159
- # This block is commented out when inferencing on full tiles.
160
- # It is possible to inference on full tiles with this paramter on, the benefit is that the compute requirement is smaller.
161
- # However, using this to inference on a full tile will introduce artefacting/"patchy" predictions.
162
- # tiled_inference_parameters:
163
- # h_crop: 224
164
- # h_stride: 224
165
- # w_crop: 224
166
- # w_stride: 224
167
- # average_patches: true
168
- optimizer:
169
- class_path: torch.optim.AdamW
170
- init_args:
171
- lr: 0.0001
172
- weight_decay: 0.05
173
- lr_scheduler:
174
- class_path: ReduceLROnPlateau
175
- init_args:
176
- monitor: val/loss