anypinn.lightning
Lightning integration for PINN training.
This module wraps anypinn.core abstractions into PyTorch Lightning
components: a training module, stopping criteria, and data callbacks.
PINNModule
PINNModule is a thin LightningModule that delegates to a Problem
instance. It handles optimizer/scheduler configuration, context injection at
fit start, and prediction output formatting. You rarely need to subclass it
— all physics live in the Problem and its Constraint list.
Data scaling
When ODE/PDE state variables span very different magnitudes (e.g. the Lorenz
system where variables reach ~40), raw values can destabilize training.
DataScaling is a DataCallback (not a Lightning Callback) that
rescales data before the dataset is constructed:
- x is normalized to [0, 1].
- y is multiplied by per-series scale factors you provide.
Pass it via PINNDataModule(callbacks=[DataScaling(y_scale=...)]).
Validation functions are automatically rescaled so that logged validation
losses remain comparable.
Stopping criteria
Two stopping strategies are available:
- SMMAStopping: monitors the Smoothed Moving Average of the loss and stops when relative improvement over a lookback window drops below a threshold. This is the default in most catalog examples — it adapts to the loss trajectory rather than requiring a fixed patience count.
- Lightning's built-in EarlyStopping: monitors a metric and stops after
patienceepochs without improvement. Simpler to reason about, but less tolerant of noisy loss curves common in PINN training.
Use SMMAStopping when loss decreases slowly and erratically (typical for
PINNs). Use EarlyStopping when loss curves are smooth and you want a
hard patience bound.
Adaptive collocation
AdaptiveCollocationCallback resamples collocation points every N epochs
using the current model weights. It requires the data module to be
configured with collocation_sampler="adaptive" and a ResidualScorer.
__all__ = ['AdaptiveCollocationCallback', 'FormattedProgressBar', 'PINNModule', 'PredictionsWriter', 'SMMAStopping']
module-attribute
AdaptiveCollocationCallback
Bases: Callback
Refreshes the collocation pool every N epochs using AdaptiveSampler.
Requires the DataModule to be configured with collocation_sampler="adaptive"
and a ResidualScorer passed to PINNDataModule(residual_scorer=...).
Because the scorer typically closes over the Problem, it automatically uses
the model's current weights on every call — no additional injection step needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
every_n_epochs
|
int
|
How often (in epochs) to resample. Default: 1. |
1
|
Source code in src/anypinn/lightning/callbacks.py
__init__(every_n_epochs: int = 1) -> None
on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None
Validate at fit start that the DataModule uses AdaptiveSampler.
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Resample collocation points using the current model weights.
Source code in src/anypinn/lightning/callbacks.py
FormattedProgressBar
Bases: TQDMProgressBar
Custom progress bar for training that formats metrics for better readability.
This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
format
|
FormatFn
|
Function to format the metric values. |
required |
Source code in src/anypinn/lightning/callbacks.py
format = format
instance-attribute
__init__(*args: Any, format: FormatFn, **kwargs: Any)
get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]
Format metrics for display in the progress bar.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary of formatted metrics with: |
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|
Source code in src/anypinn/lightning/callbacks.py
PINNModule
Bases: LightningModule
LightningModule wrapper for a Problem instance.
Delegates physics computation to the Problem and handles
optimizer/scheduler configuration, context injection, and
prediction formatting. You rarely need to subclass this.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
problem
|
Problem
|
The PINN problem definition (constraints, fields, etc.). |
required |
hp
|
PINNHyperparameters
|
Hyperparameters for training. |
required |
Example
module = PINNModule(problem=problem, hp=hp) trainer = pl.Trainer(max_epochs=hp.max_epochs) trainer.fit(module, datamodule=data_module)
Source code in src/anypinn/lightning/module.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | |
hp = hp
instance-attribute
problem = problem
instance-attribute
__init__(problem: Problem, hp: PINNHyperparameters)
Source code in src/anypinn/lightning/module.py
configure_optimizers() -> OptimizerLRScheduler
Configures the optimizer and learning rate scheduler.
Source code in src/anypinn/lightning/module.py
on_fit_start() -> None
Called when fit begins. Resolves validation sources using loaded data.
on_predict_start() -> None
Called when predict begins. Resolves validation sources using loaded data.
predict_step(batch: PredictionBatch, batch_idx: int) -> Predictions
Performs a prediction step.
Source code in src/anypinn/lightning/module.py
training_step(batch: TrainingBatch, batch_idx: int) -> Tensor
Performs a single training step. Calculates total loss from the problem.
PredictionsWriter
Bases: BasePredictionWriter
Callback to write predictions to disk at the end of an epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions_path
|
Path | None
|
Path to save the predictions tensor/object. |
None
|
batch_indices_path
|
Path | None
|
Path to save the batch indices. |
None
|
on_prediction
|
HookFn | None
|
Optional hook function called when predictions are ready. |
None
|
write_interval
|
Literal['batch', 'epoch', 'batch_and_epoch']
|
Interval to write predictions ("batch", "epoch", "batch_and_epoch"). |
'epoch'
|
Source code in src/anypinn/lightning/callbacks.py
batch_indices_path = batch_indices_path
instance-attribute
on_prediction = on_prediction
instance-attribute
predictions_path = predictions_path
instance-attribute
__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')
Source code in src/anypinn/lightning/callbacks.py
write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None
Writes predictions to disk or calls the hook at the end of the epoch.
Source code in src/anypinn/lightning/callbacks.py
SMMAStopping
Bases: Callback
Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SMMAStoppingConfig
|
Configuration for SMMA stopping (window, threshold, lookback). |
required |
loss_key
|
str
|
The metric key to monitor (e.g., 'loss'). |
required |
log_key
|
str
|
Key to log the computed SMMA value. |
SMMA_KEY
|
Source code in src/anypinn/lightning/callbacks.py
config = config
instance-attribute
log_key = log_key
instance-attribute
loss_buffer: list[float] = []
instance-attribute
loss_key = loss_key
instance-attribute
smma_buffer: deque[float] = deque(maxlen=(self.config.lookback))
instance-attribute
__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Called when the train epoch ends. Updates SMMA and checks stopping condition.