The transformertf package exposes the entrypoint transformertf, which is implemented using Lightning CLI, which allows the user to call

transformertf fit -c [config.yml]

to train a model.

Sample configuration

A sample configuration to train a TFT is the following:

# lightning.pytorch==2.2.2
seed_everything: 0
trainer:
  logger:
    class_path: lightning.pytorch.loggers.NeptuneLogger
    init_args:
      api_key: null
      project: lua/Pretrain-TFT-MBI
      run: null
      log_model_checkpoints: true
    dict_kwargs:
      dependencies: infer
  callbacks:
  - class_path: lightning.pytorch.callbacks.EarlyStopping
    init_args:
      monitor: validation/loss
      patience: 10
      mode: min
  - class_path: transformertf.callbacks.SetOptimizerLRCallback
    init_args:
      lr_file: /tmp/lr.txt
      'on': step
  - class_path: transformertf.callbacks.LogHparamsCallback
    init_args:
      monitor: validation/loss
  - class_path: lightning.pytorch.callbacks.RichProgressBar
    init_args:
      refresh_rate: 1
      leave: false
      theme:
        metrics_format: .2e

  max_epochs: 50
  min_epochs: 25
  val_check_interval: 0.1
  check_val_every_n_epoch: 1
  num_sanity_val_steps: null
  log_every_n_steps: 10
  gradient_clip_val: 1.0
  use_distributed_sampler: false
verbose: 0
transfer_ckpt: null
lr_step_interval: epoch
lr_monitor:
  logging_interval: epoch
  log_momentum: false
  log_weight_decay: false
model_summary:
  max_depth: 2
fit:
  checkpoint_every:
    dirpath: checkpoints
    filename: epoch={epoch}-RMSE={validation/RMSE:.4f}
    monitor: validation/RMSE
    every_n_epochs: 50
    auto_insert_metric_name: false
  checkpoint_best:
    dirpath: checkpoints
    filename: epoch={epoch}-RMSE={validation/RMSE:.4f}
    monitor: validation/RMSE
    auto_insert_metric_name: false
ckpt_path: null
data:
  init_args:
    known_covariates:
    - I_sim_noise_A
    - I_sim_A_dot
    target_covariate: B_sim_eddy_noise_T
    train_df_paths:
    - ~/cernbox/hysteresis/dipole/datasets/pretraining/pretrain_train_24h.parquet
    val_df_paths:
    - ~/cernbox/hysteresis/dipole/datasets/pretraining/pretrain_validation_1h.parquet
    normalize: false
    ctxt_seq_len: 600
    tgt_seq_len: 200
    min_ctxt_seq_len: 100
    min_tgt_seq_len: 100
    randomize_seq_len: true
    stride: 1
    downsample: 20
    downsample_method: interval
    target_depends_on: I_sim_noise_A
    time_column: time_ms
    time_format: relative
    extra_transforms:
      I_sim_noise_A:
      - class_path: transformertf.data.RunningNormalizer
        init_args:
          num_features_: 1
          center_: 1820.0
          scale_: 1740.0
          frozen_: true
      I_sim_A_dot:
      - class_path: transformertf.data.RunningNormalizer
        init_args:
          num_features_: 1
          center_: 0.0
          scale_: 1300.0
          frozen_: true
      B_sim_eddy_noise_T:
      - class_path: transformertf.data.DiscreteFunctionTransform
        init_args:
          xs_: ~/cernbox/hysteresis/calibration_fn/SPS_MB_I2B_CALIBRATION_FN_v7.csv
          ys_: null
      - class_path: transformertf.data.RunningNormalizer
        init_args:
          num_features_: 1
          center_: 0.0
          scale_: 0.00103
          frozen_: true
    batch_size: 64
    num_workers: 4
    shuffle: true
  class_path: transformertf.data.EncoderDecoderDataModule
lr_scheduler:
  class_path: lightning.pytorch.cli.ReduceLROnPlateau
  init_args:
    monitor: validation/loss
    factor: 0.5
    patience: 2
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.00045
    weight_decay: 0.0001
model:
  class_path: transformertf.models.temporal_fusion_transformer.TemporalFusionTransformer
  init_args:
    n_dim_model: 300
    hidden_continuous_dim: 64
    num_heads: 2
    num_lstm_layers: 2
    dropout: 0.12
    output_dim: 7
    criterion:
      class_path: transformertf.nn.QuantileLoss
      init_args:
        quantiles:
        - 0.25
        - 0.5
        - 0.75
    casual_attention: true
    log_grad_norm: false

Merging multiple configurations

The Lightning CLI API allows multiple configurations to be used using multiple usages of the -c/--config option, where configs will be parsed from left-to-right, and merged in the same direction to produce a final configuration, from which classes are instantiated.

The following configuration files are provided

Base configurations

  • ~/cernbox/hysteresis/configs/base/trainer.yml Configures a default trainer with a TensorboardLogger saving logs to logs, and adding an EarlyStopping callback monitoring validation/loss, and a gradient_clip_val to 1.0. The default max_epochs and min_epochs are expected to be overriden by the user configuration.
  • ~/cernbox/hysteresis/configs/base/checkpoint.yml Configures checkpointing callbacks with the default filename epoch={epoch}-RMSE={validation/RMSE:.4f}.ckpt, and monitoring validation/RMSE.
  • ~/cernbox/hysteresis/configs/base/reduce_lr_on_plateau.yml Configures a ReduceLROnPlateaulearning rate scheduler that by default monitorsvalidation/loss, and has a patience` of 5 epochs.
  • ~/cernbox/hysteresis/configs/base/transfer_tft.yml Freezes all TFT model parameters except for attn_grn, attn_gate2, attn_norm2, output_layer, and should be used as a first step for Transfer learning and fine-tuning.

Transformation configurations

  • ~/cernbox/hysteresis/configs/mbi/mbi_transforms.yml Configures scalers for columns I_meas_A_filtered, I_meas_A_filtered_dot and B_meas_T_filtered, as well as a Calibration function for the target B_meas_T_filtered . The configuration additionally disables automatic normalization by the datamodule (since the scalers have already been defined/fitted).

  • ~/cernbox/hysteresis/configs/mbi/mbi_transforms_.yml Configures in addition to the previous one a scaler for B_meas_T_filtered_, which is the target variable, but a calibration function is not subtracted.

  • ~/cernbox/hysteresis/configs/mbi/sim_preisach_transforms.yml Configures scalers for columns I_sim_A, I_sim_A_dot, B_sim_eddy_T, and a calibration function for B_sim_eddy_T.

  • ~/cernbox/hysteresis/configs/mbi/sim_preisach_noise_transforms.yml Same as above, but scales I_sim_noise_A, I_sim_A_dot, and B_sim_eddy_noise_T instead.

  • ~/cernbox/hysteresis/configs/mbi/sim_ja_transforms.yml Same as above, but scales I_sim_A, I_sim_A_dotl, B_sim_ja_eddy_T.

  • ~/cernbox/hysteresis/configs/mbi/relative_time_transforms.yml Disables normalization, sets time_format to relative, and creates DeltaTransform and StandardScaler as extra transforms for __time__. This configuration must be used if above transforms are used, and time_column specified (otherwise time is not normalized).

  • ~/cernbox/hysteresis/configs/mbi/absolute_time_transforms.yml Disables normalization, sets time_format to absolute, and creates MaxScaler as extra transform for __time__. This configuration must be used if above transforms are used, and time_column specified (otherwise time is not normalized).

Datasets

  • ~/cernbox/hysteresis/configs/datasets/pretrain_preisach_rdp.yaml
  • ~/cernbox/hysteresis/configs/datasets/pretrain_ja_rdp.yaml
  • ~/cernbox/hysteresis/dipole/datasets/v3/mbi_dataset_v3.yml
  • ~/cernbox/hysteresis/dipole/datasets/v2/train_v2.yml

Typical training calls

To train a temporal fusion transformer with the Dipole dataset v3, ReduceOnLRPlateau, and default checkpoints callbacks, use the following:

transformertf fit \
	-c ~/cernbox/hysteresis/configs/base/trainer.yml \
	-c ~/cernbox/hysteresis/configs/base/checkpoint.yml \
	-c ~/cernbox/hysteresis/configs/base/reduce_lr_on_plateau.yml \
	-c ~/cernbox/hysteresis/configs/mbi/mbi_transforms_.yml \
	-c ~/cernbox/hysteresis/configs/mbi/relative_time_transforms.yml \
	-c ~/cernbox/hysteresis/dipole/datasets/v3/mbi_dataset_v3.yml \
	-c tft_config.yml

Example tft_config.yml with settings that override some default YAML settings:

seed_everything: false
trainer:
  enable_progress_bar: false
  max_epochs: 50
  min_epochs: 10
  val_check_interval: 0.5
  logger:
    class_path: lightning.pytorch.loggers.neptune.NeptuneLogger
    init_args:
      project: lua/TFT-MBI
      log_model_checkpoints: true
      prefix: ""
    dict_kwargs:
      dependencies: infer
      proxies:
        http_proxy: "http://cs-513-ml001:8080"
        https_proxy: "http://cs-513-ml001:8080"
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: validation/loss
    patience: 2
    min_lr: 1e-7
    factor: 0.5
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 5e-4
    weight_decay: 1e-4
model:
  class_path: transformertf.models.temporal_fusion_transformer.TemporalFusionTransformer
  init_args:
    n_dim_model: 500
    num_heads: 4
    hidden_continuous_dim: 64
    num_lstm_layers: 1
    dropout: 0.1
    log_grad_norm: false
    compile_model: true
    criterion:
      class_path: transformertf.nn.QuantileLoss
      init_args:
        quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
data:
  class_path: transformertf.data.EncoderDecoderDataModule
  init_args:
    known_covariates:
      - "I_meas_A_filtered"
      - "I_meas_A_filtered_dot"
    target_covariate: B_meas_T_filtered
    known_past_covariates:
      - B_meas_T_filtered_
    time_column: time_ms
    time_format: relative
    ctxt_seq_len: 1020
    tgt_seq_len: 540
    min_ctxt_seq_len: 180
    min_tgt_seq_len: 180
    randomize_seq_len: true
    batch_size: 64
    num_workers: 4
    downsample: 1
    stride: 20
 

Warning

If you are using Neptune, don’t forget to export the NEPTUNE_API_TOKEN prior to calling transformertf fit. The token is not set in the YAML files in order to not upload the token in plaintext to Neptune.