| # lightning.pytorch==2.1.1 | |
| seed_everything: 0 | |
| ### Trainer configuration | |
| trainer: | |
| accelerator: auto | |
| strategy: auto | |
| devices: auto | |
| num_nodes: 1 | |
| # precision: 16-mixed | |
| logger: | |
| # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line | |
| #class_path: TensorBoardLogger | |
| class_path: lightning.pytorch.loggers.csv_logs.CSVLogger | |
| init_args: | |
| save_dir: ./experiments | |
| name: fine_tune_suhi | |
| callbacks: | |
| - class_path: RichProgressBar | |
| - class_path: LearningRateMonitor | |
| init_args: | |
| logging_interval: epoch | |
| - class_path: EarlyStopping | |
| init_args: | |
| monitor: val/loss | |
| patience: 600 | |
| max_epochs: 600 | |
| check_val_every_n_epoch: 1 | |
| log_every_n_steps: 10 | |
| enable_checkpointing: true | |
| default_root_dir: ./experiments | |
| out_dtype: float32 | |
| ### Data configuration | |
| data: | |
| class_path: GenericNonGeoPixelwiseRegressionDataModule | |
| init_args: | |
| batch_size: 1 | |
| num_workers: 8 | |
| train_transform: | |
| - class_path: albumentations.HorizontalFlip | |
| init_args: | |
| p: 0.5 | |
| - class_path: albumentations.Rotate | |
| init_args: | |
| limit: 30 | |
| border_mode: 0 # cv2.BORDER_CONSTANT | |
| value: 0 | |
| mask_value: 1 | |
| p: 0.5 | |
| - class_path: ToTensorV2 | |
| # Specify all bands which are in the input data. | |
| dataset_bands: | |
| # 6 HLS bands | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| # ERA5-Land t2m_spatial_avg | |
| - 7 | |
| # ERA5-Land t2m_sunrise_avg | |
| - 8 | |
| # ERA5-Land t2m_midnight_avg | |
| - 9 | |
| # ERA5-Land t2m_delta_avg | |
| - 10 | |
| # cos_tod | |
| - 11 | |
| # sin_tod | |
| - 12 | |
| # cos_doy | |
| - 13 | |
| # sin_doy | |
| - 14 | |
| # Specify the bands which are used from the input data. | |
| # Bands 8 - 14 were discarded in the final model | |
| output_bands: | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| - 7 | |
| rgb_indices: | |
| - 2 | |
| - 1 | |
| - 0 | |
| # Directory roots to training, validation and test datasplits: | |
| train_data_root: train/inputs | |
| train_label_data_root: train/targets | |
| val_data_root: val/inputs | |
| val_label_data_root: val/targets | |
| test_data_root: test/inputs | |
| test_label_data_root: test/targets | |
| img_grep: "*.inputs.tif" | |
| label_grep: "*.lst.tif" | |
| # Nodata value in the input data | |
| no_data_replace: 0 | |
| # Nodata value in label (target) data | |
| no_label_replace: -9999 | |
| # Mean value of the training dataset per band | |
| means: | |
| - 702.4754028320312 | |
| - 1023.23291015625 | |
| - 1118.8924560546875 | |
| - 2440.750732421875 | |
| - 2052.705810546875 | |
| - 1514.15087890625 | |
| - 21.031919479370117 | |
| # Standard deviation of the training dataset per band | |
| stds: | |
| - 554.8255615234375 | |
| - 613.5565185546875 | |
| - 745.929443359375 | |
| - 715.0111083984375 | |
| - 761.47607421875 | |
| - 734.991943359375 | |
| - 8.66781997680664 | |
| ### Model configuration | |
| model: | |
| class_path: terratorch.tasks.PixelwiseRegressionTask | |
| init_args: | |
| model_args: | |
| decoder: UperNetDecoder | |
| pretrained: false | |
| backbone: prithvi_swin_L | |
| #img_size: 224 | |
| backbone_drop_path_rate: 0.3 | |
| decoder_channels: 256 | |
| in_channels: 7 | |
| bands: | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| - 7 | |
| num_frames: 1 | |
| loss: rmse | |
| aux_heads: | |
| - name: aux_head | |
| decoder: IdentityDecoder | |
| decoder_args: | |
| head_dropout: 0.5 | |
| head_channel_list: | |
| - 1 | |
| head_final_act: torch.nn.LazyLinear | |
| aux_loss: | |
| aux_head: 0.4 | |
| ignore_index: -9999 | |
| freeze_backbone: false | |
| freeze_decoder: false | |
| model_factory: PrithviModelFactory | |
| # This block is commented out when inferencing on full tiles. | |
| # It is possible to inference on full tiles with this paramter on, the benefit is that the compute requirement is smaller. | |
| # However, using this to inference on a full tile will introduce artefacting/"patchy" predictions. | |
| # tiled_inference_parameters: | |
| # h_crop: 224 | |
| # h_stride: 224 | |
| # w_crop: 224 | |
| # w_stride: 224 | |
| # average_patches: true | |
| optimizer: | |
| class_path: torch.optim.AdamW | |
| init_args: | |
| lr: 0.0001 | |
| weight_decay: 0.05 | |
| lr_scheduler: | |
| class_path: ReduceLROnPlateau | |
| init_args: | |
| monitor: val/loss | |