Skip to content

anypinn.lightning

Lightning integration for PINN training.

__all__ = ['AdaptiveCollocationCallback', 'FormattedProgressBar', 'PINNModule', 'PredictionsWriter', 'SMMAStopping'] module-attribute

AdaptiveCollocationCallback

Bases: Callback

Refreshes the collocation pool every N epochs using AdaptiveSampler.

Requires the DataModule to be configured with collocation_sampler="adaptive" and a ResidualScorer passed to PINNDataModule(residual_scorer=...).

Because the scorer typically closes over the Problem, it automatically uses the model's current weights on every call — no additional injection step needed.

Parameters:

Name Type Description Default
every_n_epochs int

How often (in epochs) to resample. Default: 1.

1
Source code in src/anypinn/lightning/callbacks.py
class AdaptiveCollocationCallback(Callback):
    """Refreshes the collocation pool every N epochs using AdaptiveSampler.

    Requires the DataModule to be configured with ``collocation_sampler="adaptive"``
    and a ``ResidualScorer`` passed to ``PINNDataModule(residual_scorer=...)``.

    Because the scorer typically closes over the ``Problem``, it automatically uses
    the model's current weights on every call — no additional injection step needed.

    Args:
        every_n_epochs: How often (in epochs) to resample. Default: 1.
    """

    def __init__(self, every_n_epochs: int = 1) -> None:
        super().__init__()
        if every_n_epochs < 1:
            raise ValueError(f"every_n_epochs must be >= 1, got {every_n_epochs}.")
        self._every_n = every_n_epochs

    @override
    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Validate at fit start that the DataModule uses AdaptiveSampler."""
        dm = trainer.datamodule  # type: ignore[attr-defined]
        if not isinstance(getattr(dm, "_sampler", None), AdaptiveSampler):
            raise TypeError(
                "AdaptiveCollocationCallback requires the DataModule to use "
                "collocation_sampler='adaptive'. Got: "
                f"{type(getattr(dm, '_sampler', None)).__name__}."
            )

    @override
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Resample collocation points using the current model weights."""
        if (trainer.current_epoch + 1) % self._every_n != 0:
            return
        dm = trainer.datamodule  # type: ignore[attr-defined]
        n = dm.hp.training_data.collocations
        new_coll = dm._sampler.sample(n, dm._domain)
        dm.pinn_ds.x_coll = new_coll

__init__(every_n_epochs: int = 1) -> None

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, every_n_epochs: int = 1) -> None:
    super().__init__()
    if every_n_epochs < 1:
        raise ValueError(f"every_n_epochs must be >= 1, got {every_n_epochs}.")
    self._every_n = every_n_epochs

on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None

Validate at fit start that the DataModule uses AdaptiveSampler.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Validate at fit start that the DataModule uses AdaptiveSampler."""
    dm = trainer.datamodule  # type: ignore[attr-defined]
    if not isinstance(getattr(dm, "_sampler", None), AdaptiveSampler):
        raise TypeError(
            "AdaptiveCollocationCallback requires the DataModule to use "
            "collocation_sampler='adaptive'. Got: "
            f"{type(getattr(dm, '_sampler', None)).__name__}."
        )

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Resample collocation points using the current model weights.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Resample collocation points using the current model weights."""
    if (trainer.current_epoch + 1) % self._every_n != 0:
        return
    dm = trainer.datamodule  # type: ignore[attr-defined]
    n = dm.hp.training_data.collocations
    new_coll = dm._sampler.sample(n, dm._domain)
    dm.pinn_ds.x_coll = new_coll

FormattedProgressBar

Bases: TQDMProgressBar

Custom progress bar for training that formats metrics for better readability.

This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.

Parameters:

Name Type Description Default
format FormatFn

Function to format the metric values.

required
Source code in src/anypinn/lightning/callbacks.py
class FormattedProgressBar(TQDMProgressBar):
    """
    Custom progress bar for training that formats metrics for better readability.

    This class extends the TQDMProgressBar to provide custom formatting for
    training metrics, particularly for the total loss and beta values.

    Args:
        format: Function to format the metric values.
    """

    def __init__(self, *args: Any, format: FormatFn, **kwargs: Any):
        super().__init__(*args, **kwargs)
        self.format = format

    @override
    def get_metrics(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
        """
        Format metrics for display in the progress bar.

        Returns:
            Dictionary of formatted metrics with:
            - Total loss in scientific notation
            - Beta value with 4 decimal places
            - Other metrics as provided by the parent class
        """
        items = super().get_metrics(*args, **kwargs)
        items.pop("v_num", None)
        for key, value in items.items():
            items[key] = self.format(key, value)

        return items

format = format instance-attribute

__init__(*args: Any, format: FormatFn, **kwargs: Any)

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, *args: Any, format: FormatFn, **kwargs: Any):
    super().__init__(*args, **kwargs)
    self.format = format

get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]

Format metrics for display in the progress bar.

Returns:

Type Description
dict[str, Any]

Dictionary of formatted metrics with:

dict[str, Any]
  • Total loss in scientific notation
dict[str, Any]
  • Beta value with 4 decimal places
dict[str, Any]
  • Other metrics as provided by the parent class
Source code in src/anypinn/lightning/callbacks.py
@override
def get_metrics(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
    """
    Format metrics for display in the progress bar.

    Returns:
        Dictionary of formatted metrics with:
        - Total loss in scientific notation
        - Beta value with 4 decimal places
        - Other metrics as provided by the parent class
    """
    items = super().get_metrics(*args, **kwargs)
    items.pop("v_num", None)
    for key, value in items.items():
        items[key] = self.format(key, value)

    return items

PINNModule

Bases: LightningModule

Generic PINN Lightning module. Expects external Problem + Sampler + optimizer config.

Parameters:

Name Type Description Default
problem Problem

The PINN problem definition (constraints, fields, etc.).

required
hp PINNHyperparameters

Hyperparameters for training.

required
Source code in src/anypinn/lightning/module.py
class PINNModule(pl.LightningModule):
    """
    Generic PINN Lightning module.
    Expects external Problem + Sampler + optimizer config.

    Args:
        problem: The PINN problem definition (constraints, fields, etc.).
        hp: Hyperparameters for training.
    """

    def __init__(
        self,
        problem: Problem,
        hp: PINNHyperparameters,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["problem"])

        self.problem = problem
        self.hp = hp

        def _log(key: str, value: Tensor, progress_bar: bool = False) -> None:
            self.log(
                key,
                value,
                on_step=False,
                on_epoch=True,
                prog_bar=progress_bar,
                batch_size=hp.training_data.batch_size,
            )

        self._log = cast(LogFn, _log)

    @override
    def on_fit_start(self) -> None:
        """
        Called when fit begins. Resolves validation sources using loaded data.
        """
        self.problem.inject_context(self.trainer.datamodule.context)  # type: ignore

    @override
    def on_predict_start(self) -> None:
        """
        Called when predict begins. Resolves validation sources using loaded data.
        """
        self.problem.inject_context(self.trainer.datamodule.context)  # type: ignore

    @override
    def training_step(self, batch: TrainingBatch, batch_idx: int) -> Tensor:
        """
        Performs a single training step.
        Calculates total loss from the problem.
        """
        return self.problem.training_loss(batch, self._log)

    @override
    def predict_step(self, batch: PredictionBatch, batch_idx: int) -> Predictions:
        """
        Performs a prediction step.
        """
        x_data, y_data = batch

        (data_batch, predictions) = self.problem.predict((x_data, y_data))
        true_values = self.problem.true_values(x_data)

        return (data_batch, predictions, true_values)

    @override
    def configure_optimizers(self) -> OptimizerLRScheduler:
        """
        Configures the optimizer and learning rate scheduler.
        """
        opt_cfg = self.hp.optimizer
        if isinstance(opt_cfg, LBFGSConfig):
            opt = torch.optim.LBFGS(
                self.parameters(),
                lr=opt_cfg.lr,
                max_iter=opt_cfg.max_iter,
                max_eval=opt_cfg.max_eval,
                history_size=opt_cfg.history_size,
                line_search_fn=opt_cfg.line_search_fn,
            )
        elif isinstance(opt_cfg, AdamConfig):
            opt = torch.optim.Adam(
                self.parameters(),
                lr=opt_cfg.lr,
                betas=opt_cfg.betas,
                weight_decay=opt_cfg.weight_decay,
            )
        else:
            opt = torch.optim.Adam(self.parameters(), lr=self.hp.lr)

        sch_cfg = self.hp.scheduler
        if not sch_cfg:
            return opt

        if isinstance(sch_cfg, CosineAnnealingConfig):
            sch = torch.optim.lr_scheduler.CosineAnnealingLR(
                opt,
                T_max=sch_cfg.T_max,
                eta_min=sch_cfg.eta_min,
            )
            return {
                "optimizer": opt,
                "lr_scheduler": {
                    "name": "lr",
                    "scheduler": sch,
                    "interval": "epoch",
                    "frequency": 1,
                },
            }

        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            mode=sch_cfg.mode,
            factor=sch_cfg.factor,
            patience=sch_cfg.patience,
            threshold=sch_cfg.threshold,
            min_lr=sch_cfg.min_lr,
        )

        return {
            "optimizer": opt,
            "lr_scheduler": {
                "name": "lr",
                "scheduler": sch,
                "monitor": LOSS_KEY,
                "interval": "epoch",
                "frequency": 1,
            },
        }

hp = hp instance-attribute

problem = problem instance-attribute

__init__(problem: Problem, hp: PINNHyperparameters)

Source code in src/anypinn/lightning/module.py
def __init__(
    self,
    problem: Problem,
    hp: PINNHyperparameters,
):
    super().__init__()
    self.save_hyperparameters(ignore=["problem"])

    self.problem = problem
    self.hp = hp

    def _log(key: str, value: Tensor, progress_bar: bool = False) -> None:
        self.log(
            key,
            value,
            on_step=False,
            on_epoch=True,
            prog_bar=progress_bar,
            batch_size=hp.training_data.batch_size,
        )

    self._log = cast(LogFn, _log)

configure_optimizers() -> OptimizerLRScheduler

Configures the optimizer and learning rate scheduler.

Source code in src/anypinn/lightning/module.py
@override
def configure_optimizers(self) -> OptimizerLRScheduler:
    """
    Configures the optimizer and learning rate scheduler.
    """
    opt_cfg = self.hp.optimizer
    if isinstance(opt_cfg, LBFGSConfig):
        opt = torch.optim.LBFGS(
            self.parameters(),
            lr=opt_cfg.lr,
            max_iter=opt_cfg.max_iter,
            max_eval=opt_cfg.max_eval,
            history_size=opt_cfg.history_size,
            line_search_fn=opt_cfg.line_search_fn,
        )
    elif isinstance(opt_cfg, AdamConfig):
        opt = torch.optim.Adam(
            self.parameters(),
            lr=opt_cfg.lr,
            betas=opt_cfg.betas,
            weight_decay=opt_cfg.weight_decay,
        )
    else:
        opt = torch.optim.Adam(self.parameters(), lr=self.hp.lr)

    sch_cfg = self.hp.scheduler
    if not sch_cfg:
        return opt

    if isinstance(sch_cfg, CosineAnnealingConfig):
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt,
            T_max=sch_cfg.T_max,
            eta_min=sch_cfg.eta_min,
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "name": "lr",
                "scheduler": sch,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt,
        mode=sch_cfg.mode,
        factor=sch_cfg.factor,
        patience=sch_cfg.patience,
        threshold=sch_cfg.threshold,
        min_lr=sch_cfg.min_lr,
    )

    return {
        "optimizer": opt,
        "lr_scheduler": {
            "name": "lr",
            "scheduler": sch,
            "monitor": LOSS_KEY,
            "interval": "epoch",
            "frequency": 1,
        },
    }

on_fit_start() -> None

Called when fit begins. Resolves validation sources using loaded data.

Source code in src/anypinn/lightning/module.py
@override
def on_fit_start(self) -> None:
    """
    Called when fit begins. Resolves validation sources using loaded data.
    """
    self.problem.inject_context(self.trainer.datamodule.context)  # type: ignore

on_predict_start() -> None

Called when predict begins. Resolves validation sources using loaded data.

Source code in src/anypinn/lightning/module.py
@override
def on_predict_start(self) -> None:
    """
    Called when predict begins. Resolves validation sources using loaded data.
    """
    self.problem.inject_context(self.trainer.datamodule.context)  # type: ignore

predict_step(batch: PredictionBatch, batch_idx: int) -> Predictions

Performs a prediction step.

Source code in src/anypinn/lightning/module.py
@override
def predict_step(self, batch: PredictionBatch, batch_idx: int) -> Predictions:
    """
    Performs a prediction step.
    """
    x_data, y_data = batch

    (data_batch, predictions) = self.problem.predict((x_data, y_data))
    true_values = self.problem.true_values(x_data)

    return (data_batch, predictions, true_values)

training_step(batch: TrainingBatch, batch_idx: int) -> Tensor

Performs a single training step. Calculates total loss from the problem.

Source code in src/anypinn/lightning/module.py
@override
def training_step(self, batch: TrainingBatch, batch_idx: int) -> Tensor:
    """
    Performs a single training step.
    Calculates total loss from the problem.
    """
    return self.problem.training_loss(batch, self._log)

PredictionsWriter

Bases: BasePredictionWriter

Callback to write predictions to disk at the end of an epoch.

Parameters:

Name Type Description Default
predictions_path Path | None

Path to save the predictions tensor/object.

None
batch_indices_path Path | None

Path to save the batch indices.

None
on_prediction HookFn | None

Optional hook function called when predictions are ready.

None
write_interval Literal['batch', 'epoch', 'batch_and_epoch']

Interval to write predictions ("batch", "epoch", "batch_and_epoch").

'epoch'
Source code in src/anypinn/lightning/callbacks.py
class PredictionsWriter(BasePredictionWriter):
    """
    Callback to write predictions to disk at the end of an epoch.

    Args:
        predictions_path: Path to save the predictions tensor/object.
        batch_indices_path: Path to save the batch indices.
        on_prediction: Optional hook function called when predictions are ready.
        write_interval: Interval to write predictions ("batch", "epoch", "batch_and_epoch").
    """

    def __init__(
        self,
        predictions_path: Path | None = None,
        batch_indices_path: Path | None = None,
        on_prediction: HookFn | None = None,
        write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch",
    ):
        super().__init__(write_interval)
        self.predictions_path = predictions_path
        self.batch_indices_path = batch_indices_path
        self.on_prediction = on_prediction

    @override
    def write_on_epoch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        predictions: Sequence[Predictions],
        batch_indices: Sequence[Any],
    ) -> None:
        """
        Writes predictions to disk or calls the hook at the end of the epoch.
        """
        if self.on_prediction is not None:
            self.on_prediction(trainer, pl_module, predictions, batch_indices)

        if self.predictions_path is not None:
            torch.save(predictions, self.predictions_path)

        if self.batch_indices_path is not None:
            torch.save(batch_indices, self.batch_indices_path)

batch_indices_path = batch_indices_path instance-attribute

on_prediction = on_prediction instance-attribute

predictions_path = predictions_path instance-attribute

__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')

Source code in src/anypinn/lightning/callbacks.py
def __init__(
    self,
    predictions_path: Path | None = None,
    batch_indices_path: Path | None = None,
    on_prediction: HookFn | None = None,
    write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch",
):
    super().__init__(write_interval)
    self.predictions_path = predictions_path
    self.batch_indices_path = batch_indices_path
    self.on_prediction = on_prediction

write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None

Writes predictions to disk or calls the hook at the end of the epoch.

Source code in src/anypinn/lightning/callbacks.py
@override
def write_on_epoch_end(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    predictions: Sequence[Predictions],
    batch_indices: Sequence[Any],
) -> None:
    """
    Writes predictions to disk or calls the hook at the end of the epoch.
    """
    if self.on_prediction is not None:
        self.on_prediction(trainer, pl_module, predictions, batch_indices)

    if self.predictions_path is not None:
        torch.save(predictions, self.predictions_path)

    if self.batch_indices_path is not None:
        torch.save(batch_indices, self.batch_indices_path)

SMMAStopping

Bases: Callback

Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.

Parameters:

Name Type Description Default
config SMMAStoppingConfig

Configuration for SMMA stopping (window, threshold, lookback).

required
loss_key str

The metric key to monitor (e.g., 'loss').

required
log_key str

Key to log the computed SMMA value.

SMMA_KEY
Source code in src/anypinn/lightning/callbacks.py
class SMMAStopping(Callback):
    """
    Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss.
    Stops training if the relative improvement of the SMMA drops below a threshold.

    Args:
        config: Configuration for SMMA stopping (window, threshold, lookback).
        loss_key: The metric key to monitor (e.g., 'loss').
        log_key: Key to log the computed SMMA value.
    """

    def __init__(self, config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY):
        super().__init__()
        self.config = config
        self.loss_key = loss_key
        self.log_key = log_key
        self.loss_buffer: list[float] = []
        self.smma_buffer: deque[float] = deque(maxlen=self.config.lookback)

    @override
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """
        Called when the train epoch ends. Updates SMMA and checks stopping condition.
        """
        # phase 0: get the loss
        loss_t = trainer.callback_metrics.get(self.loss_key)
        if loss_t is None:
            return

        loss = loss_t.item()
        n = self.config.window

        # phase 1: collect first `window` losses
        if len(self.loss_buffer) <= n:
            self.loss_buffer.append(loss)
            return

        # phase 1.5: compute the first average
        if len(self.smma_buffer) == 0:
            first_smma = sum(self.loss_buffer) / n
            self.smma_buffer.append(first_smma)
            return

        # phase 2: compute the first `lookback` Smoothed Moving Average (SMMA)
        smma = self.smma_buffer[-1]
        smma = ((n - 1) * smma + loss) / n
        self.smma_buffer.append(smma)

        pl_module.log(self.log_key, smma)
        if len(self.smma_buffer) < self.config.lookback:
            return

        # phase 3: compute the improvement between the current and the `lookback` SMMA
        smma_lookback = self.smma_buffer[0]
        improvement = (smma_lookback - smma) / smma_lookback

        if 0 < improvement < self.config.threshold:
            trainer.should_stop = True
            print(
                f"\nStopping training: SMMA improvement over {self.config.lookback} "
                f"epochs ({improvement:.2%}) below threshold ({self.config.threshold:.2%})"
            )

config = config instance-attribute

log_key = log_key instance-attribute

loss_buffer: list[float] = [] instance-attribute

loss_key = loss_key instance-attribute

smma_buffer: deque[float] = deque(maxlen=(self.config.lookback)) instance-attribute

__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)

Source code in src/anypinn/lightning/callbacks.py
def __init__(self, config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY):
    super().__init__()
    self.config = config
    self.loss_key = loss_key
    self.log_key = log_key
    self.loss_buffer: list[float] = []
    self.smma_buffer: deque[float] = deque(maxlen=self.config.lookback)

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Called when the train epoch ends. Updates SMMA and checks stopping condition.

Source code in src/anypinn/lightning/callbacks.py
@override
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """
    Called when the train epoch ends. Updates SMMA and checks stopping condition.
    """
    # phase 0: get the loss
    loss_t = trainer.callback_metrics.get(self.loss_key)
    if loss_t is None:
        return

    loss = loss_t.item()
    n = self.config.window

    # phase 1: collect first `window` losses
    if len(self.loss_buffer) <= n:
        self.loss_buffer.append(loss)
        return

    # phase 1.5: compute the first average
    if len(self.smma_buffer) == 0:
        first_smma = sum(self.loss_buffer) / n
        self.smma_buffer.append(first_smma)
        return

    # phase 2: compute the first `lookback` Smoothed Moving Average (SMMA)
    smma = self.smma_buffer[-1]
    smma = ((n - 1) * smma + loss) / n
    self.smma_buffer.append(smma)

    pl_module.log(self.log_key, smma)
    if len(self.smma_buffer) < self.config.lookback:
        return

    # phase 3: compute the improvement between the current and the `lookback` SMMA
    smma_lookback = self.smma_buffer[0]
    improvement = (smma_lookback - smma) / smma_lookback

    if 0 < improvement < self.config.threshold:
        trainer.should_stop = True
        print(
            f"\nStopping training: SMMA improvement over {self.config.lookback} "
            f"epochs ({improvement:.2%}) below threshold ({self.config.threshold:.2%})"
        )