For some reason training on ml3 we end up with

transformertf fit -c ~/cernbox/hysteresis/configs/base/trainer.yml -c ~/cernbox/hysteresis/configs/base/checkpoint.yml -c ~/cernbox/hysteresis/configs/base/reduce_lr_on_plateau.yml -c ~/cernbox/hysteresis/dipole/datasets/v4/mbi_dataset_v4.yml -c ~/cernbox/hysteresis/configs/mbi/mbi_transforms.yml -c tft_pretftmbi_25.yml
2025-03-13 13:19:42.827052: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-13 13:19:42.827181: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-13 13:19:42.828516: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-13 13:19:42.835017: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlowpip install triton==3.1.0 jsonargparse==4.32 torch==2.5.1 lightning==2.2.2 binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-13 13:19:43.771004: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2025-03-13 13:19:52 WARNING  The normalizer is frozen and cannot be fitted.                                      _scaler.py:156
                    WARNING  The normalizer is frozen and cannot be fitted.                                      _scaler.py:156
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
 Name Type Params
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
 0 criterion QuantileLoss      0
 1 model TemporalFusionTransformerModel  7.8 M
 2 model.static_vs VariableSelection  205 K
 3 model.enc_vs VariableSelection  825 K
 4 model.dec_vs VariableSelection  825 K
 5 model.static_ctxt_vs GatedResidualNetwork  361 K
 6 model.static_ctxt_enrichment GatedResidualNetwork  361 K
 7 model.lstm_init_hidden GatedResidualNetwork  361 K
 8 model.lstm_init_cell GatedResidualNetwork  361 K
 9 model.enc_lstm LSTM  1.4 M
 10 model.dec_lstm LSTM  1.4 M
 11 model.enc_gate1 GatedLinearUnit  180 K
 12 model.enc_norm1 AddNorm    600
 13 model.static_enrichment GatedResidualNetwork  451 K
 14 model.attn InterpretableMultiHeadAttention  225 K
 15 model.attn_gate1 GatedLinearUnit  180 K
 16 model.attn_norm1 AddNorm    600
 17 model.attn_grn GatedResidualNetwork  361 K
 18 model.attn_gate2 GatedLinearUnit  180 K
 19 model.attn_norm2 AddNorm    600
 20 model.output_layer Linear  2.1 K
└────┴──────────────────────────────┴─────────────────────────────────┴────────┘
Trainable params: 7.8 M
Non-trainable params: 0
Total params: 7.8 M
Total estimated model params size (MB): 31
Sanity Checking DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/opt/home/lua/.venvs/train-hysteresis/bin/transformertf", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/data/lua/code/transformertf/transformertf/main.py", line 333, in main
    LightningCLI(
  File "/opt/data/lua/code/transformertf/transformertf/main.py", line 79, in __init__
    super().__init__(*args, parser_kwargs={"parser_mode": "omegaconf"}, **kwargs)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 388, in __init__
    self._run_subcommand(self.subcommand)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 679, in _run_subcommand
    fn(**fn_kwargs)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
    self._run_sanity_check()
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1059, in _run_sanity_check
    val_loop.run()
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/data/lua/code/transformertf/transformertf/models/temporal_fusion_transformer/_lightning.py", line 101, in validation_step
    model_output = self(batch)
                   ^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/data/lua/code/transformertf/transformertf/models/temporal_fusion_transformer/_lightning.py", line 60, in forward
    return self.model(
           ^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/data/lua/code/transformertf/transformertf/models/temporal_fusion_transformer/_model.py", line 282, in forward
    attn_output = self.attn_gate2(attn_output)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 654, in wrapper
    speculation = self.speculate()
                  ^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2575, in speculate
    return self.speculation_log.next(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 204, in next
    entry.instruction_pointer == instruction_pointer
AssertionError:
SpeculationLog diverged at index 2 (log had 7 entries):
- Expected: /opt/data/lua/code/transformertf/transformertf/nn/_glu.py:55 (LOAD_FAST at ip=7)
- Actual: /opt/data/lua/code/transformertf/transformertf/nn/_glu.py:55 (CALL at ip=11)
Previous instruction: /opt/home/lua/.venvs/train-hysteresis/lib/python3.11/site-packages/torch/nn/modules/dropout.py:70(CALL @ 11)
 
There are two usual reasons why this may have occured:
- When Dynamo analysis restarted, the second run took a different path than
  the first.  If this occurred, the previous instruction is the critical instruction that
  behaved differently.
- Speculation entries are only added under certain conditions (as seen in
  step()), e.g., there must exist operators in the graph; those conditions may
  have changed on restart.
 
If this divergence was intentional, clear the speculation log before restarting (do NOT
do this for graph breaks, you will infinite loop).
 
Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
 
 
from user code:
   File "/opt/data/lua/code/transformertf/transformertf/nn/_glu.py", line 55, in forward
    x = self.fc1(x)
 
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
 
 
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The issue is fixed by downgrading

pip install triton==3.1.0 jsonargparse==4.32 torch==2.5.1 lightning==2.2.2