Skip to content

Lightning API

anypinn.lightning wraps core abstractions into PyTorch Lightning components for batteries-included training.


Training Module

PINNModule

Bases: LightningModule

LightningModule wrapper for a Problem instance.

Delegates physics computation to the Problem and handles optimizer/scheduler configuration, context injection, and prediction formatting. You rarely need to subclass this.

Parameters:

Name Type Description Default
problem Problem

The PINN problem definition (constraints, fields, etc.).

required
hp PINNHyperparameters

Hyperparameters for training.

required
Example
module = PINNModule(problem=problem, hp=hp)
trainer = pl.Trainer(max_epochs=hp.max_epochs)
trainer.fit(module, datamodule=data_module)

hp = hp instance-attribute

problem = problem instance-attribute

__init__(problem: Problem, hp: PINNHyperparameters)

configure_optimizers() -> OptimizerLRScheduler

Configures the optimizer and learning rate scheduler.

on_fit_start() -> None

Called when fit begins. Resolves validation sources using loaded data.

on_predict_start() -> None

Called when predict begins. Resolves validation sources using loaded data.

predict_step(batch: PredictionBatch, batch_idx: int) -> Predictions

Performs a prediction step.

training_step(batch: TrainingBatch, batch_idx: int) -> Tensor

Performs a single training step. Calculates total loss from the problem.


Callbacks

SMMAStopping

Bases: Callback

Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.

Parameters:

Name Type Description Default
config SMMAStoppingConfig

Configuration for SMMA stopping (window, threshold, lookback).

required
loss_key str

The metric key to monitor (e.g., 'loss').

required
log_key str

Key to log the computed SMMA value.

SMMA_KEY

config = config instance-attribute

log_key = log_key instance-attribute

loss_buffer: list[float] = [] instance-attribute

loss_key = loss_key instance-attribute

smma_buffer: deque[float] = deque(maxlen=(self.config.lookback)) instance-attribute

__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Called when the train epoch ends. Updates SMMA and checks stopping condition.

FormattedProgressBar

Bases: TQDMProgressBar

Custom progress bar for training that formats metrics for better readability.

This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.

Parameters:

Name Type Description Default
format FormatFn

Function to format the metric values.

required

format = format instance-attribute

__init__(*args: Any, format: FormatFn, **kwargs: Any)

get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]

Format metrics for display in the progress bar.

Returns:

Type Description
dict[str, Any]

Dictionary of formatted metrics with:

dict[str, Any]
  • Total loss in scientific notation
dict[str, Any]
  • Beta value with 4 decimal places
dict[str, Any]
  • Other metrics as provided by the parent class

PredictionsWriter

Bases: BasePredictionWriter

Callback to write predictions to disk at the end of an epoch.

Parameters:

Name Type Description Default
predictions_path Path | None

Path to save the predictions tensor/object.

None
batch_indices_path Path | None

Path to save the batch indices.

None
on_prediction HookFn | None

Optional hook function called when predictions are ready.

None
write_interval Literal['batch', 'epoch', 'batch_and_epoch']

Interval to write predictions ("batch", "epoch", "batch_and_epoch").

'epoch'

batch_indices_path = batch_indices_path instance-attribute

on_prediction = on_prediction instance-attribute

predictions_path = predictions_path instance-attribute

__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')

write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None

Writes predictions to disk or calls the hook at the end of the epoch.

AdaptiveCollocationCallback

Bases: Callback

Refreshes the collocation pool every N epochs using AdaptiveSampler.

Requires the DataModule to be configured with collocation_sampler="adaptive" and a ResidualScorer passed to PINNDataModule(residual_scorer=...).

Because the scorer typically closes over the Problem, it automatically uses the model's current weights on every call — no additional injection step needed.

Parameters:

Name Type Description Default
every_n_epochs int

How often (in epochs) to resample. Default: 1.

1

__init__(every_n_epochs: int = 1) -> None

on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None

Validate at fit start that the DataModule uses AdaptiveSampler.

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None

Resample collocation points using the current model weights.