anypinn.lightning.callbacks
FormatFn: TypeAlias = Callable[[str, Metric], Metric]
module-attribute
A function that formats a metric for display in the progress bar. Takes the key and value of the metric, and returns the formatted metric.
HookFn: TypeAlias = Callable[[Trainer, LightningModule, Sequence[Predictions], Sequence[Any]], None]
module-attribute
Metric: TypeAlias = int | str | float | dict[str, float]
module-attribute
SMMA_KEY = 'loss/smma'
module-attribute
AdaptiveCollocationCallback
Bases: Callback
Refreshes the collocation pool every N epochs using AdaptiveSampler.
Requires the DataModule to be configured with collocation_sampler="adaptive"
and a ResidualScorer passed to PINNDataModule(residual_scorer=...).
Because the scorer typically closes over the Problem, it automatically uses
the model's current weights on every call — no additional injection step needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
every_n_epochs
|
int
|
How often (in epochs) to resample. Default: 1. |
1
|
Source code in src/anypinn/lightning/callbacks.py
__init__(every_n_epochs: int = 1) -> None
on_fit_start(trainer: Trainer, pl_module: LightningModule) -> None
Validate at fit start that the DataModule uses AdaptiveSampler.
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Resample collocation points using the current model weights.
Source code in src/anypinn/lightning/callbacks.py
DataScaling
Bases: DataCallback
Callback to transform the data and collocation points.
Scales x to [0, 1] and applies per-series scaling factors to y.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_scale
|
float | Sequence[float]
|
Scaling factor(s) for y data. Can be: - A single float: applied to all series - A sequence of floats: one per series (length must match number of series) |
required |
Source code in src/anypinn/lightning/callbacks.py
__init__(y_scale: float | Sequence[float])
on_after_setup(dm: PINNDataModule) -> None
Called after setup is complete.
Source code in src/anypinn/lightning/callbacks.py
transform_data(data: DataBatch, coll: Tensor) -> tuple[DataBatch, Tensor]
Source code in src/anypinn/lightning/callbacks.py
FormattedProgressBar
Bases: TQDMProgressBar
Custom progress bar for training that formats metrics for better readability.
This class extends the TQDMProgressBar to provide custom formatting for training metrics, particularly for the total loss and beta values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
format
|
FormatFn
|
Function to format the metric values. |
required |
Source code in src/anypinn/lightning/callbacks.py
format = format
instance-attribute
__init__(*args: Any, format: FormatFn, **kwargs: Any)
get_metrics(*args: Any, **kwargs: Any) -> dict[str, Any]
Format metrics for display in the progress bar.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary of formatted metrics with: |
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|
Source code in src/anypinn/lightning/callbacks.py
PredictionsWriter
Bases: BasePredictionWriter
Callback to write predictions to disk at the end of an epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions_path
|
Path | None
|
Path to save the predictions tensor/object. |
None
|
batch_indices_path
|
Path | None
|
Path to save the batch indices. |
None
|
on_prediction
|
HookFn | None
|
Optional hook function called when predictions are ready. |
None
|
write_interval
|
Literal['batch', 'epoch', 'batch_and_epoch']
|
Interval to write predictions ("batch", "epoch", "batch_and_epoch"). |
'epoch'
|
Source code in src/anypinn/lightning/callbacks.py
batch_indices_path = batch_indices_path
instance-attribute
on_prediction = on_prediction
instance-attribute
predictions_path = predictions_path
instance-attribute
__init__(predictions_path: Path | None = None, batch_indices_path: Path | None = None, on_prediction: HookFn | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch')
Source code in src/anypinn/lightning/callbacks.py
write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[Predictions], batch_indices: Sequence[Any]) -> None
Writes predictions to disk or calls the hook at the end of the epoch.
Source code in src/anypinn/lightning/callbacks.py
SMMAStopping
Bases: Callback
Early stopping callback based on the Smoothed Moving Average (SMMA) of the loss. Stops training if the relative improvement of the SMMA drops below a threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SMMAStoppingConfig
|
Configuration for SMMA stopping (window, threshold, lookback). |
required |
loss_key
|
str
|
The metric key to monitor (e.g., 'loss'). |
required |
log_key
|
str
|
Key to log the computed SMMA value. |
SMMA_KEY
|
Source code in src/anypinn/lightning/callbacks.py
config = config
instance-attribute
log_key = log_key
instance-attribute
loss_buffer: list[float] = []
instance-attribute
loss_key = loss_key
instance-attribute
smma_buffer: deque[float] = deque(maxlen=(self.config.lookback))
instance-attribute
__init__(config: SMMAStoppingConfig, loss_key: str, log_key: str = SMMA_KEY)
Source code in src/anypinn/lightning/callbacks.py
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) -> None
Called when the train epoch ends. Updates SMMA and checks stopping condition.