anypinn.lightning
Lightning integration for PINN training.
__all__ = ['AdaptiveCollocationCallback', 'FormattedProgressBar', 'PINNModule', 'PredictionsWriter', 'SMMAStopping']
module-attribute
AdaptiveCollocationCallback
Bases: Callback
Refreshes the collocation pool every N epochs using AdaptiveSampler.
Requires the DataModule to be configured with collocation_sampler="adaptive"
and a ResidualScorer passed to PINNDataModule(residual_scorer=...).
Because the scorer typically closes over the Problem, it automatically uses
the model's current weights on every call — no additional injection step needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
every_n_epochs
|
int
|
How often (in epochs) to resample. Default: 1. |
1
|
Source code in src/anypinn/lightning/callbacks.py
__init__(every_n_epochs: int = 1) -> None
on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None
Validate at fit start that the DataModule uses AdaptiveSampler.
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Resample collocation points using the current model weights.
Source code in src/anypinn/lightning/callbacks.py
FormattedProgressBar
Bases: TQDMProgressBar
Custom progress bar for training that formats metrics for better readability.
This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
format
|
FormatFn
|
Function to format the metric values. |
required |
Source code in src/anypinn/lightning/callbacks.py
format = format
instance-attribute
__init__(*args: Any, format: FormatFn, **kwargs: Any)
get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]
Format metrics for display in the progress bar.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary of formatted metrics with: |
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|
Source code in src/anypinn/lightning/callbacks.py
PINNModule
Bases: LightningModule
Generic PINN Lightning module. Expects external Problem + Sampler + optimizer config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
problem
|
Problem
|
The PINN problem definition (constraints, fields, etc.). |
required |
hp
|
PINNHyperparameters
|
Hyperparameters for training. |
required |
Source code in src/anypinn/lightning/module.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | |
hp = hp
instance-attribute
problem = problem
instance-attribute
__init__(problem: Problem, hp: PINNHyperparameters)
Source code in src/anypinn/lightning/module.py
configure_optimizers() -> OptimizerLRScheduler
Configures the optimizer and learning rate scheduler.
Source code in src/anypinn/lightning/module.py
on_fit_start() -> None
Called when fit begins. Resolves validation sources using loaded data.
on_predict_start() -> None
Called when predict begins. Resolves validation sources using loaded data.
predict_step(batch: PredictionBatch, batch_idx: int) -> Predictions
Performs a prediction step.
Source code in src/anypinn/lightning/module.py
training_step(batch: TrainingBatch, batch_idx: int) -> Tensor
Performs a single training step. Calculates total loss from the problem.
PredictionsWriter
Bases: BasePredictionWriter
Callback to write predictions to disk at the end of an epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions_path
|
Path | None
|
Path to save the predictions tensor/object. |
None
|
batch_indices_path
|
Path | None
|
Path to save the batch indices. |
None
|
on_prediction
|
HookFn | None
|
Optional hook function called when predictions are ready. |
None
|
write_interval
|
Literal['batch', 'epoch', 'batch_and_epoch']
|
Interval to write predictions ("batch", "epoch", "batch_and_epoch"). |
'epoch'
|
Source code in src/anypinn/lightning/callbacks.py
batch_indices_path = batch_indices_path
instance-attribute
on_prediction = on_prediction
instance-attribute
predictions_path = predictions_path
instance-attribute
__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')
Source code in src/anypinn/lightning/callbacks.py
write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None
Writes predictions to disk or calls the hook at the end of the epoch.
Source code in src/anypinn/lightning/callbacks.py
SMMAStopping
Bases: Callback
Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SMMAStoppingConfig
|
Configuration for SMMA stopping (window, threshold, lookback). |
required |
loss_key
|
str
|
The metric key to monitor (e.g., 'loss'). |
required |
log_key
|
str
|
Key to log the computed SMMA value. |
SMMA_KEY
|
Source code in src/anypinn/lightning/callbacks.py
config = config
instance-attribute
log_key = log_key
instance-attribute
loss_buffer: list[float] = []
instance-attribute
loss_key = loss_key
instance-attribute
smma_buffer: deque[float] = deque(maxlen=(self.config.lookback))
instance-attribute
__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Called when the train epoch ends. Updates SMMA and checks stopping condition.