Lightning API
anypinn.lightning wraps core abstractions into PyTorch Lightning components
for batteries-included training.
Training Module
PINNModule
Bases: LightningModule
LightningModule wrapper for a Problem instance.
Delegates physics computation to the Problem and handles
optimizer/scheduler configuration, context injection, and
prediction formatting. You rarely need to subclass this.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
problem
|
Problem
|
The PINN problem definition (constraints, fields, etc.). |
required |
hp
|
PINNHyperparameters
|
Hyperparameters for training. |
required |
Example
hp = hp
instance-attribute
problem = problem
instance-attribute
__init__(problem: Problem, hp: PINNHyperparameters)
configure_optimizers() -> OptimizerLRScheduler
Configures the optimizer and learning rate scheduler.
on_fit_start() -> None
Called when fit begins. Resolves validation sources using loaded data.
on_predict_start() -> None
Called when predict begins. Resolves validation sources using loaded data.
predict_step(batch: PredictionBatch, batch_idx: int) -> Predictions
Performs a prediction step.
training_step(batch: TrainingBatch, batch_idx: int) -> Tensor
Performs a single training step. Calculates total loss from the problem.
Callbacks
SMMAStopping
Bases: Callback
Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SMMAStoppingConfig
|
Configuration for SMMA stopping (window, threshold, lookback). |
required |
loss_key
|
str
|
The metric key to monitor (e.g., 'loss'). |
required |
log_key
|
str
|
Key to log the computed SMMA value. |
SMMA_KEY
|
config = config
instance-attribute
log_key = log_key
instance-attribute
loss_buffer: list[float] = []
instance-attribute
loss_key = loss_key
instance-attribute
smma_buffer: deque[float] = deque(maxlen=(self.config.lookback))
instance-attribute
__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Called when the train epoch ends. Updates SMMA and checks stopping condition.
FormattedProgressBar
Bases: TQDMProgressBar
Custom progress bar for training that formats metrics for better readability.
This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
format
|
FormatFn
|
Function to format the metric values. |
required |
format = format
instance-attribute
__init__(*args: Any, format: FormatFn, **kwargs: Any)
get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]
Format metrics for display in the progress bar.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary of formatted metrics with: |
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|
PredictionsWriter
Bases: BasePredictionWriter
Callback to write predictions to disk at the end of an epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions_path
|
Path | None
|
Path to save the predictions tensor/object. |
None
|
batch_indices_path
|
Path | None
|
Path to save the batch indices. |
None
|
on_prediction
|
HookFn | None
|
Optional hook function called when predictions are ready. |
None
|
write_interval
|
Literal['batch', 'epoch', 'batch_and_epoch']
|
Interval to write predictions ("batch", "epoch", "batch_and_epoch"). |
'epoch'
|
batch_indices_path = batch_indices_path
instance-attribute
on_prediction = on_prediction
instance-attribute
predictions_path = predictions_path
instance-attribute
__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')
write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None
Writes predictions to disk or calls the hook at the end of the epoch.
AdaptiveCollocationCallback
Bases: Callback
Refreshes the collocation pool every N epochs using AdaptiveSampler.
Requires the DataModule to be configured with collocation_sampler="adaptive"
and a ResidualScorer passed to PINNDataModule(residual_scorer=...).
Because the scorer typically closes over the Problem, it automatically uses
the model's current weights on every call — no additional injection step needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
every_n_epochs
|
int
|
How often (in epochs) to resample. Default: 1. |
1
|
__init__(every_n_epochs: int = 1) -> None
on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None
Validate at fit start that the DataModule uses AdaptiveSampler.
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Resample collocation points using the current model weights.