Skip to content

anypinn.lightning.callbacks

FormatFn: TypeAlias = Callable[[str, Metric], Metric] module-attribute

A function that formats a metric for display in the progress bar. Takes the key and value of the metric, and returns the formatted metric.

HookFn: TypeAlias = Callable[[Trainer, LightningModule, Sequence[Predictions], Sequence[Any]], None] module-attribute

Metric: TypeAlias = int | str | float | dict[str, float] module-attribute

SMMA_KEY = 'loss/smma' module-attribute

AdaptiveCollocationCallback

Bases: Callback

Refreshes the collocation pool every N epochs using AdaptiveSampler.

Requires the DataModule to be configured with collocation_sampler="adaptive" and a ResidualScorer passed to PINNDataModule(residual_scorer=...).

Because the scorer typically closes over the Problem, it automatically uses the model's current weights on every call — no additional injection step needed.

Parameters:

Name Type Description Default
every_n_epochs int

How often (in epochs) to resample. Default: 1.

1
Source code in src/anypinn/lightning/callbacks.py
class AdaptiveCollocationCallback(Callback):
    """Refreshes the collocation pool every N epochs using AdaptiveSampler.

    Requires the DataModule to be configured with ``collocation_sampler="adaptive"``
    and a ``ResidualScorer`` passed to ``PINNDataModule(residual_scorer=...)``.

    Because the scorer typically closes over the ``Problem``, it automatically uses
    the model's current weights on every call — no additional injection step needed.

    Args:
        every_n_epochs: How often (in epochs) to resample. Default: 1.
    """

    def __init__(self, every_n_epochs: int = 1) -> None:
        super().__init__()
        if every_n_epochs < 1:
            raise ValueError(f"every_n_epochs must be >= 1, got {every_n_epochs}.")
        self._every_n = every_n_epochs

    @override
    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Validate at fit start that the DataModule uses AdaptiveSampler."""
        dm = trainer.datamodule  # type: ignore[attr-defined]
        if not isinstance(getattr(dm, "_sampler", None), AdaptiveSampler):
            raise TypeError(
                "AdaptiveCollocationCallback requires the DataModule to use "
                "collocation_sampler='adaptive'. Got: "
                f"{type(getattr(dm, '_sampler', None)).__name__}."
            )

    @override
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Resample collocation points using the current model weights."""
        if (trainer.current_epoch + 1) % self._every_n != 0:
            return
        dm = trainer.datamodule  # type: ignore[attr-defined]
        n = dm.hp.training_data.collocations
        new_coll = dm._sampler.sample(n, dm._domain)
        dm.pinn_ds.x_coll = new_coll

__init__(every_n_epochs: int = 1) -> None

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, every_n_epochs: int = 1) -> None:
    super().__init__()
    if every_n_epochs < 1:
        raise ValueError(f"every_n_epochs must be >= 1, got {every_n_epochs}.")
    self._every_n = every_n_epochs

on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None

Validate at fit start that the DataModule uses AdaptiveSampler.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Validate at fit start that the DataModule uses AdaptiveSampler."""
    dm = trainer.datamodule  # type: ignore[attr-defined]
    if not isinstance(getattr(dm, "_sampler", None), AdaptiveSampler):
        raise TypeError(
            "AdaptiveCollocationCallback requires the DataModule to use "
            "collocation_sampler='adaptive'. Got: "
            f"{type(getattr(dm, '_sampler', None)).__name__}."
        )

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Resample collocation points using the current model weights.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Resample collocation points using the current model weights."""
    if (trainer.current_epoch + 1) % self._every_n != 0:
        return
    dm = trainer.datamodule  # type: ignore[attr-defined]
    n = dm.hp.training_data.collocations
    new_coll = dm._sampler.sample(n, dm._domain)
    dm.pinn_ds.x_coll = new_coll

DataScaling

Bases: DataCallback

Callback to transform the data and collocation points.

Scales x to [0, 1] and applies per-series scaling factors to y.

Parameters:

Name Type Description Default
y_scale float | Sequence[float]

Scaling factor(s) for y data. Can be: - A single float: applied to all series - A sequence of floats: one per series (length must match number of series)

required
Source code in src/anypinn/lightning/callbacks.py
class DataScaling(DataCallback):
    """
    Callback to transform the data and collocation points.

    Scales x to [0, 1] and applies per-series scaling factors to y.

    Args:
        y_scale: Scaling factor(s) for y data. Can be:
            - A single float: applied to all series
            - A sequence of floats: one per series (length must match number of series)
    """

    def __init__(self, y_scale: float | Sequence[float]):
        self._y_scale_input = y_scale

    @override
    def transform_data(self, data: DataBatch, coll: Tensor) -> tuple[DataBatch, Tensor]:
        x, y = data

        x_min, x_max = x.min(), x.max()
        self.x_scale = x_max - x_min
        x = (x - x_min) / self.x_scale

        coll_min, coll_max = coll.min(), coll.max()
        coll = (coll - coll_min) / (coll_max - coll_min)

        n_series = y.shape[1]

        if isinstance(self._y_scale_input, (int, float)):
            scale_list = [float(self._y_scale_input)] * n_series
        else:
            scale_list = list(self._y_scale_input)
            if len(scale_list) != n_series:
                raise ValueError(
                    f"y_scale has {len(scale_list)} elements but data has {n_series} series"
                )

        self.y_scale = torch.tensor(scale_list, dtype=y.dtype, device=y.device)

        # Reshape scale for broadcasting against (n, k, 1)
        scale_tensor = self.y_scale.view(1, -1, 1)

        return (x, y * scale_tensor), coll

    @override
    def on_after_setup(self, dm: PINNDataModule) -> None:
        """Called after setup is complete."""

        for k in dm.validation:
            orig_fn = dm.validation[k]
            dm.validation[k] = (lambda fn, scale: lambda x: fn(x * scale))(orig_fn, self.x_scale)
        return None

__init__(y_scale: float | Sequence[float])

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, y_scale: float | Sequence[float]):
    self._y_scale_input = y_scale

on_after_setup(dm: PINNDataModule) -> None

Called after setup is complete.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_after_setup(self, dm: PINNDataModule) -> None:
    """Called after setup is complete."""

    for k in dm.validation:
        orig_fn = dm.validation[k]
        dm.validation[k] = (lambda fn, scale: lambda x: fn(x * scale))(orig_fn, self.x_scale)
    return None

transform_data(data: DataBatch, coll: Tensor) -> tuple[DataBatch, Tensor]

Source code in src/anypinn/lightning/callbacks.py
@override
def transform_data(self, data: DataBatch, coll: Tensor) -> tuple[DataBatch, Tensor]:
    x, y = data

    x_min, x_max = x.min(), x.max()
    self.x_scale = x_max - x_min
    x = (x - x_min) / self.x_scale

    coll_min, coll_max = coll.min(), coll.max()
    coll = (coll - coll_min) / (coll_max - coll_min)

    n_series = y.shape[1]

    if isinstance(self._y_scale_input, (int, float)):
        scale_list = [float(self._y_scale_input)] * n_series
    else:
        scale_list = list(self._y_scale_input)
        if len(scale_list) != n_series:
            raise ValueError(
                f"y_scale has {len(scale_list)} elements but data has {n_series} series"
            )

    self.y_scale = torch.tensor(scale_list, dtype=y.dtype, device=y.device)

    # Reshape scale for broadcasting against (n, k, 1)
    scale_tensor = self.y_scale.view(1, -1, 1)

    return (x, y * scale_tensor), coll

FormattedProgressBar

Bases: TQDMProgressBar

Custom progress bar for training that formats metrics for better readability.

This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.

Parameters:

Name Type Description Default
format FormatFn

Function to format the metric values.

required
Source code in src/anypinn/lightning/callbacks.py
class FormattedProgressBar(TQDMProgressBar):
    """
    Custom progress bar for training that formats metrics for better readability.

    This class extends the TQDMProgressBar to provide custom formatting for
    training metrics, particularly for the total loss and beta values.

    Args:
        format: Function to format the metric values.
    """

    def __init__(self, *args: Any, format: FormatFn, **kwargs: Any):
        super().__init__(*args, **kwargs)
        self.format = format

    @override
    def get_metrics(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
        """
        Format metrics for display in the progress bar.

        Returns:
            Dictionary of formatted metrics with:
            - Total loss in scientific notation
            - Beta value with 4 decimal places
            - Other metrics as provided by the parent class
        """
        items = super().get_metrics(*args, **kwargs)
        items.pop("v_num", None)
        for key, value in items.items():
            items[key] = self.format(key, value)

        return items

format = format instance-attribute

__init__(*args: Any, format: FormatFn, **kwargs: Any)

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, *args: Any, format: FormatFn, **kwargs: Any):
    super().__init__(*args, **kwargs)
    self.format = format

get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]

Format metrics for display in the progress bar.

Returns:

Type Description
dict[str, Any]

Dictionary of formatted metrics with:

dict[str, Any]
  • Total loss in scientific notation
dict[str, Any]
  • Beta value with 4 decimal places
dict[str, Any]
  • Other metrics as provided by the parent class
Source code in src/anypinn/lightning/callbacks.py
@override
def get_metrics(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
    """
    Format metrics for display in the progress bar.

    Returns:
        Dictionary of formatted metrics with:
        - Total loss in scientific notation
        - Beta value with 4 decimal places
        - Other metrics as provided by the parent class
    """
    items = super().get_metrics(*args, **kwargs)
    items.pop("v_num", None)
    for key, value in items.items():
        items[key] = self.format(key, value)

    return items

PredictionsWriter

Bases: BasePredictionWriter

Callback to write predictions to disk at the end of an epoch.

Parameters:

Name Type Description Default
predictions_path Path | None

Path to save the predictions tensor/object.

None
batch_indices_path Path | None

Path to save the batch indices.

None
on_prediction HookFn | None

Optional hook function called when predictions are ready.

None
write_interval Literal['batch', 'epoch', 'batch_and_epoch']

Interval to write predictions ("batch", "epoch", "batch_and_epoch").

'epoch'
Source code in src/anypinn/lightning/callbacks.py
class PredictionsWriter(BasePredictionWriter):
    """
    Callback to write predictions to disk at the end of an epoch.

    Args:
        predictions_path: Path to save the predictions tensor/object.
        batch_indices_path: Path to save the batch indices.
        on_prediction: Optional hook function called when predictions are ready.
        write_interval: Interval to write predictions ("batch", "epoch", "batch_and_epoch").
    """

    def __init__(
        self,
        predictions_path: Path | None = None,
        batch_indices_path: Path | None = None,
        on_prediction: HookFn | None = None,
        write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch",
    ):
        super().__init__(write_interval)
        self.predictions_path = predictions_path
        self.batch_indices_path = batch_indices_path
        self.on_prediction = on_prediction

    @override
    def write_on_epoch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        predictions: Sequence[Predictions],
        batch_indices: Sequence[Any],
    ) -> None:
        """
        Writes predictions to disk or calls the hook at the end of the epoch.
        """
        if self.on_prediction is not None:
            self.on_prediction(trainer, pl_module, predictions, batch_indices)

        if self.predictions_path is not None:
            torch.save(predictions, self.predictions_path)

        if self.batch_indices_path is not None:
            torch.save(batch_indices, self.batch_indices_path)

batch_indices_path = batch_indices_path instance-attribute

on_prediction = on_prediction instance-attribute

predictions_path = predictions_path instance-attribute

__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')

Source code in src/anypinn/lightning/callbacks.py
def __init__(
    self,
    predictions_path: Path | None = None,
    batch_indices_path: Path | None = None,
    on_prediction: HookFn | None = None,
    write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch",
):
    super().__init__(write_interval)
    self.predictions_path = predictions_path
    self.batch_indices_path = batch_indices_path
    self.on_prediction = on_prediction

write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None

Writes predictions to disk or calls the hook at the end of the epoch.

Source code in src/anypinn/lightning/callbacks.py
@override
def write_on_epoch_end(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    predictions: Sequence[Predictions],
    batch_indices: Sequence[Any],
) -> None:
    """
    Writes predictions to disk or calls the hook at the end of the epoch.
    """
    if self.on_prediction is not None:
        self.on_prediction(trainer, pl_module, predictions, batch_indices)

    if self.predictions_path is not None:
        torch.save(predictions, self.predictions_path)

    if self.batch_indices_path is not None:
        torch.save(batch_indices, self.batch_indices_path)

SMMAStopping

Bases: Callback

Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.

Parameters:

Name Type Description Default
config SMMAStoppingConfig

Configuration for SMMA stopping (window, threshold, lookback).

required
loss_key str

The metric key to monitor (e.g., 'loss').

required
log_key str

Key to log the computed SMMA value.

SMMA_KEY
Source code in src/anypinn/lightning/callbacks.py
class SMMAStopping(Callback):
    """
    Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss.
    Stops training if the relative improvement of the SMMA drops below a threshold.

    Args:
        config: Configuration for SMMA stopping (window, threshold, lookback).
        loss_key: The metric key to monitor (e.g., 'loss').
        log_key: Key to log the computed SMMA value.
    """

    def __init__(self, config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY):
        super().__init__()
        self.config = config
        self.loss_key = loss_key
        self.log_key = log_key
        self.loss_buffer: list[float] = []
        self.smma_buffer: deque[float] = deque(maxlen=self.config.lookback)

    @override
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """
        Called when the train epoch ends. Updates SMMA and checks stopping condition.
        """
        # phase 0: get the loss
        loss_t = trainer.callback_metrics.get(self.loss_key)
        if loss_t is None:
            return

        loss = loss_t.item()
        n = self.config.window

        # phase 1: collect first `window` losses
        if len(self.loss_buffer) <= n:
            self.loss_buffer.append(loss)
            return

        # phase 1.5: compute the first average
        if len(self.smma_buffer) == 0:
            first_smma = sum(self.loss_buffer) / n
            self.smma_buffer.append(first_smma)
            return

        # phase 2: compute the first `lookback` Smoothed Moving Average (SMMA)
        smma = self.smma_buffer[-1]
        smma = ((n - 1) * smma + loss) / n
        self.smma_buffer.append(smma)

        pl_module.log(self.log_key, smma)
        if len(self.smma_buffer) < self.config.lookback:
            return

        # phase 3: compute the improvement between the current and the `lookback` SMMA
        smma_lookback = self.smma_buffer[0]
        improvement = (smma_lookback - smma) / smma_lookback

        if 0 < improvement < self.config.threshold:
            trainer.should_stop = True
            print(
                f"\nStopping training: SMMA improvement over {self.config.lookback} "
                f"epochs ({improvement:.2%}) below threshold ({self.config.threshold:.2%})"
            )

config = config instance-attribute

log_key = log_key instance-attribute

loss_buffer: list[float] = [] instance-attribute

loss_key = loss_key instance-attribute

smma_buffer: deque[float] = deque(maxlen=(self.config.lookback)) instance-attribute

__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY):
    super().__init__()
    self.config = config
    self.loss_key = loss_key
    self.log_key = log_key
    self.loss_buffer: list[float] = []
    self.smma_buffer: deque[float] = deque(maxlen=self.config.lookback)

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Called when the train epoch ends. Updates SMMA and checks stopping condition.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """
    Called when the train epoch ends. Updates SMMA and checks stopping condition.
    """
    # phase 0: get the loss
    loss_t = trainer.callback_metrics.get(self.loss_key)
    if loss_t is None:
        return

    loss = loss_t.item()
    n = self.config.window

    # phase 1: collect first `window` losses
    if len(self.loss_buffer) <= n:
        self.loss_buffer.append(loss)
        return

    # phase 1.5: compute the first average
    if len(self.smma_buffer) == 0:
        first_smma = sum(self.loss_buffer) / n
        self.smma_buffer.append(first_smma)
        return

    # phase 2: compute the first `lookback` Smoothed Moving Average (SMMA)
    smma = self.smma_buffer[-1]
    smma = ((n - 1) * smma + loss) / n
    self.smma_buffer.append(smma)

    pl_module.log(self.log_key, smma)
    if len(self.smma_buffer) < self.config.lookback:
        return

    # phase 3: compute the improvement between the current and the `lookback` SMMA
    smma_lookback = self.smma_buffer[0]
    improvement = (smma_lookback - smma) / smma_lookback

    if 0 < improvement < self.config.threshold:
        trainer.should_stop = True
        print(
            f"\nStopping training: SMMA improvement over {self.config.lookback} "
            f"epochs ({improvement:.2%}) below threshold ({self.config.threshold:.2%})"
        )