The transformertf package exposes the entrypoint transformertf, which is implemented using Lightning CLI, which allows the user to call
transformertf fit -c [config.yml]to train a model.
Sample configuration
A sample configuration to train a TFT is the following:
# lightning.pytorch==2.2.2
seed_everything: 0
trainer:
logger:
class_path: lightning.pytorch.loggers.NeptuneLogger
init_args:
api_key: null
project: lua/Pretrain-TFT-MBI
run: null
log_model_checkpoints: true
dict_kwargs:
dependencies: infer
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: validation/loss
patience: 10
mode: min
- class_path: transformertf.callbacks.SetOptimizerLRCallback
init_args:
lr_file: /tmp/lr.txt
'on': step
- class_path: transformertf.callbacks.LogHparamsCallback
init_args:
monitor: validation/loss
- class_path: lightning.pytorch.callbacks.RichProgressBar
init_args:
refresh_rate: 1
leave: false
theme:
metrics_format: .2e
max_epochs: 50
min_epochs: 25
val_check_interval: 0.1
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 10
gradient_clip_val: 1.0
use_distributed_sampler: false
verbose: 0
transfer_ckpt: null
lr_step_interval: epoch
lr_monitor:
logging_interval: epoch
log_momentum: false
log_weight_decay: false
model_summary:
max_depth: 2
fit:
checkpoint_every:
dirpath: checkpoints
filename: epoch={epoch}-RMSE={validation/RMSE:.4f}
monitor: validation/RMSE
every_n_epochs: 50
auto_insert_metric_name: false
checkpoint_best:
dirpath: checkpoints
filename: epoch={epoch}-RMSE={validation/RMSE:.4f}
monitor: validation/RMSE
auto_insert_metric_name: false
ckpt_path: null
data:
init_args:
known_covariates:
- I_sim_noise_A
- I_sim_A_dot
target_covariate: B_sim_eddy_noise_T
train_df_paths:
- ~/cernbox/hysteresis/dipole/datasets/pretraining/pretrain_train_24h.parquet
val_df_paths:
- ~/cernbox/hysteresis/dipole/datasets/pretraining/pretrain_validation_1h.parquet
normalize: false
ctxt_seq_len: 600
tgt_seq_len: 200
min_ctxt_seq_len: 100
min_tgt_seq_len: 100
randomize_seq_len: true
stride: 1
downsample: 20
downsample_method: interval
target_depends_on: I_sim_noise_A
time_column: time_ms
time_format: relative
extra_transforms:
I_sim_noise_A:
- class_path: transformertf.data.RunningNormalizer
init_args:
num_features_: 1
center_: 1820.0
scale_: 1740.0
frozen_: true
I_sim_A_dot:
- class_path: transformertf.data.RunningNormalizer
init_args:
num_features_: 1
center_: 0.0
scale_: 1300.0
frozen_: true
B_sim_eddy_noise_T:
- class_path: transformertf.data.DiscreteFunctionTransform
init_args:
xs_: ~/cernbox/hysteresis/calibration_fn/SPS_MB_I2B_CALIBRATION_FN_v7.csv
ys_: null
- class_path: transformertf.data.RunningNormalizer
init_args:
num_features_: 1
center_: 0.0
scale_: 0.00103
frozen_: true
batch_size: 64
num_workers: 4
shuffle: true
class_path: transformertf.data.EncoderDecoderDataModule
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: validation/loss
factor: 0.5
patience: 2
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00045
weight_decay: 0.0001
model:
class_path: transformertf.models.temporal_fusion_transformer.TemporalFusionTransformer
init_args:
n_dim_model: 300
hidden_continuous_dim: 64
num_heads: 2
num_lstm_layers: 2
dropout: 0.12
output_dim: 7
criterion:
class_path: transformertf.nn.QuantileLoss
init_args:
quantiles:
- 0.25
- 0.5
- 0.75
casual_attention: true
log_grad_norm: false
Merging multiple configurations
The Lightning CLI API allows multiple configurations to be used using multiple usages of the -c/--config option, where configs will be parsed from left-to-right, and merged in the same direction to produce a final configuration, from which classes are instantiated.
The following configuration files are provided
Base configurations
~/cernbox/hysteresis/configs/base/trainer.ymlConfigures a default trainer with aTensorboardLoggersaving logs tologs, and adding anEarlyStoppingcallback monitoringvalidation/loss, and agradient_clip_valto 1.0. The defaultmax_epochsandmin_epochsare expected to be overriden by the user configuration.~/cernbox/hysteresis/configs/base/checkpoint.ymlConfigures checkpointing callbacks with the default filenameepoch={epoch}-RMSE={validation/RMSE:.4f}.ckpt, and monitoringvalidation/RMSE.~/cernbox/hysteresis/configs/base/reduce_lr_on_plateau.yml Configures aReduceLROnPlateaulearning rate scheduler that by default monitorsvalidation/loss, and has apatience` of 5 epochs.~/cernbox/hysteresis/configs/base/transfer_tft.ymlFreezes all TFT model parameters except forattn_grn,attn_gate2,attn_norm2,output_layer, and should be used as a first step for Transfer learning and fine-tuning.
Transformation configurations
-
~/cernbox/hysteresis/configs/mbi/mbi_transforms.ymlConfigures scalers for columnsI_meas_A_filtered,I_meas_A_filtered_dotandB_meas_T_filtered, as well as a Calibration function for the targetB_meas_T_filtered. The configuration additionally disables automatic normalization by the datamodule (since the scalers have already been defined/fitted). -
~/cernbox/hysteresis/configs/mbi/mbi_transforms_.ymlConfigures in addition to the previous one a scaler forB_meas_T_filtered_, which is the target variable, but a calibration function is not subtracted. -
~/cernbox/hysteresis/configs/mbi/sim_preisach_transforms.ymlConfigures scalers for columnsI_sim_A,I_sim_A_dot,B_sim_eddy_T, and a calibration function forB_sim_eddy_T. -
~/cernbox/hysteresis/configs/mbi/sim_preisach_noise_transforms.ymlSame as above, but scalesI_sim_noise_A,I_sim_A_dot, andB_sim_eddy_noise_Tinstead. -
~/cernbox/hysteresis/configs/mbi/sim_ja_transforms.ymlSame as above, but scalesI_sim_A,I_sim_A_dotl,B_sim_ja_eddy_T. -
~/cernbox/hysteresis/configs/mbi/relative_time_transforms.ymlDisables normalization, setstime_formattorelative, and createsDeltaTransformandStandardScaleras extra transforms for__time__. This configuration must be used if above transforms are used, andtime_columnspecified (otherwise time is not normalized). -
~/cernbox/hysteresis/configs/mbi/absolute_time_transforms.ymlDisables normalization, setstime_formattoabsolute, and createsMaxScaleras extra transform for__time__. This configuration must be used if above transforms are used, andtime_columnspecified (otherwise time is not normalized).
Datasets
~/cernbox/hysteresis/configs/datasets/pretrain_preisach_rdp.yaml~/cernbox/hysteresis/configs/datasets/pretrain_ja_rdp.yaml~/cernbox/hysteresis/dipole/datasets/v3/mbi_dataset_v3.yml~/cernbox/hysteresis/dipole/datasets/v2/train_v2.yml
Typical training calls
To train a temporal fusion transformer with the Dipole dataset v3, ReduceOnLRPlateau, and default checkpoints callbacks, use the following:
transformertf fit \
-c ~/cernbox/hysteresis/configs/base/trainer.yml \
-c ~/cernbox/hysteresis/configs/base/checkpoint.yml \
-c ~/cernbox/hysteresis/configs/base/reduce_lr_on_plateau.yml \
-c ~/cernbox/hysteresis/configs/mbi/mbi_transforms_.yml \
-c ~/cernbox/hysteresis/configs/mbi/relative_time_transforms.yml \
-c ~/cernbox/hysteresis/dipole/datasets/v3/mbi_dataset_v3.yml \
-c tft_config.ymlExample tft_config.yml with settings that override some default YAML settings:
seed_everything: false
trainer:
enable_progress_bar: false
max_epochs: 50
min_epochs: 10
val_check_interval: 0.5
logger:
class_path: lightning.pytorch.loggers.neptune.NeptuneLogger
init_args:
project: lua/TFT-MBI
log_model_checkpoints: true
prefix: ""
dict_kwargs:
dependencies: infer
proxies:
http_proxy: "http://cs-513-ml001:8080"
https_proxy: "http://cs-513-ml001:8080"
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: validation/loss
patience: 2
min_lr: 1e-7
factor: 0.5
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 5e-4
weight_decay: 1e-4
model:
class_path: transformertf.models.temporal_fusion_transformer.TemporalFusionTransformer
init_args:
n_dim_model: 500
num_heads: 4
hidden_continuous_dim: 64
num_lstm_layers: 1
dropout: 0.1
log_grad_norm: false
compile_model: true
criterion:
class_path: transformertf.nn.QuantileLoss
init_args:
quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
data:
class_path: transformertf.data.EncoderDecoderDataModule
init_args:
known_covariates:
- "I_meas_A_filtered"
- "I_meas_A_filtered_dot"
target_covariate: B_meas_T_filtered
known_past_covariates:
- B_meas_T_filtered_
time_column: time_ms
time_format: relative
ctxt_seq_len: 1020
tgt_seq_len: 540
min_ctxt_seq_len: 180
min_tgt_seq_len: 180
randomize_seq_len: true
batch_size: 64
num_workers: 4
downsample: 1
stride: 20
Warning
If you are using Neptune, don’t forget to export the
NEPTUNE_API_TOKENprior to callingtransformertf fit. The token is not set in the YAML files in order to not upload the token in plaintext to Neptune.