Datasets are generated in ~/cernbox/hysteresis/dipole/datasets/v5/mbi_dataset_v5.yml, and is a copy of Dipole datasets v4, but with SFT type cycles drift-corrected, only if the preceeding cycle is an MD1. In other cases like ZERO or other MD cycles, the SFT is untouched.

Datasets are inspected manually before saving to avoid artifacts.

python drift_correct_sftpro.py datasets

from __future__ import annotations
 
import argparse
import os
import pathlib
import shutil
import typing
import sys
 
import matplotlib.figure
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt
 
from hysteresis_scripts.data import BTrainDataset, BTrainCycle
 
 
CALIBRATION_FN_PATH = pathlib.Path(
    "~/cernbox/hysteresis/calibration_fn/SPS_MB_I2B_CALIBRATION_FN_v7.csv"
).expanduser()
 
calibration_fn = np.loadtxt(CALIBRATION_FN_PATH, delimiter=",", skiprows=1)
 
 
def marker_idx(
    i_meas: np.ndarray, b_meas: np.ndarray, marker_value: float = 0.11
) -> int:
    b_ref = np.interp(i_meas, *calibration_fn.T)
    diff = np.gradient(b_meas) - np.gradient(b_ref)
 
    idx = int(np.where(b_meas > marker_value)[0][0])
 
    amin = np.argmin(diff)
    amax = np.argmax(diff)
 
    # depending on if the marker direction was up or down
    if np.abs(amax - idx) <= 1:
        pass
    elif np.abs(amin - idx) <= 1:
        idx += 1
 
    return idx
 
 
def drift_correct(
    dataset: BTrainDataset,
    calibration_fn: np.ndarray,
    *,
    fix_only: typing.Callable[[BTrainCycle], bool] | str = "",
    prev_cycle_constr: typing.Callable[[BTrainCycle], bool] | str = "",
) -> BTrainDataset:
    dataset = BTrainDataset(dataset.df.copy())
 
    if int(pd.__version__.split(".")[0]) > 2:
        msg = "This function is only supported for pandas <3.0.0 due to chained assignment issues"
        raise NotImplementedError(msg)
 
    cycles = dataset.aslist()
    slices = list(dataset._slices.values())  # noqa: SLF001
    for i, (cycle, s) in enumerate(zip(cycles, slices)):
        if i == 0:
            continue
 
        if callable(fix_only):
            if not fix_only(cycle):
                continue
        elif fix_only and cycle.user != fix_only:
            continue
 
        try:
            idx = marker_idx(cycle.i_meas, cycle.b_meas)
        except IndexError:
            continue
 
        # estimate the field change due to current change
        iref = cycle.df["I_ref_A"].to_numpy()
        b0 = np.interp(iref[idx - 1], *calibration_fn.T)
        b1 = np.interp(iref[idx], *calibration_fn.T)
        db = b1 - b0
 
        # drift is computed as the gap between the marker and the previous value, minus the field change due to current change
        drift = cycle.b_meas[idx] - cycle.b_meas[idx - 1]
        drift -= db
 
        # retrieve the previous cycle so we know how long the drift correction should be
        prev_cycle = cycles[i - 1]
 
        if callable(prev_cycle_constr):
            if not prev_cycle_constr(prev_cycle):
                continue
        elif prev_cycle_constr and prev_cycle.user != prev_cycle_constr:
            continue
 
        try:
            idx_prev = marker_idx(prev_cycle.i_meas, prev_cycle.b_meas)
        except IndexError:
            continue
        dt = len(prev_cycle.i_meas) - idx_prev + idx
 
        # fix the previous cycle
        b_meas_prev = prev_cycle.df["B_meas_T"].to_numpy()
        corr = drift * np.arange(len(b_meas_prev) - idx_prev) / dt
        b_meas_prev[idx_prev:] += corr
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            dataset.df["B_meas_T"].iloc[slices[i - 1]] = b_meas_prev
 
        # fix the current cycle
        b_meas_corrected = cycle.df["B_meas_T"].to_numpy()
        corr = drift * np.arange(len(prev_cycle.i_meas) - idx_prev + 1, dt + 1) / dt
        b_meas_corrected[:idx] += corr
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            dataset.df["B_meas_T"].iloc[s] = b_meas_corrected
 
        print(f"Drift corrected cycle {cycle.cycle} with drift {drift * 1e4:.2f} G")
 
    return dataset
 
 
def parse_args(argv: list[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("datasets", type=pathlib.Path, nargs="+")
 
    return parser.parse_args(argv)
 
 
def plot_phase_space(
    dataset: BTrainDataset, dataset_ref: BTrainDataset
) -> matplotlib.figure.Figure:
    fig, ax = plt.subplots()
 
    cycles = dataset.aslist()
    cycles = [cycle for cycle in cycles if cycle.cycle.startswith("SFT")]
 
    cycles_ref = [
        cycle for cycle in dataset_ref.aslist() if cycle.cycle.startswith("SFT")
    ]
    if len(cycles) != len(cycles_ref):
        raise ValueError("Different number of cycles in datasets")
    if len(cycles) == 0:
        print("No SFT cycles found in dataset")
        return fig
    for IDX in range(len(cycles)):
        cycle = cycles[IDX]
        idx = marker_idx(cycle.i_meas, cycles_ref[IDX].b_meas)
 
        ax.plot(
            cycle.i_meas,
            cycles_ref[IDX].b_meas
            - np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1]),
            label="Reference",
            color=f"C{IDX}",
            linewidth=0.5,
        )
        ax.plot(
            cycle.i_meas,
            cycle.b_meas
            - np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1]),
            label="Fixed",
            color=f"C{IDX}",
            linestyle="--",
            linewidth=0.5,
        )
 
    ax.axvline(cycle.i_meas[idx], c="black", linestyle="--", label="Marker")
    ax.axhline(
        (
            cycle.b_meas
            - np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1])
        )[idx],
        c="black",
        linestyle="--",
    )
 
    # ax.legend()
 
    ax.grid()
 
    return fig
 
 
def drift_correct_dataset(path: pathlib.Path) -> None:
    print(f"Drift correcting dataset {path}")
 
    dataset = BTrainDataset.from_parquet(path)
 
    dataset = drift_correct(
        dataset,
        calibration_fn,
        fix_only=lambda cycle: cycle.cycle.startswith("SFT"),
        prev_cycle_constr=lambda cycle: cycle.user == "MD1"
        and cycle.df.I_meas_A.max() > 500,
    )
 
    fig = plot_phase_space(dataset, BTrainDataset.from_parquet(path))
    fig.suptitle(f"Drift correction for {path.name}")
    plt.show()
 
    # ask user if they want to save the corrected dataset
    if input("Save corrected dataset? [y/n] ").lower() == "y":
        dataset.to_parquet(path)
 
 
def main(argv: list[str]) -> None:
    args = parse_args(argv)
 
    datasets = (
        [args.datasets] if isinstance(args.datasets, pathlib.Path) else args.datasets
    )
    datasets = [dataset for dataset in datasets if dataset.exists()]
 
    for dataset_path in datasets:
        if (
            dataset_path.suffix == ".parquet"
            and "preprocessed" not in dataset_path.name
        ):
            drift_correct_dataset(dataset_path)
            continue
        elif dataset_path.is_dir():
            for path in dataset_path.rglob("*"):
                if path.suffix == ".parquet" and "preprocessed" not in path.name:
                    drift_correct_dataset(path)
            continue
 
 
if __name__ == "__main__":
    main(sys.argv[1:])