Problems API
anypinn.problems provides ready-made constraint types for ODE and PDE
problems, plus convenience classes that wire them together.
ODE Constraints
ODECallable
Bases: Protocol
Protocol for ODE right-hand side callables.
First-order callables (ODEProperties.order == 1) receive three
positional arguments and must match this Protocol exactly::
def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...
Higher-order callables (order >= 2) receive a fourth positional
argument derivs: list[Tensor], where derivs[k] is the
(k+1)-th derivative of all fields stacked as (n_fields, m, 1)::
def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
derivs: list[Tensor] = []) -> Tensor: ...
The Protocol is intentionally kept to three arguments so that existing
first-order callables remain valid ODECallable implementations.
ResidualsConstraint uses _ODECallableN internally to call
higher-order functions with the correct signature.
__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor
ODEProperties
Properties defining an Ordinary Differential Equation problem.
Attributes:
| Name | Type | Description |
|---|---|---|
ode |
ODECallable
|
The ODE function (callable). |
args |
ArgsRegistry
|
Arguments/Parameters for the ODE. |
y0 |
Tensor
|
Initial conditions. |
order |
int
|
Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative. |
dy0 |
list[Tensor]
|
Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,). |
expected_args |
frozenset[str] | None
|
Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time. |
Example
def sir_ode(t, y, args):
S, I, R = y
beta, gamma = args["beta"](t), args["gamma"](t)
N = S + I + R
dS = -beta * S * I / N
dI = beta * S * I / N - gamma * I
dR = gamma * I
return torch.stack([dS, dI, dR])
props = ODEProperties(
ode=sir_ode,
args={"beta": Argument(0.3), "gamma": Argument(0.1)},
y0=torch.tensor([0.99, 0.01, 0.0]),
)
args: ArgsRegistry
instance-attribute
dy0: list[Tensor] = dc_field(default_factory=list)
class-attribute
instance-attribute
expected_args: frozenset[str] | None = None
class-attribute
instance-attribute
ode: ODECallable
instance-attribute
order: int = 1
class-attribute
instance-attribute
y0: Tensor
instance-attribute
__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None
__post_init__() -> None
ResidualsConstraint
Bases: Constraint
Constraint enforcing the ODE residuals.
Minimizes ||dy/dt - f(t, y)||^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
props
|
ODEProperties
|
ODE properties. |
required |
fields
|
FieldsRegistry
|
List of fields. |
required |
params
|
ParamsRegistry
|
List of parameters. |
required |
weight
|
float
|
Weight for this loss term. |
1.0
|
args = props.args.copy()
instance-attribute
fields = fields
instance-attribute
ode = props.ode
instance-attribute
order = props.order
instance-attribute
weight = weight
instance-attribute
__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the ODE residual loss over collocation points.
ICConstraint
Bases: Constraint
Constraint enforcing Initial Conditions (IC).
Minimizes ||y(t0) - Y0||^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fields
|
FieldsRegistry
|
Fields registry. |
required |
weight
|
float
|
Weight for this loss term. |
1.0
|
Y0 = props.y0.clone().reshape(-1, 1, 1)
instance-attribute
dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)]
instance-attribute
fields = fields
instance-attribute
order = props.order
instance-attribute
weight = weight
instance-attribute
__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)
inject_context(context: InferredContext) -> None
Inject the context into the constraint.
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the initial-condition loss.
DataConstraint
Bases: Constraint
Constraint enforcing fit to observed data.
Minimizes ||y_hat - y||^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fields
|
FieldsRegistry
|
Fields registry. |
required |
params
|
ParamsRegistry
|
Parameters registry. |
required |
predict_data
|
PredictDataFn
|
Function to predict data values from fields. |
required |
weight
|
float
|
Weight for this loss term. |
1.0
|
fields = fields
instance-attribute
params = params
instance-attribute
predict_data = predict_data
instance-attribute
weight = weight
instance-attribute
__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the data-fitting loss.
ODEInverseProblem
Bases: Problem
Convenience class composing Residuals + IC + Data constraints.
Wires together ResidualsConstraint, ICConstraint, and
DataConstraint with the loss criterion from hyperparameters.
For forward-only problems (no data fitting), compose the constraints
manually instead.
Example
__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None
ODEHyperparameters
Bases: PINNHyperparameters
Hyperparameters for ODE inverse problems.
Extends PINNHyperparameters with per-constraint loss weights
used by ODEInverseProblem.
Attributes:
| Name | Type | Description |
|---|---|---|
pde_weight |
float
|
Weight for the ODE residual loss term. |
ic_weight |
float
|
Weight for the initial-condition loss term. |
data_weight |
float
|
Weight for the data-fitting loss term. |
data_weight: float = 1.0
class-attribute
instance-attribute
ic_weight: float = 1.0
class-attribute
instance-attribute
pde_weight: float = 1.0
class-attribute
instance-attribute
__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None
PDE Constraints
BoundaryCondition
Pairs a boundary region sampler with a prescribed value function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sampler
|
Callable[[int], Tensor]
|
Callable |
required |
value
|
BCValueFn
|
Callable |
required |
n_pts
|
int
|
Number of boundary points sampled per training step. |
100
|
Example
n_pts = n_pts
instance-attribute
sampler = sampler
instance-attribute
value = value
instance-attribute
__init__(sampler: Callable[[int], Tensor], value: BCValueFn, n_pts: int = 100)
DirichletBCConstraint
Bases: Constraint
Enforces the Dirichlet boundary condition: u(x_bc) = g(x_bc).
Minimizes weight * criterion(u(x_bc), g(x_bc)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bc
|
BoundaryCondition
|
Boundary condition (sampler + target value function). |
required |
field
|
Field
|
The neural field to enforce the condition on. |
required |
log_key
|
str
|
Key used when logging the loss value. |
'loss/bc_dirichlet'
|
weight
|
float
|
Loss term weight. |
1.0
|
bc = bc
instance-attribute
field = field
instance-attribute
log_key = log_key
instance-attribute
weight = weight
instance-attribute
__init__(bc: BoundaryCondition, field: Field, log_key: str = 'loss/bc_dirichlet', weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the Dirichlet boundary condition loss.
NeumannBCConstraint
Bases: Constraint
Enforces the Neumann boundary condition:
du/dn(x_bc) = h(x_bc).
For a rectangular domain face whose outward normal is axis-aligned with
dimension normal_dim, we have
du/dn = du/dx[normal_dim].
Minimizes
weight * criterion(du_dn(x_bc), h(x_bc)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bc
|
BoundaryCondition
|
Boundary condition (sampler + target normal-derivative function). |
required |
field
|
Field
|
The neural field to enforce the condition on. |
required |
normal_dim
|
int
|
Index of the spatial dimension the boundary normal points along. |
required |
log_key
|
str
|
Key used when logging the loss value. |
'loss/bc_neumann'
|
weight
|
float
|
Loss term weight. |
1.0
|
bc = bc
instance-attribute
field = field
instance-attribute
log_key = log_key
instance-attribute
normal_dim = normal_dim
instance-attribute
weight = weight
instance-attribute
__init__(bc: BoundaryCondition, field: Field, normal_dim: int, log_key: str = 'loss/bc_neumann', weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the Neumann boundary condition loss.
PeriodicBCConstraint
Bases: Constraint
Enforces periodic boundary conditions:
u(x_left, t) = u(x_right, t) and matching spatial derivatives.
The two boundary samplers must produce paired points — identical coordinates in every dimension except the periodic one — so that the value- and derivative-matching losses are meaningful.
Minimizes
weight * [criterion(u_left, u_right) + criterion(du_left, du_right)].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bc_left
|
BoundaryCondition
|
Left boundary sampler (sampler + dummy value function). |
required |
bc_right
|
BoundaryCondition
|
Right boundary sampler (sampler + dummy value function). |
required |
field
|
Field
|
The neural field to enforce the condition on. |
required |
match_dim
|
int
|
Spatial dimension index for the derivative matching. |
0
|
log_key
|
str
|
Key used when logging the loss value. |
'loss/bc_periodic'
|
weight
|
float
|
Loss term weight. |
1.0
|
bc_left = bc_left
instance-attribute
bc_right = bc_right
instance-attribute
field = field
instance-attribute
log_key = log_key
instance-attribute
match_dim = match_dim
instance-attribute
weight = weight
instance-attribute
__init__(bc_left: BoundaryCondition, bc_right: BoundaryCondition, field: Field, match_dim: int = 0, log_key: str = 'loss/bc_periodic', weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the periodic boundary condition loss.
PDEResidualConstraint
Bases: Constraint
Enforces a PDE interior residual: residual_fn(x, fields, params) ≈ 0.
Minimizes weight * criterion(residual_fn(x_coll, fields, params), 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fields
|
FieldsRegistry
|
Registry of neural fields the residual function operates on. Pass only the subset needed — other fields in the Problem are ignored. |
required |
params
|
ParamsRegistry
|
Registry of parameters the residual function uses. |
required |
residual_fn
|
PDEResidualFn
|
Callable (x, fields, params) → Tensor of residuals.
Should use |
required |
log_key
|
str
|
Key used when logging the loss value. |
'loss/pde_residual'
|
weight
|
float
|
Loss term weight. |
1.0
|
Example
from anypinn.lib.diff import grad, partial
def heat_residual(x, fields, params):
u = fields["u"]
u_pred = u(x)
u_t = partial(u_pred, x, dim=1) # du/dt
u_x = partial(u_pred, x, dim=0) # du/dx
u_xx = partial(u_x, x, dim=0) # d2u/dx2
alpha = params["alpha"](x)
return u_t - alpha * u_xx # residual = 0
constraint = PDEResidualConstraint(
fields=fields, params=params,
residual_fn=heat_residual,
)
fields = fields
instance-attribute
log_key = log_key
instance-attribute
params = params
instance-attribute
residual_fn = residual_fn
instance-attribute
weight = weight
instance-attribute
__init__(fields: FieldsRegistry, params: ParamsRegistry, residual_fn: PDEResidualFn, log_key: str = 'loss/pde_residual', weight: float = 1.0)
loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor
Compute the PDE interior residual loss.
Type Aliases
| Alias | Definition | Purpose |
|---|---|---|
PredictDataFn |
Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] |
Maps fields/params to observed quantities |
BCValueFn |
Callable[[Tensor], Tensor] |
Boundary value function |
PDEResidualFn |
Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] |
PDE residual function |