Datasets are generated in ~/cernbox/hysteresis/dipole/datasets/v5/mbi_dataset_v5.yml, and is a copy of Dipole datasets v4, but with SFT type cycles drift-corrected, only if the preceeding cycle is an MD1. In other cases like ZERO or other MD cycles, the SFT is untouched.
Datasets are inspected manually before saving to avoid artifacts.
python drift_correct_sftpro.py datasets
from __future__ import annotations
import argparse
import os
import pathlib
import shutil
import typing
import sys
import matplotlib.figure
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt
from hysteresis_scripts.data import BTrainDataset, BTrainCycle
CALIBRATION_FN_PATH = pathlib.Path(
"~/cernbox/hysteresis/calibration_fn/SPS_MB_I2B_CALIBRATION_FN_v7.csv"
).expanduser()
calibration_fn = np.loadtxt(CALIBRATION_FN_PATH, delimiter=",", skiprows=1)
def marker_idx(
i_meas: np.ndarray, b_meas: np.ndarray, marker_value: float = 0.11
) -> int:
b_ref = np.interp(i_meas, *calibration_fn.T)
diff = np.gradient(b_meas) - np.gradient(b_ref)
idx = int(np.where(b_meas > marker_value)[0][0])
amin = np.argmin(diff)
amax = np.argmax(diff)
# depending on if the marker direction was up or down
if np.abs(amax - idx) <= 1:
pass
elif np.abs(amin - idx) <= 1:
idx += 1
return idx
def drift_correct(
dataset: BTrainDataset,
calibration_fn: np.ndarray,
*,
fix_only: typing.Callable[[BTrainCycle], bool] | str = "",
prev_cycle_constr: typing.Callable[[BTrainCycle], bool] | str = "",
) -> BTrainDataset:
dataset = BTrainDataset(dataset.df.copy())
if int(pd.__version__.split(".")[0]) > 2:
msg = "This function is only supported for pandas <3.0.0 due to chained assignment issues"
raise NotImplementedError(msg)
cycles = dataset.aslist()
slices = list(dataset._slices.values()) # noqa: SLF001
for i, (cycle, s) in enumerate(zip(cycles, slices)):
if i == 0:
continue
if callable(fix_only):
if not fix_only(cycle):
continue
elif fix_only and cycle.user != fix_only:
continue
try:
idx = marker_idx(cycle.i_meas, cycle.b_meas)
except IndexError:
continue
# estimate the field change due to current change
iref = cycle.df["I_ref_A"].to_numpy()
b0 = np.interp(iref[idx - 1], *calibration_fn.T)
b1 = np.interp(iref[idx], *calibration_fn.T)
db = b1 - b0
# drift is computed as the gap between the marker and the previous value, minus the field change due to current change
drift = cycle.b_meas[idx] - cycle.b_meas[idx - 1]
drift -= db
# retrieve the previous cycle so we know how long the drift correction should be
prev_cycle = cycles[i - 1]
if callable(prev_cycle_constr):
if not prev_cycle_constr(prev_cycle):
continue
elif prev_cycle_constr and prev_cycle.user != prev_cycle_constr:
continue
try:
idx_prev = marker_idx(prev_cycle.i_meas, prev_cycle.b_meas)
except IndexError:
continue
dt = len(prev_cycle.i_meas) - idx_prev + idx
# fix the previous cycle
b_meas_prev = prev_cycle.df["B_meas_T"].to_numpy()
corr = drift * np.arange(len(b_meas_prev) - idx_prev) / dt
b_meas_prev[idx_prev:] += corr
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dataset.df["B_meas_T"].iloc[slices[i - 1]] = b_meas_prev
# fix the current cycle
b_meas_corrected = cycle.df["B_meas_T"].to_numpy()
corr = drift * np.arange(len(prev_cycle.i_meas) - idx_prev + 1, dt + 1) / dt
b_meas_corrected[:idx] += corr
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dataset.df["B_meas_T"].iloc[s] = b_meas_corrected
print(f"Drift corrected cycle {cycle.cycle} with drift {drift * 1e4:.2f} G")
return dataset
def parse_args(argv: list[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("datasets", type=pathlib.Path, nargs="+")
return parser.parse_args(argv)
def plot_phase_space(
dataset: BTrainDataset, dataset_ref: BTrainDataset
) -> matplotlib.figure.Figure:
fig, ax = plt.subplots()
cycles = dataset.aslist()
cycles = [cycle for cycle in cycles if cycle.cycle.startswith("SFT")]
cycles_ref = [
cycle for cycle in dataset_ref.aslist() if cycle.cycle.startswith("SFT")
]
if len(cycles) != len(cycles_ref):
raise ValueError("Different number of cycles in datasets")
if len(cycles) == 0:
print("No SFT cycles found in dataset")
return fig
for IDX in range(len(cycles)):
cycle = cycles[IDX]
idx = marker_idx(cycle.i_meas, cycles_ref[IDX].b_meas)
ax.plot(
cycle.i_meas,
cycles_ref[IDX].b_meas
- np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1]),
label="Reference",
color=f"C{IDX}",
linewidth=0.5,
)
ax.plot(
cycle.i_meas,
cycle.b_meas
- np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1]),
label="Fixed",
color=f"C{IDX}",
linestyle="--",
linewidth=0.5,
)
ax.axvline(cycle.i_meas[idx], c="black", linestyle="--", label="Marker")
ax.axhline(
(
cycle.b_meas
- np.interp(cycle.i_meas, calibration_fn[:, 0], calibration_fn[:, 1])
)[idx],
c="black",
linestyle="--",
)
# ax.legend()
ax.grid()
return fig
def drift_correct_dataset(path: pathlib.Path) -> None:
print(f"Drift correcting dataset {path}")
dataset = BTrainDataset.from_parquet(path)
dataset = drift_correct(
dataset,
calibration_fn,
fix_only=lambda cycle: cycle.cycle.startswith("SFT"),
prev_cycle_constr=lambda cycle: cycle.user == "MD1"
and cycle.df.I_meas_A.max() > 500,
)
fig = plot_phase_space(dataset, BTrainDataset.from_parquet(path))
fig.suptitle(f"Drift correction for {path.name}")
plt.show()
# ask user if they want to save the corrected dataset
if input("Save corrected dataset? [y/n] ").lower() == "y":
dataset.to_parquet(path)
def main(argv: list[str]) -> None:
args = parse_args(argv)
datasets = (
[args.datasets] if isinstance(args.datasets, pathlib.Path) else args.datasets
)
datasets = [dataset for dataset in datasets if dataset.exists()]
for dataset_path in datasets:
if (
dataset_path.suffix == ".parquet"
and "preprocessed" not in dataset_path.name
):
drift_correct_dataset(dataset_path)
continue
elif dataset_path.is_dir():
for path in dataset_path.rglob("*"):
if path.suffix == ".parquet" and "preprocessed" not in path.name:
drift_correct_dataset(path)
continue
if __name__ == "__main__":
main(sys.argv[1:])