Skip to content

anypinn.problems

Problem templates and implementations.

This module provides ready-made constraint types for ODE and PDE problems, plus the ODEInverseProblem convenience class that wires them together.

ODE vs PDE constraints

ODE and PDE problems share the same Constraint base class from anypinn.core, but differ in the physics they enforce:

ODE constraints (anypinn.problems.ode):

  • ResidualsConstraint: minimizes the ODE residual ||dy/dt - f(t, y, args)||. Supports arbitrary-order ODEs via the order parameter on ODEProperties.
  • ICConstraint: enforces initial conditions y(t0) = Y0 (and derivative ICs for higher-order ODEs).
  • DataConstraint: enforces fit to observed data ||y_hat - y_obs||.

PDE constraints (anypinn.problems.pde):

  • PDEResidualConstraint: minimizes a user-defined PDE residual function over interior collocation points.
  • DirichletBCConstraint: enforces u(x_bc) = g(x_bc).
  • NeumannBCConstraint: enforces du/dn(x_bc) = h(x_bc).
  • PeriodicBCConstraint: enforces u(x_left) = u(x_right) and matching derivatives.

Forward vs inverse direction

The difference between forward and inverse problems is which values are known and which are learned:

  • Forward: the PDE/ODE parameters (e.g. diffusivity, reaction rates) are known constants passed as Argument values. The neural field learns the solution. Constraints: residual + boundary/initial conditions only.
  • Inverse: some parameters are unknown and passed as Parameter instances (learnable). A DataConstraint is added so observed data guides the optimizer toward the correct parameter values alongside the solution.

Switching direction is a matter of moving entries between args and params — see anypinn.core for the Argument/Parameter promotion pattern.

Writing a custom Constraint

Subclass anypinn.core.Constraint and implement the loss method:

class MyConstraint(Constraint):
    def loss(self, batch, criterion, log=None):
        (x_data, y_data), x_coll = batch
        # compute your physics loss here
        loss = ...
        if log is not None:
            log("loss/my_term", loss)
        return loss

Optionally override inject_context(context) if the constraint needs domain bounds or validation data at runtime.

BCValueFn: TypeAlias = Callable[[Tensor], Tensor] module-attribute

A callable that maps boundary coordinates (n_pts, d) → target values (n_pts, out_dim).

PDEResidualFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

A callable (x, fields, params) → residual tensor, expected to be zero at the solution.

PredictDataFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

__all__ = ['BCValueFn', 'BoundaryCondition', 'DataConstraint', 'DirichletBCConstraint', 'ICConstraint', 'NeumannBCConstraint', 'ODECallable', 'ODEHyperparameters', 'ODEInverseProblem', 'ODEProperties', 'PDEResidualConstraint', 'PDEResidualFn', 'PeriodicBCConstraint', 'PredictDataFn', 'ResidualsConstraint'] module-attribute

BoundaryCondition

Pairs a boundary region sampler with a prescribed value function.

Parameters:

Name Type Description Default
sampler Callable[[int], Tensor]

Callable (n_pts: int) -> Tensor of shape (n_pts, d). Called each training step to produce fresh boundary sample points.

required
value BCValueFn

Callable Tensor -> Tensor giving the target value or normal derivative at boundary coordinates.

required
n_pts int

Number of boundary points sampled per training step.

100
Example

Left boundary of a 1-D+time domain at x=0

bc_left = BoundaryCondition( ... sampler=lambda n: torch.stack([ ... torch.zeros(n), # x = 0 ... torch.rand(n) * T_max, # t in [0, T] ... ], dim=-1), ... value=lambda coords: torch.zeros(coords.shape[0], 1), ... n_pts=50, ... )

Source code in src/anypinn/problems/pde.py
class BoundaryCondition:
    """
    Pairs a boundary region sampler with a prescribed value function.

    Args:
        sampler: Callable ``(n_pts: int) -> Tensor`` of shape ``(n_pts, d)``.
            Called each training step to produce fresh boundary sample points.
        value: Callable ``Tensor -> Tensor`` giving the target value or normal
            derivative at boundary coordinates.
        n_pts: Number of boundary points sampled per training step.

    Example:
        >>> # Left boundary of a 1-D+time domain at x=0
        >>> bc_left = BoundaryCondition(
        ...     sampler=lambda n: torch.stack([
        ...         torch.zeros(n),            # x = 0
        ...         torch.rand(n) * T_max,     # t in [0, T]
        ...     ], dim=-1),
        ...     value=lambda coords: torch.zeros(coords.shape[0], 1),
        ...     n_pts=50,
        ... )
    """

    def __init__(
        self,
        sampler: Callable[[int], Tensor],
        value: BCValueFn,
        n_pts: int = 100,
    ):
        self.sampler = sampler
        self.value = value
        self.n_pts = n_pts

n_pts = n_pts instance-attribute

sampler = sampler instance-attribute

value = value instance-attribute

__init__(sampler: Callable[[int], Tensor], value: BCValueFn, n_pts: int = 100)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    sampler: Callable[[int], Tensor],
    value: BCValueFn,
    n_pts: int = 100,
):
    self.sampler = sampler
    self.value = value
    self.n_pts = n_pts

DataConstraint

Bases: Constraint

Constraint enforcing fit to observed data. Minimizes ||y_hat - y||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
params ParamsRegistry

Parameters registry.

required
predict_data PredictDataFn

Function to predict data values from fields.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class DataConstraint(Constraint):
    """
    Constraint enforcing fit to observed data.
    Minimizes ``||y_hat - y||^2``.

    Args:
        fields: Fields registry.
        params: Parameters registry.
        predict_data: Function to predict data values from fields.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.predict_data = predict_data
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the data-fitting loss."""
        (x_data, y_data), _ = batch

        y_data_pred = self.predict_data(x_data, self.fields, self.params)

        loss: Tensor = criterion(y_data_pred, y_data)
        loss = self.weight * loss

        if log is not None:
            log("loss/data", loss)

        return loss

fields = fields instance-attribute

params = params instance-attribute

predict_data = predict_data instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.predict_data = predict_data
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the data-fitting loss.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the data-fitting loss."""
    (x_data, y_data), _ = batch

    y_data_pred = self.predict_data(x_data, self.fields, self.params)

    loss: Tensor = criterion(y_data_pred, y_data)
    loss = self.weight * loss

    if log is not None:
        log("loss/data", loss)

    return loss

DirichletBCConstraint

Bases: Constraint

Enforces the Dirichlet boundary condition: u(x_bc) = g(x_bc). Minimizes weight * criterion(u(x_bc), g(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target value function).

required
field Field

The neural field to enforce the condition on.

required
log_key str

Key used when logging the loss value.

'loss/bc_dirichlet'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class DirichletBCConstraint(Constraint):
    """
    Enforces the Dirichlet boundary condition: ``u(x_bc) = g(x_bc)``.
    Minimizes ``weight * criterion(u(x_bc), g(x_bc))``.

    Args:
        bc: Boundary condition (sampler + target value function).
        field: The neural field to enforce the condition on.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc: BoundaryCondition,
        field: Field,
        log_key: str = "loss/bc_dirichlet",
        weight: float = 1.0,
    ):
        self.bc = bc
        self.field = field
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the Dirichlet boundary condition loss."""
        device = next(self.field.parameters()).device
        x_bc = self.bc.sampler(self.bc.n_pts).to(device)
        u_pred = self.field(x_bc)
        g = self.bc.value(x_bc).to(device)
        loss: Tensor = self.weight * criterion(u_pred, g)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, log_key: str = 'loss/bc_dirichlet', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc: BoundaryCondition,
    field: Field,
    log_key: str = "loss/bc_dirichlet",
    weight: float = 1.0,
):
    self.bc = bc
    self.field = field
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the Dirichlet boundary condition loss.

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the Dirichlet boundary condition loss."""
    device = next(self.field.parameters()).device
    x_bc = self.bc.sampler(self.bc.n_pts).to(device)
    u_pred = self.field(x_bc)
    g = self.bc.value(x_bc).to(device)
    loss: Tensor = self.weight * criterion(u_pred, g)
    if log is not None:
        log(self.log_key, loss)
    return loss

ICConstraint

Bases: Constraint

Constraint enforcing Initial Conditions (IC). Minimizes ||y(t0) - Y0||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ICConstraint(Constraint):
    """
    Constraint enforcing Initial Conditions (IC).
    Minimizes ``||y(t0) - Y0||^2``.

    Args:
        fields: Fields registry.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        self.Y0 = props.y0.clone().reshape(-1, 1, 1)
        self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
        self.order = props.order
        self.fields = fields
        self.weight = weight

    @override
    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the constraint.
        """
        self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the initial-condition loss."""
        device = batch[1].device

        if self.t0.device != device:
            self.t0 = self.t0.to(device)
            self.Y0 = self.Y0.to(device)
            self.dY0 = [d.to(device) for d in self.dY0]

        n_fields = len(self.fields)

        if self.order == 1:
            # Fast path: no requires_grad needed, identical to original behaviour
            Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
            loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
        else:
            x0 = self.t0.detach().requires_grad_(True)
            preds = [f(x0) for f in self.fields.values()]
            Y0_preds = torch.stack(preds)
            total = criterion(Y0_preds, self.Y0)
            # Enforce derivative ICs by chaining from previous level
            currents = list(preds)
            for k in range(self.order - 1):
                next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
                dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
                total = total + criterion(dY0_k_pred, self.dY0[k])
                currents = next_level
            loss = self.weight * total

        if log is not None:
            log("loss/ic", loss)

        return loss

Y0 = props.y0.clone().reshape(-1, 1, 1) instance-attribute

dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)] instance-attribute

fields = fields instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    self.Y0 = props.y0.clone().reshape(-1, 1, 1)
    self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
    self.order = props.order
    self.fields = fields
    self.weight = weight

inject_context(context: InferredContext) -> None

Inject the context into the constraint.

Source code in src/anypinn/problems/ode.py
@override
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the constraint.
    """
    self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the initial-condition loss.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the initial-condition loss."""
    device = batch[1].device

    if self.t0.device != device:
        self.t0 = self.t0.to(device)
        self.Y0 = self.Y0.to(device)
        self.dY0 = [d.to(device) for d in self.dY0]

    n_fields = len(self.fields)

    if self.order == 1:
        # Fast path: no requires_grad needed, identical to original behaviour
        Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
        loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
    else:
        x0 = self.t0.detach().requires_grad_(True)
        preds = [f(x0) for f in self.fields.values()]
        Y0_preds = torch.stack(preds)
        total = criterion(Y0_preds, self.Y0)
        # Enforce derivative ICs by chaining from previous level
        currents = list(preds)
        for k in range(self.order - 1):
            next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
            dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
            total = total + criterion(dY0_k_pred, self.dY0[k])
            currents = next_level
        loss = self.weight * total

    if log is not None:
        log("loss/ic", loss)

    return loss

NeumannBCConstraint

Bases: Constraint

Enforces the Neumann boundary condition: du/dn(x_bc) = h(x_bc).

For a rectangular domain face whose outward normal is axis-aligned with dimension normal_dim, we have du/dn = du/dx[normal_dim]. Minimizes weight * criterion(du_dn(x_bc), h(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target normal-derivative function).

required
field Field

The neural field to enforce the condition on.

required
normal_dim int

Index of the spatial dimension the boundary normal points along.

required
log_key str

Key used when logging the loss value.

'loss/bc_neumann'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class NeumannBCConstraint(Constraint):
    """
    Enforces the Neumann boundary condition:
    ``du/dn(x_bc) = h(x_bc)``.

    For a rectangular domain face whose outward normal is axis-aligned with
    dimension ``normal_dim``, we have
    ``du/dn = du/dx[normal_dim]``.
    Minimizes
    ``weight * criterion(du_dn(x_bc), h(x_bc))``.

    Args:
        bc: Boundary condition (sampler + target normal-derivative function).
        field: The neural field to enforce the condition on.
        normal_dim: Index of the spatial dimension the boundary normal points along.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc: BoundaryCondition,
        field: Field,
        normal_dim: int,
        log_key: str = "loss/bc_neumann",
        weight: float = 1.0,
    ):
        self.bc = bc
        self.field = field
        self.normal_dim = normal_dim
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the Neumann boundary condition loss."""
        device = next(self.field.parameters()).device
        x_bc = self.bc.sampler(self.bc.n_pts).to(device).detach().requires_grad_(True)
        u_pred = self.field(x_bc)
        du_dn = diff_partial(u_pred, x_bc, dim=self.normal_dim)
        h = self.bc.value(x_bc.detach()).to(device)
        loss: Tensor = self.weight * criterion(du_dn, h)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

normal_dim = normal_dim instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, normal_dim: int, log_key: str = 'loss/bc_neumann', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc: BoundaryCondition,
    field: Field,
    normal_dim: int,
    log_key: str = "loss/bc_neumann",
    weight: float = 1.0,
):
    self.bc = bc
    self.field = field
    self.normal_dim = normal_dim
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the Neumann boundary condition loss.

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the Neumann boundary condition loss."""
    device = next(self.field.parameters()).device
    x_bc = self.bc.sampler(self.bc.n_pts).to(device).detach().requires_grad_(True)
    u_pred = self.field(x_bc)
    du_dn = diff_partial(u_pred, x_bc, dim=self.normal_dim)
    h = self.bc.value(x_bc.detach()).to(device)
    loss: Tensor = self.weight * criterion(du_dn, h)
    if log is not None:
        log(self.log_key, loss)
    return loss

ODECallable

Bases: Protocol

Protocol for ODE right-hand side callables.

First-order callables (ODEProperties.order == 1) receive three positional arguments and must match this Protocol exactly::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

Higher-order callables (order >= 2) receive a fourth positional argument derivs: list[Tensor], where derivs[k] is the (k+1)-th derivative of all fields stacked as (n_fields, m, 1)::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
           derivs: list[Tensor] = []) -> Tensor: ...

The Protocol is intentionally kept to three arguments so that existing first-order callables remain valid ODECallable implementations. ResidualsConstraint uses _ODECallableN internally to call higher-order functions with the correct signature.

Source code in src/anypinn/problems/ode.py
class ODECallable(Protocol):
    """
    Protocol for ODE right-hand side callables.

    **First-order** callables (``ODEProperties.order == 1``) receive three
    positional arguments and must match this Protocol exactly::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

    **Higher-order** callables (``order >= 2``) receive a fourth positional
    argument ``derivs: list[Tensor]``, where ``derivs[k]`` is the
    ``(k+1)``-th derivative of all fields stacked as ``(n_fields, m, 1)``::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
                   derivs: list[Tensor] = []) -> Tensor: ...

    The Protocol is intentionally kept to three arguments so that existing
    first-order callables remain valid ``ODECallable`` implementations.
    ``ResidualsConstraint`` uses ``_ODECallableN`` internally to call
    higher-order functions with the correct signature.
    """

    def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Source code in src/anypinn/problems/ode.py
def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

ODEHyperparameters dataclass

Bases: PINNHyperparameters

Hyperparameters for ODE inverse problems.

Extends PINNHyperparameters with per-constraint loss weights used by ODEInverseProblem.

Attributes:

Name Type Description
pde_weight float

Weight for the ODE residual loss term.

ic_weight float

Weight for the initial-condition loss term.

data_weight float

Weight for the data-fitting loss term.

Source code in src/anypinn/problems/ode.py
@dataclass(kw_only=True)
class ODEHyperparameters(PINNHyperparameters):
    """
    Hyperparameters for ODE inverse problems.

    Extends ``PINNHyperparameters`` with per-constraint loss weights
    used by ``ODEInverseProblem``.

    Attributes:
        pde_weight: Weight for the ODE residual loss term.
        ic_weight: Weight for the initial-condition loss term.
        data_weight: Weight for the data-fitting loss term.
    """

    pde_weight: float = 1.0
    ic_weight: float = 1.0
    data_weight: float = 1.0

data_weight: float = 1.0 class-attribute instance-attribute

ic_weight: float = 1.0 class-attribute instance-attribute

pde_weight: float = 1.0 class-attribute instance-attribute

__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None

ODEInverseProblem

Bases: Problem

Convenience class composing Residuals + IC + Data constraints.

Wires together ResidualsConstraint, ICConstraint, and DataConstraint with the loss criterion from hyperparameters. For forward-only problems (no data fitting), compose the constraints manually instead.

Example

problem = ODEInverseProblem( ... props=props, ... hp=ODEHyperparameters(...), ... fields={"S": field_s, "I": field_i, "R": field_r}, ... params={"beta": Parameter(ScalarConfig(init_value=0.3))}, ... predict_data=lambda t, f, p: torch.stack( ... [f"S", f"I", f"R"], dim=1 ... ).squeeze(-1), ... )

Source code in src/anypinn/problems/ode.py
class ODEInverseProblem(Problem):
    """
    Convenience class composing Residuals + IC + Data constraints.

    Wires together ``ResidualsConstraint``, ``ICConstraint``, and
    ``DataConstraint`` with the loss criterion from hyperparameters.
    For forward-only problems (no data fitting), compose the constraints
    manually instead.

    Example:
        >>> problem = ODEInverseProblem(
        ...     props=props,
        ...     hp=ODEHyperparameters(...),
        ...     fields={"S": field_s, "I": field_i, "R": field_r},
        ...     params={"beta": Parameter(ScalarConfig(init_value=0.3))},
        ...     predict_data=lambda t, f, p: torch.stack(
        ...         [f["S"](t), f["I"](t), f["R"](t)], dim=1
        ...     ).squeeze(-1),
        ... )
    """

    def __init__(
        self,
        props: ODEProperties,
        hp: ODEHyperparameters,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
    ) -> None:
        constraints: list[Constraint] = [
            ResidualsConstraint(
                props=props,
                fields=fields,
                params=params,
                weight=hp.pde_weight,
            ),
            ICConstraint(
                props=props,
                fields=fields,
                weight=hp.ic_weight,
            ),
            DataConstraint(
                fields=fields,
                params=params,
                predict_data=predict_data,
                weight=hp.data_weight,
            ),
        ]

        criterion = build_criterion(hp.criterion)

        super().__init__(
            constraints=constraints,
            criterion=criterion,
            fields=fields,
            params=params,
        )

__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    hp: ODEHyperparameters,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
) -> None:
    constraints: list[Constraint] = [
        ResidualsConstraint(
            props=props,
            fields=fields,
            params=params,
            weight=hp.pde_weight,
        ),
        ICConstraint(
            props=props,
            fields=fields,
            weight=hp.ic_weight,
        ),
        DataConstraint(
            fields=fields,
            params=params,
            predict_data=predict_data,
            weight=hp.data_weight,
        ),
    ]

    criterion = build_criterion(hp.criterion)

    super().__init__(
        constraints=constraints,
        criterion=criterion,
        fields=fields,
        params=params,
    )

ODEProperties dataclass

Properties defining an Ordinary Differential Equation problem.

Attributes:

Name Type Description
ode ODECallable

The ODE function (callable).

args ArgsRegistry

Arguments/Parameters for the ODE.

y0 Tensor

Initial conditions.

order int

Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative.

dy0 list[Tensor]

Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).

expected_args frozenset[str] | None

Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time.

Example

def sir_ode(t, y, args): ... S, I, R = y ... beta, gamma = args"beta", args"gamma" ... N = S + I + R ... dS = -beta * S * I / N ... dI = beta * S * I / N - gamma * I ... dR = gamma * I ... return torch.stack([dS, dI, dR]) props = ODEProperties( ... ode=sir_ode, ... args={"beta": Argument(0.3), "gamma": Argument(0.1)}, ... y0=torch.tensor([0.99, 0.01, 0.0]), ... )

Source code in src/anypinn/problems/ode.py
@dataclass
class ODEProperties:
    """
    Properties defining an Ordinary Differential Equation problem.

    Attributes:
        ode: The ODE function (callable).
        args: Arguments/Parameters for the ODE.
        y0: Initial conditions.
        order: Order of the ODE (default 1). For order=n, the ODE callable receives
            derivs as its last argument: derivs[k] is the (k+1)-th derivative.
        dy0: Initial conditions for lower-order derivatives, length = order-1.
            dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).
        expected_args: Optional set of arg keys the ODE function accesses.
            When provided, validated against the merged args+params at construction time.

    Example:
        >>> def sir_ode(t, y, args):
        ...     S, I, R = y
        ...     beta, gamma = args["beta"](t), args["gamma"](t)
        ...     N = S + I + R
        ...     dS = -beta * S * I / N
        ...     dI = beta * S * I / N - gamma * I
        ...     dR = gamma * I
        ...     return torch.stack([dS, dI, dR])
        >>> props = ODEProperties(
        ...     ode=sir_ode,
        ...     args={"beta": Argument(0.3), "gamma": Argument(0.1)},
        ...     y0=torch.tensor([0.99, 0.01, 0.0]),
        ... )
    """

    ode: ODECallable
    args: ArgsRegistry
    y0: Tensor
    order: int = 1
    dy0: list[Tensor] = dc_field(default_factory=list)
    expected_args: frozenset[str] | None = None

    def __post_init__(self) -> None:
        if self.order < 1:
            raise ValueError(f"order must be >= 1, got {self.order}")
        if len(self.dy0) != self.order - 1:
            raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

args: ArgsRegistry instance-attribute

dy0: list[Tensor] = dc_field(default_factory=list) class-attribute instance-attribute

expected_args: frozenset[str] | None = None class-attribute instance-attribute

ode: ODECallable instance-attribute

order: int = 1 class-attribute instance-attribute

y0: Tensor instance-attribute

__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None

__post_init__() -> None

Source code in src/anypinn/problems/ode.py
def __post_init__(self) -> None:
    if self.order < 1:
        raise ValueError(f"order must be >= 1, got {self.order}")
    if len(self.dy0) != self.order - 1:
        raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

PDEResidualConstraint

Bases: Constraint

Enforces a PDE interior residual: residual_fn(x, fields, params) ≈ 0. Minimizes weight * criterion(residual_fn(x_coll, fields, params), 0).

Parameters:

Name Type Description Default
fields FieldsRegistry

Registry of neural fields the residual function operates on. Pass only the subset needed — other fields in the Problem are ignored.

required
params ParamsRegistry

Registry of parameters the residual function uses.

required
residual_fn PDEResidualFn

Callable (x, fields, params) → Tensor of residuals. Should use anypinn.lib.diff operators for derivatives. The returned tensor is compared against zeros.

required
log_key str

Key used when logging the loss value.

'loss/pde_residual'
weight float

Loss term weight.

1.0
Example

from anypinn.lib.diff import grad, partial def heat_residual(x, fields, params): ... u = fields["u"] ... u_pred = u(x) ... u_t = partial(u_pred, x, dim=1) # du/dt ... u_x = partial(u_pred, x, dim=0) # du/dx ... u_xx = partial(u_x, x, dim=0) # d2u/dx2 ... alpha = params"alpha" ... return u_t - alpha * u_xx # residual = 0 constraint = PDEResidualConstraint( ... fields=fields, params=params, ... residual_fn=heat_residual, ... )

Source code in src/anypinn/problems/pde.py
class PDEResidualConstraint(Constraint):
    """
    Enforces a PDE interior residual: ``residual_fn(x, fields, params) ≈ 0``.
    Minimizes ``weight * criterion(residual_fn(x_coll, fields, params), 0)``.

    Args:
        fields: Registry of neural fields the residual function operates on.
            Pass only the subset needed — other fields in the Problem are ignored.
        params: Registry of parameters the residual function uses.
        residual_fn: Callable (x, fields, params) → Tensor of residuals.
            Should use ``anypinn.lib.diff`` operators for derivatives.
            The returned tensor is compared against zeros.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.

    Example:
        >>> from anypinn.lib.diff import grad, partial
        >>> def heat_residual(x, fields, params):
        ...     u = fields["u"]
        ...     u_pred = u(x)
        ...     u_t = partial(u_pred, x, dim=1)   # du/dt
        ...     u_x = partial(u_pred, x, dim=0)   # du/dx
        ...     u_xx = partial(u_x, x, dim=0)     # d2u/dx2
        ...     alpha = params["alpha"](x)
        ...     return u_t - alpha * u_xx          # residual = 0
        >>> constraint = PDEResidualConstraint(
        ...     fields=fields, params=params,
        ...     residual_fn=heat_residual,
        ... )
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        residual_fn: PDEResidualFn,
        log_key: str = "loss/pde_residual",
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.residual_fn = residual_fn
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the PDE interior residual loss."""
        _, x_coll = batch
        x_coll = x_coll.detach().requires_grad_(True)
        residual = self.residual_fn(x_coll, self.fields, self.params)
        loss: Tensor = self.weight * criterion(residual, torch.zeros_like(residual))
        if log is not None:
            log(self.log_key, loss)
        return loss

fields = fields instance-attribute

log_key = log_key instance-attribute

params = params instance-attribute

residual_fn = residual_fn instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, residual_fn: PDEResidualFn, log_key: str = 'loss/pde_residual', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    residual_fn: PDEResidualFn,
    log_key: str = "loss/pde_residual",
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.residual_fn = residual_fn
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the PDE interior residual loss.

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the PDE interior residual loss."""
    _, x_coll = batch
    x_coll = x_coll.detach().requires_grad_(True)
    residual = self.residual_fn(x_coll, self.fields, self.params)
    loss: Tensor = self.weight * criterion(residual, torch.zeros_like(residual))
    if log is not None:
        log(self.log_key, loss)
    return loss

PeriodicBCConstraint

Bases: Constraint

Enforces periodic boundary conditions: u(x_left, t) = u(x_right, t) and matching spatial derivatives.

The two boundary samplers must produce paired points — identical coordinates in every dimension except the periodic one — so that the value- and derivative-matching losses are meaningful.

Minimizes weight * [criterion(u_left, u_right) + criterion(du_left, du_right)].

Parameters:

Name Type Description Default
bc_left BoundaryCondition

Left boundary sampler (sampler + dummy value function).

required
bc_right BoundaryCondition

Right boundary sampler (sampler + dummy value function).

required
field Field

The neural field to enforce the condition on.

required
match_dim int

Spatial dimension index for the derivative matching.

0
log_key str

Key used when logging the loss value.

'loss/bc_periodic'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class PeriodicBCConstraint(Constraint):
    """
    Enforces periodic boundary conditions:
    ``u(x_left, t) = u(x_right, t)`` and matching spatial derivatives.

    The two boundary samplers must produce **paired** points — identical
    coordinates in every dimension except the periodic one — so that
    the value- and derivative-matching losses are meaningful.

    Minimizes
    ``weight * [criterion(u_left, u_right) + criterion(du_left, du_right)]``.

    Args:
        bc_left: Left boundary sampler (sampler + dummy value function).
        bc_right: Right boundary sampler (sampler + dummy value function).
        field: The neural field to enforce the condition on.
        match_dim: Spatial dimension index for the derivative matching.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc_left: BoundaryCondition,
        bc_right: BoundaryCondition,
        field: Field,
        match_dim: int = 0,
        log_key: str = "loss/bc_periodic",
        weight: float = 1.0,
    ):
        self.bc_left = bc_left
        self.bc_right = bc_right
        self.field = field
        self.match_dim = match_dim
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the periodic boundary condition loss."""
        device = next(self.field.parameters()).device
        n_pts = self.bc_left.n_pts

        x_left = self.bc_left.sampler(n_pts).to(device).detach().requires_grad_(True)
        x_right = self.bc_right.sampler(n_pts).to(device).detach().requires_grad_(True)

        u_left = self.field(x_left)
        u_right = self.field(x_right)

        # Value matching: u(x_left, t) = u(x_right, t)
        loss_val: Tensor = criterion(u_left, u_right)

        # Derivative matching: du/dx(x_left, t) = du/dx(x_right, t)
        du_left = diff_partial(u_left, x_left, dim=self.match_dim)
        du_right = diff_partial(u_right, x_right, dim=self.match_dim)
        loss_deriv: Tensor = criterion(du_left, du_right)

        loss: Tensor = self.weight * (loss_val + loss_deriv)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc_left = bc_left instance-attribute

bc_right = bc_right instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

match_dim = match_dim instance-attribute

weight = weight instance-attribute

__init__(bc_left: BoundaryCondition, bc_right: BoundaryCondition, field: Field, match_dim: int = 0, log_key: str = 'loss/bc_periodic', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc_left: BoundaryCondition,
    bc_right: BoundaryCondition,
    field: Field,
    match_dim: int = 0,
    log_key: str = "loss/bc_periodic",
    weight: float = 1.0,
):
    self.bc_left = bc_left
    self.bc_right = bc_right
    self.field = field
    self.match_dim = match_dim
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the periodic boundary condition loss.

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the periodic boundary condition loss."""
    device = next(self.field.parameters()).device
    n_pts = self.bc_left.n_pts

    x_left = self.bc_left.sampler(n_pts).to(device).detach().requires_grad_(True)
    x_right = self.bc_right.sampler(n_pts).to(device).detach().requires_grad_(True)

    u_left = self.field(x_left)
    u_right = self.field(x_right)

    # Value matching: u(x_left, t) = u(x_right, t)
    loss_val: Tensor = criterion(u_left, u_right)

    # Derivative matching: du/dx(x_left, t) = du/dx(x_right, t)
    du_left = diff_partial(u_left, x_left, dim=self.match_dim)
    du_right = diff_partial(u_right, x_right, dim=self.match_dim)
    loss_deriv: Tensor = criterion(du_left, du_right)

    loss: Tensor = self.weight * (loss_val + loss_deriv)
    if log is not None:
        log(self.log_key, loss)
    return loss

ResidualsConstraint

Bases: Constraint

Constraint enforcing the ODE residuals. Minimizes ||dy/dt - f(t, y)||^2.

Parameters:

Name Type Description Default
props ODEProperties

ODE properties.

required
fields FieldsRegistry

List of fields.

required
params ParamsRegistry

List of parameters.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ResidualsConstraint(Constraint):
    """
    Constraint enforcing the ODE residuals.
    Minimizes ``||dy/dt - f(t, y)||^2``.

    Args:
        props: ODE properties.
        fields: List of fields.
        params: List of parameters.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        merged_args: dict[str, object] = {**props.args, **params}
        if props.expected_args is not None:
            missing = props.expected_args - merged_args.keys()
            if missing:
                raise ValueError(
                    f"ODE function expects args {sorted(missing)!r} but they are not in "
                    f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
                )

        self.fields = fields
        self.weight = weight
        self.order = props.order

        self.ode = props.ode

        # add the trainable params as args
        self.args = props.args.copy()
        self.args.update(params)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the ODE residual loss over collocation points."""
        _, x_coll = batch

        n_fields = len(self.fields)
        x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
        preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
        y = torch.stack(preds)

        # Build all derivative levels by chaining (each level differentiates the previous)
        # deriv_levels[k][i] = (k+1)-th derivative of field i
        deriv_levels: list[list[Tensor]] = []
        currents = list(preds)
        for _ in range(self.order):
            next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
            deriv_levels.append(next_level)
            currents = next_level

        # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
        derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
        # The order-th derivative is the LHS to compare against f_out
        high_deriv = torch.stack(deriv_levels[self.order - 1])

        if self.order == 1:
            f_out = self.ode(x_coll, y, self.args)
        else:
            f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

        loss: Tensor = self.weight * criterion(high_deriv, f_out)

        if log is not None:
            log("loss/res", loss)

        return loss

args = props.args.copy() instance-attribute

fields = fields instance-attribute

ode = props.ode instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    merged_args: dict[str, object] = {**props.args, **params}
    if props.expected_args is not None:
        missing = props.expected_args - merged_args.keys()
        if missing:
            raise ValueError(
                f"ODE function expects args {sorted(missing)!r} but they are not in "
                f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
            )

    self.fields = fields
    self.weight = weight
    self.order = props.order

    self.ode = props.ode

    # add the trainable params as args
    self.args = props.args.copy()
    self.args.update(params)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the ODE residual loss over collocation points.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the ODE residual loss over collocation points."""
    _, x_coll = batch

    n_fields = len(self.fields)
    x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
    preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
    y = torch.stack(preds)

    # Build all derivative levels by chaining (each level differentiates the previous)
    # deriv_levels[k][i] = (k+1)-th derivative of field i
    deriv_levels: list[list[Tensor]] = []
    currents = list(preds)
    for _ in range(self.order):
        next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
        deriv_levels.append(next_level)
        currents = next_level

    # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
    derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
    # The order-th derivative is the LHS to compare against f_out
    high_deriv = torch.stack(deriv_levels[self.order - 1])

    if self.order == 1:
        f_out = self.ode(x_coll, y, self.args)
    else:
        f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

    loss: Tensor = self.weight * criterion(high_deriv, f_out)

    if log is not None:
        log("loss/res", loss)

    return loss