Skip to content

Problems API

anypinn.problems provides ready-made constraint types for ODE and PDE problems, plus convenience classes that wire them together.


ODE Constraints

ODECallable

Bases: Protocol

Protocol for ODE right-hand side callables.

First-order callables (ODEProperties.order == 1) receive three positional arguments and must match this Protocol exactly::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

Higher-order callables (order >= 2) receive a fourth positional argument derivs: list[Tensor], where derivs[k] is the (k+1)-th derivative of all fields stacked as (n_fields, m, 1)::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
           derivs: list[Tensor] = []) -> Tensor: ...

The Protocol is intentionally kept to three arguments so that existing first-order callables remain valid ODECallable implementations. ResidualsConstraint uses _ODECallableN internally to call higher-order functions with the correct signature.

__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

ODEProperties

Properties defining an Ordinary Differential Equation problem.

Attributes:

Name Type Description
ode ODECallable

The ODE function (callable).

args ArgsRegistry

Arguments/Parameters for the ODE.

y0 Tensor

Initial conditions.

order int

Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative.

dy0 list[Tensor]

Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).

expected_args frozenset[str] | None

Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time.

Example
def sir_ode(t, y, args):
    S, I, R = y
    beta, gamma = args["beta"](t), args["gamma"](t)
    N = S + I + R
    dS = -beta * S * I / N
    dI = beta * S * I / N - gamma * I
    dR = gamma * I
    return torch.stack([dS, dI, dR])

props = ODEProperties(
    ode=sir_ode,
    args={"beta": Argument(0.3), "gamma": Argument(0.1)},
    y0=torch.tensor([0.99, 0.01, 0.0]),
)

args: ArgsRegistry instance-attribute

dy0: list[Tensor] = dc_field(default_factory=list) class-attribute instance-attribute

expected_args: frozenset[str] | None = None class-attribute instance-attribute

ode: ODECallable instance-attribute

order: int = 1 class-attribute instance-attribute

y0: Tensor instance-attribute

__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None

__post_init__() -> None

ResidualsConstraint

Bases: Constraint

Constraint enforcing the ODE residuals. Minimizes ||dy/dt - f(t, y)||^2.

Parameters:

Name Type Description Default
props ODEProperties

ODE properties.

required
fields FieldsRegistry

List of fields.

required
params ParamsRegistry

List of parameters.

required
weight float

Weight for this loss term.

1.0

args = props.args.copy() instance-attribute

fields = fields instance-attribute

ode = props.ode instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the ODE residual loss over collocation points.

ICConstraint

Bases: Constraint

Constraint enforcing Initial Conditions (IC). Minimizes ||y(t0) - Y0||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
weight float

Weight for this loss term.

1.0

Y0 = props.y0.clone().reshape(-1, 1, 1) instance-attribute

dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)] instance-attribute

fields = fields instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)

inject_context(context: InferredContext) -> None

Inject the context into the constraint.

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the initial-condition loss.

DataConstraint

Bases: Constraint

Constraint enforcing fit to observed data. Minimizes ||y_hat - y||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
params ParamsRegistry

Parameters registry.

required
predict_data PredictDataFn

Function to predict data values from fields.

required
weight float

Weight for this loss term.

1.0

fields = fields instance-attribute

params = params instance-attribute

predict_data = predict_data instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the data-fitting loss.

ODEInverseProblem

Bases: Problem

Convenience class composing Residuals + IC + Data constraints.

Wires together ResidualsConstraint, ICConstraint, and DataConstraint with the loss criterion from hyperparameters. For forward-only problems (no data fitting), compose the constraints manually instead.

Example
problem = ODEInverseProblem(
    props=props,
    hp=ODEHyperparameters(...),
    fields={"S": field_s, "I": field_i, "R": field_r},
    params={"beta": Parameter(ScalarConfig(init_value=0.3))},
    predict_data=lambda t, f, p: torch.stack(
        [f["S"](t), f["I"](t), f["R"](t)], dim=1
    ).squeeze(-1),
)

__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None

ODEHyperparameters

Bases: PINNHyperparameters

Hyperparameters for ODE inverse problems.

Extends PINNHyperparameters with per-constraint loss weights used by ODEInverseProblem.

Attributes:

Name Type Description
pde_weight float

Weight for the ODE residual loss term.

ic_weight float

Weight for the initial-condition loss term.

data_weight float

Weight for the data-fitting loss term.

data_weight: float = 1.0 class-attribute instance-attribute

ic_weight: float = 1.0 class-attribute instance-attribute

pde_weight: float = 1.0 class-attribute instance-attribute

__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None


PDE Constraints

BoundaryCondition

Pairs a boundary region sampler with a prescribed value function.

Parameters:

Name Type Description Default
sampler Callable[[int], Tensor]

Callable (n_pts: int) -> Tensor of shape (n_pts, d). Called each training step to produce fresh boundary sample points.

required
value BCValueFn

Callable Tensor -> Tensor giving the target value or normal derivative at boundary coordinates.

required
n_pts int

Number of boundary points sampled per training step.

100
Example
# Left boundary of a 1-D+time domain at x=0
bc_left = BoundaryCondition(
    sampler=lambda n: torch.stack([
        torch.zeros(n),            # x = 0
        torch.rand(n) * T_max,     # t in [0, T]
    ], dim=-1),
    value=lambda coords: torch.zeros(coords.shape[0], 1),
    n_pts=50,
)

n_pts = n_pts instance-attribute

sampler = sampler instance-attribute

value = value instance-attribute

__init__(sampler: Callable[[int], Tensor], value: BCValueFn, n_pts: int = 100)

DirichletBCConstraint

Bases: Constraint

Enforces the Dirichlet boundary condition: u(x_bc) = g(x_bc). Minimizes weight * criterion(u(x_bc), g(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target value function).

required
field Field

The neural field to enforce the condition on.

required
log_key str

Key used when logging the loss value.

'loss/bc_dirichlet'
weight float

Loss term weight.

1.0

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, log_key: str = 'loss/bc_dirichlet', weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the Dirichlet boundary condition loss.

NeumannBCConstraint

Bases: Constraint

Enforces the Neumann boundary condition: du/dn(x_bc) = h(x_bc).

For a rectangular domain face whose outward normal is axis-aligned with dimension normal_dim, we have du/dn = du/dx[normal_dim]. Minimizes weight * criterion(du_dn(x_bc), h(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target normal-derivative function).

required
field Field

The neural field to enforce the condition on.

required
normal_dim int

Index of the spatial dimension the boundary normal points along.

required
log_key str

Key used when logging the loss value.

'loss/bc_neumann'
weight float

Loss term weight.

1.0

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

normal_dim = normal_dim instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, normal_dim: int, log_key: str = 'loss/bc_neumann', weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the Neumann boundary condition loss.

PeriodicBCConstraint

Bases: Constraint

Enforces periodic boundary conditions: u(x_left, t) = u(x_right, t) and matching spatial derivatives.

The two boundary samplers must produce paired points — identical coordinates in every dimension except the periodic one — so that the value- and derivative-matching losses are meaningful.

Minimizes weight * [criterion(u_left, u_right) + criterion(du_left, du_right)].

Parameters:

Name Type Description Default
bc_left BoundaryCondition

Left boundary sampler (sampler + dummy value function).

required
bc_right BoundaryCondition

Right boundary sampler (sampler + dummy value function).

required
field Field

The neural field to enforce the condition on.

required
match_dim int

Spatial dimension index for the derivative matching.

0
log_key str

Key used when logging the loss value.

'loss/bc_periodic'
weight float

Loss term weight.

1.0

bc_left = bc_left instance-attribute

bc_right = bc_right instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

match_dim = match_dim instance-attribute

weight = weight instance-attribute

__init__(bc_left: BoundaryCondition, bc_right: BoundaryCondition, field: Field, match_dim: int = 0, log_key: str = 'loss/bc_periodic', weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the periodic boundary condition loss.

PDEResidualConstraint

Bases: Constraint

Enforces a PDE interior residual: residual_fn(x, fields, params) ≈ 0. Minimizes weight * criterion(residual_fn(x_coll, fields, params), 0).

Parameters:

Name Type Description Default
fields FieldsRegistry

Registry of neural fields the residual function operates on. Pass only the subset needed — other fields in the Problem are ignored.

required
params ParamsRegistry

Registry of parameters the residual function uses.

required
residual_fn PDEResidualFn

Callable (x, fields, params) → Tensor of residuals. Should use anypinn.lib.diff operators for derivatives. The returned tensor is compared against zeros.

required
log_key str

Key used when logging the loss value.

'loss/pde_residual'
weight float

Loss term weight.

1.0
Example
from anypinn.lib.diff import grad, partial

def heat_residual(x, fields, params):
    u = fields["u"]
    u_pred = u(x)
    u_t = partial(u_pred, x, dim=1)   # du/dt
    u_x = partial(u_pred, x, dim=0)   # du/dx
    u_xx = partial(u_x, x, dim=0)     # d2u/dx2
    alpha = params["alpha"](x)
    return u_t - alpha * u_xx          # residual = 0

constraint = PDEResidualConstraint(
    fields=fields, params=params,
    residual_fn=heat_residual,
)

fields = fields instance-attribute

log_key = log_key instance-attribute

params = params instance-attribute

residual_fn = residual_fn instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, residual_fn: PDEResidualFn, log_key: str = 'loss/pde_residual', weight: float = 1.0)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the PDE interior residual loss.


Type Aliases

Alias Definition Purpose
PredictDataFn Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] Maps fields/params to observed quantities
BCValueFn Callable[[Tensor], Tensor] Boundary value function
PDEResidualFn Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] PDE residual function