Skip to content

anypinn.problems

Problem templates and implementations.

BCValueFn: TypeAlias = Callable[[Tensor], Tensor] module-attribute

A callable that maps boundary coordinates (n_pts, d) → target values (n_pts, out_dim).

PDEResidualFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

A callable (x, fields, params) → residual tensor, expected to be zero at the solution.

PredictDataFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

__all__ = ['BCValueFn', 'BoundaryCondition', 'DataConstraint', 'DirichletBCConstraint', 'ICConstraint', 'NeumannBCConstraint', 'ODECallable', 'ODEHyperparameters', 'ODEInverseProblem', 'ODEProperties', 'PDEResidualConstraint', 'PDEResidualFn', 'PeriodicBCConstraint', 'PredictDataFn', 'ResidualsConstraint'] module-attribute

BoundaryCondition

Pairs a boundary region sampler with a prescribed value function.

Parameters:

Name Type Description Default
sampler Callable[[int], Tensor]

Callable (n_pts: int) -> Tensor of shape (n_pts, d). Called each training step to produce fresh boundary sample points.

required
value BCValueFn

Callable Tensor -> Tensor giving the target value or normal derivative at boundary coordinates.

required
n_pts int

Number of boundary points sampled per training step.

100
Source code in src/anypinn/problems/pde.py
class BoundaryCondition:
    """
    Pairs a boundary region sampler with a prescribed value function.

    Args:
        sampler: Callable ``(n_pts: int) -> Tensor`` of shape ``(n_pts, d)``.
            Called each training step to produce fresh boundary sample points.
        value: Callable ``Tensor -> Tensor`` giving the target value or normal
            derivative at boundary coordinates.
        n_pts: Number of boundary points sampled per training step.
    """

    def __init__(
        self,
        sampler: Callable[[int], Tensor],
        value: BCValueFn,
        n_pts: int = 100,
    ):
        self.sampler = sampler
        self.value = value
        self.n_pts = n_pts

n_pts = n_pts instance-attribute

sampler = sampler instance-attribute

value = value instance-attribute

__init__(sampler: Callable[[int], Tensor], value: BCValueFn, n_pts: int = 100)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    sampler: Callable[[int], Tensor],
    value: BCValueFn,
    n_pts: int = 100,
):
    self.sampler = sampler
    self.value = value
    self.n_pts = n_pts

DataConstraint

Bases: Constraint

Constraint enforcing fit to observed data. Minimizes \(\lVert \hat{{y}} - y \rVert^2\).

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
params ParamsRegistry

Parameters registry.

required
predict_data PredictDataFn

Function to predict data values from fields.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class DataConstraint(Constraint):
    """
    Constraint enforcing fit to observed data.
    Minimizes $\\lVert \\hat{{y}} - y \\rVert^2$.

    Args:
        fields: Fields registry.
        params: Parameters registry.
        predict_data: Function to predict data values from fields.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.predict_data = predict_data
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        (x_data, y_data), _ = batch

        y_data_pred = self.predict_data(x_data, self.fields, self.params)

        loss: Tensor = criterion(y_data_pred, y_data)
        loss = self.weight * loss

        if log is not None:
            log("loss/data", loss)

        return loss

fields = fields instance-attribute

params = params instance-attribute

predict_data = predict_data instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.predict_data = predict_data
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    (x_data, y_data), _ = batch

    y_data_pred = self.predict_data(x_data, self.fields, self.params)

    loss: Tensor = criterion(y_data_pred, y_data)
    loss = self.weight * loss

    if log is not None:
        log("loss/data", loss)

    return loss

DirichletBCConstraint

Bases: Constraint

Enforces the Dirichlet boundary condition: u(x_bc) = g(x_bc). Minimizes weight * criterion(u(x_bc), g(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target value function).

required
field Field

The neural field to enforce the condition on.

required
log_key str

Key used when logging the loss value.

'loss/bc_dirichlet'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class DirichletBCConstraint(Constraint):
    """
    Enforces the Dirichlet boundary condition: u(x_bc) = g(x_bc).
    Minimizes ``weight * criterion(u(x_bc), g(x_bc))``.

    Args:
        bc: Boundary condition (sampler + target value function).
        field: The neural field to enforce the condition on.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc: BoundaryCondition,
        field: Field,
        log_key: str = "loss/bc_dirichlet",
        weight: float = 1.0,
    ):
        self.bc = bc
        self.field = field
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        device = next(self.field.parameters()).device
        x_bc = self.bc.sampler(self.bc.n_pts).to(device)
        u_pred = self.field(x_bc)
        g = self.bc.value(x_bc).to(device)
        loss: Tensor = self.weight * criterion(u_pred, g)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, log_key: str = 'loss/bc_dirichlet', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc: BoundaryCondition,
    field: Field,
    log_key: str = "loss/bc_dirichlet",
    weight: float = 1.0,
):
    self.bc = bc
    self.field = field
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    device = next(self.field.parameters()).device
    x_bc = self.bc.sampler(self.bc.n_pts).to(device)
    u_pred = self.field(x_bc)
    g = self.bc.value(x_bc).to(device)
    loss: Tensor = self.weight * criterion(u_pred, g)
    if log is not None:
        log(self.log_key, loss)
    return loss

ICConstraint

Bases: Constraint

Constraint enforcing Initial Conditions (IC). Minimizes \(\lVert y(t_0) - Y_0 \rVert^2\).

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ICConstraint(Constraint):
    """
    Constraint enforcing Initial Conditions (IC).
    Minimizes $\\lVert y(t_0) - Y_0 \\rVert^2$.

    Args:
        fields: Fields registry.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        self.Y0 = props.y0.clone().reshape(-1, 1, 1)
        self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
        self.order = props.order
        self.fields = fields
        self.weight = weight

    @override
    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the constraint.
        """
        self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        device = batch[1].device

        if self.t0.device != device:
            self.t0 = self.t0.to(device)
            self.Y0 = self.Y0.to(device)
            self.dY0 = [d.to(device) for d in self.dY0]

        n_fields = len(self.fields)

        if self.order == 1:
            # Fast path: no requires_grad needed, identical to original behaviour
            Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
            loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
        else:
            x0 = self.t0.detach().requires_grad_(True)
            preds = [f(x0) for f in self.fields.values()]
            Y0_preds = torch.stack(preds)
            total = criterion(Y0_preds, self.Y0)
            # Enforce derivative ICs by chaining from previous level
            currents = list(preds)
            for k in range(self.order - 1):
                next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
                dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
                total = total + criterion(dY0_k_pred, self.dY0[k])
                currents = next_level
            loss = self.weight * total

        if log is not None:
            log("loss/ic", loss)

        return loss

Y0 = props.y0.clone().reshape(-1, 1, 1) instance-attribute

dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)] instance-attribute

fields = fields instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    self.Y0 = props.y0.clone().reshape(-1, 1, 1)
    self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
    self.order = props.order
    self.fields = fields
    self.weight = weight

inject_context(context: InferredContext) -> None

Inject the context into the constraint.

Source code in src/anypinn/problems/ode.py
@override
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the constraint.
    """
    self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    device = batch[1].device

    if self.t0.device != device:
        self.t0 = self.t0.to(device)
        self.Y0 = self.Y0.to(device)
        self.dY0 = [d.to(device) for d in self.dY0]

    n_fields = len(self.fields)

    if self.order == 1:
        # Fast path: no requires_grad needed, identical to original behaviour
        Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
        loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
    else:
        x0 = self.t0.detach().requires_grad_(True)
        preds = [f(x0) for f in self.fields.values()]
        Y0_preds = torch.stack(preds)
        total = criterion(Y0_preds, self.Y0)
        # Enforce derivative ICs by chaining from previous level
        currents = list(preds)
        for k in range(self.order - 1):
            next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
            dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
            total = total + criterion(dY0_k_pred, self.dY0[k])
            currents = next_level
        loss = self.weight * total

    if log is not None:
        log("loss/ic", loss)

    return loss

NeumannBCConstraint

Bases: Constraint

Enforces the Neumann boundary condition: \(\partial u / \partial n (x_{bc}) = h(x_{bc})\).

For a rectangular domain face whose outward normal is axis-aligned with dimension normal_dim, we have \(\partial u / \partial n = \partial u / \partial x_{\mathrm{normal\_dim}}\). Minimizes weight * criterion(du_dn(x_bc), h(x_bc)).

Parameters:

Name Type Description Default
bc BoundaryCondition

Boundary condition (sampler + target normal-derivative function).

required
field Field

The neural field to enforce the condition on.

required
normal_dim int

Index of the spatial dimension the boundary normal points along.

required
log_key str

Key used when logging the loss value.

'loss/bc_neumann'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class NeumannBCConstraint(Constraint):
    """
    Enforces the Neumann boundary condition:
    $\\partial u / \\partial n (x_{bc}) = h(x_{bc})$.

    For a rectangular domain face whose outward normal is axis-aligned with
    dimension ``normal_dim``, we have
    $\\partial u / \\partial n = \\partial u / \\partial x_{\\mathrm{normal\\_dim}}$.
    Minimizes
    ``weight * criterion(du_dn(x_bc), h(x_bc))``.

    Args:
        bc: Boundary condition (sampler + target normal-derivative function).
        field: The neural field to enforce the condition on.
        normal_dim: Index of the spatial dimension the boundary normal points along.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc: BoundaryCondition,
        field: Field,
        normal_dim: int,
        log_key: str = "loss/bc_neumann",
        weight: float = 1.0,
    ):
        self.bc = bc
        self.field = field
        self.normal_dim = normal_dim
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        device = next(self.field.parameters()).device
        x_bc = self.bc.sampler(self.bc.n_pts).to(device).detach().requires_grad_(True)
        u_pred = self.field(x_bc)
        du_dn = diff_partial(u_pred, x_bc, dim=self.normal_dim)
        h = self.bc.value(x_bc.detach()).to(device)
        loss: Tensor = self.weight * criterion(du_dn, h)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc = bc instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

normal_dim = normal_dim instance-attribute

weight = weight instance-attribute

__init__(bc: BoundaryCondition, field: Field, normal_dim: int, log_key: str = 'loss/bc_neumann', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc: BoundaryCondition,
    field: Field,
    normal_dim: int,
    log_key: str = "loss/bc_neumann",
    weight: float = 1.0,
):
    self.bc = bc
    self.field = field
    self.normal_dim = normal_dim
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    device = next(self.field.parameters()).device
    x_bc = self.bc.sampler(self.bc.n_pts).to(device).detach().requires_grad_(True)
    u_pred = self.field(x_bc)
    du_dn = diff_partial(u_pred, x_bc, dim=self.normal_dim)
    h = self.bc.value(x_bc.detach()).to(device)
    loss: Tensor = self.weight * criterion(du_dn, h)
    if log is not None:
        log(self.log_key, loss)
    return loss

ODECallable

Bases: Protocol

Protocol for ODE right-hand side callables.

First-order callables (ODEProperties.order == 1) receive three positional arguments and must match this Protocol exactly::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

Higher-order callables (order >= 2) receive a fourth positional argument derivs: list[Tensor], where derivs[k] is the (k+1)-th derivative of all fields stacked as (n_fields, m, 1)::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
           derivs: list[Tensor] = []) -> Tensor: ...

The Protocol is intentionally kept to three arguments so that existing first-order callables remain valid ODECallable implementations. ResidualsConstraint uses _ODECallableN internally to call higher-order functions with the correct signature.

Source code in src/anypinn/problems/ode.py
class ODECallable(Protocol):
    """
    Protocol for ODE right-hand side callables.

    **First-order** callables (``ODEProperties.order == 1``) receive three
    positional arguments and must match this Protocol exactly::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

    **Higher-order** callables (``order >= 2``) receive a fourth positional
    argument ``derivs: list[Tensor]``, where ``derivs[k]`` is the
    ``(k+1)``-th derivative of all fields stacked as ``(n_fields, m, 1)``::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
                   derivs: list[Tensor] = []) -> Tensor: ...

    The Protocol is intentionally kept to three arguments so that existing
    first-order callables remain valid ``ODECallable`` implementations.
    ``ResidualsConstraint`` uses ``_ODECallableN`` internally to call
    higher-order functions with the correct signature.
    """

    def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Source code in src/anypinn/problems/ode.py
def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

ODEHyperparameters dataclass

Bases: PINNHyperparameters

Hyperparameters for ODE inverse problems.

Source code in src/anypinn/problems/ode.py
@dataclass(kw_only=True)
class ODEHyperparameters(PINNHyperparameters):
    """
    Hyperparameters for ODE inverse problems.
    """

    pde_weight: float = 1.0
    ic_weight: float = 1.0
    data_weight: float = 1.0

data_weight: float = 1.0 class-attribute instance-attribute

ic_weight: float = 1.0 class-attribute instance-attribute

pde_weight: float = 1.0 class-attribute instance-attribute

__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None

ODEInverseProblem

Bases: Problem

Generic ODE Inverse Problem. Composes Residuals + IC + Data constraints with MSELoss.

Source code in src/anypinn/problems/ode.py
class ODEInverseProblem(Problem):
    """
    Generic ODE Inverse Problem.
    Composes Residuals + IC + Data constraints with MSELoss.
    """

    def __init__(
        self,
        props: ODEProperties,
        hp: ODEHyperparameters,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
    ) -> None:
        constraints: list[Constraint] = [
            ResidualsConstraint(
                props=props,
                fields=fields,
                params=params,
                weight=hp.pde_weight,
            ),
            ICConstraint(
                props=props,
                fields=fields,
                weight=hp.ic_weight,
            ),
            DataConstraint(
                fields=fields,
                params=params,
                predict_data=predict_data,
                weight=hp.data_weight,
            ),
        ]

        criterion = build_criterion(hp.criterion)

        super().__init__(
            constraints=constraints,
            criterion=criterion,
            fields=fields,
            params=params,
        )

__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    hp: ODEHyperparameters,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
) -> None:
    constraints: list[Constraint] = [
        ResidualsConstraint(
            props=props,
            fields=fields,
            params=params,
            weight=hp.pde_weight,
        ),
        ICConstraint(
            props=props,
            fields=fields,
            weight=hp.ic_weight,
        ),
        DataConstraint(
            fields=fields,
            params=params,
            predict_data=predict_data,
            weight=hp.data_weight,
        ),
    ]

    criterion = build_criterion(hp.criterion)

    super().__init__(
        constraints=constraints,
        criterion=criterion,
        fields=fields,
        params=params,
    )

ODEProperties dataclass

Properties defining an Ordinary Differential Equation problem.

Attributes:

Name Type Description
ode ODECallable

The ODE function (callable).

args ArgsRegistry

Arguments/Parameters for the ODE.

y0 Tensor

Initial conditions.

order int

Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative.

dy0 list[Tensor]

Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).

expected_args frozenset[str] | None

Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time.

Source code in src/anypinn/problems/ode.py
@dataclass
class ODEProperties:
    """
    Properties defining an Ordinary Differential Equation problem.

    Attributes:
        ode: The ODE function (callable).
        args: Arguments/Parameters for the ODE.
        y0: Initial conditions.
        order: Order of the ODE (default 1). For order=n, the ODE callable receives
            derivs as its last argument: derivs[k] is the (k+1)-th derivative.
        dy0: Initial conditions for lower-order derivatives, length = order-1.
            dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).
        expected_args: Optional set of arg keys the ODE function accesses.
            When provided, validated against the merged args+params at construction time.
    """

    ode: ODECallable
    args: ArgsRegistry
    y0: Tensor
    order: int = 1
    dy0: list[Tensor] = dc_field(default_factory=list)
    expected_args: frozenset[str] | None = None

    def __post_init__(self) -> None:
        if self.order < 1:
            raise ValueError(f"order must be >= 1, got {self.order}")
        if len(self.dy0) != self.order - 1:
            raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

args: ArgsRegistry instance-attribute

dy0: list[Tensor] = dc_field(default_factory=list) class-attribute instance-attribute

expected_args: frozenset[str] | None = None class-attribute instance-attribute

ode: ODECallable instance-attribute

order: int = 1 class-attribute instance-attribute

y0: Tensor instance-attribute

__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None

__post_init__() -> None

Source code in src/anypinn/problems/ode.py
def __post_init__(self) -> None:
    if self.order < 1:
        raise ValueError(f"order must be >= 1, got {self.order}")
    if len(self.dy0) != self.order - 1:
        raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

PDEResidualConstraint

Bases: Constraint

Enforces a PDE interior residual: residual_fn(x, fields, params) ≈ 0. Minimizes weight * criterion(residual_fn(x_coll, fields, params), 0).

Parameters:

Name Type Description Default
fields FieldsRegistry

Registry of neural fields the residual function operates on. Pass only the subset needed — other fields in the Problem are ignored.

required
params ParamsRegistry

Registry of parameters the residual function uses.

required
residual_fn PDEResidualFn

Callable (x, fields, params) → Tensor of residuals. Should use anypinn.lib.diff operators for derivatives. The returned tensor is compared against zeros.

required
log_key str

Key used when logging the loss value.

'loss/pde_residual'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class PDEResidualConstraint(Constraint):
    """
    Enforces a PDE interior residual: ``residual_fn(x, fields, params) ≈ 0``.
    Minimizes ``weight * criterion(residual_fn(x_coll, fields, params), 0)``.

    Args:
        fields: Registry of neural fields the residual function operates on.
            Pass only the subset needed — other fields in the Problem are ignored.
        params: Registry of parameters the residual function uses.
        residual_fn: Callable (x, fields, params) → Tensor of residuals.
            Should use ``anypinn.lib.diff`` operators for derivatives.
            The returned tensor is compared against zeros.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        residual_fn: PDEResidualFn,
        log_key: str = "loss/pde_residual",
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.residual_fn = residual_fn
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        _, x_coll = batch
        x_coll = x_coll.detach().requires_grad_(True)
        residual = self.residual_fn(x_coll, self.fields, self.params)
        loss: Tensor = self.weight * criterion(residual, torch.zeros_like(residual))
        if log is not None:
            log(self.log_key, loss)
        return loss

fields = fields instance-attribute

log_key = log_key instance-attribute

params = params instance-attribute

residual_fn = residual_fn instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, residual_fn: PDEResidualFn, log_key: str = 'loss/pde_residual', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    residual_fn: PDEResidualFn,
    log_key: str = "loss/pde_residual",
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.residual_fn = residual_fn
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    _, x_coll = batch
    x_coll = x_coll.detach().requires_grad_(True)
    residual = self.residual_fn(x_coll, self.fields, self.params)
    loss: Tensor = self.weight * criterion(residual, torch.zeros_like(residual))
    if log is not None:
        log(self.log_key, loss)
    return loss

PeriodicBCConstraint

Bases: Constraint

Enforces periodic boundary conditions: u(x_left, t) = u(x_right, t) and ∂u/∂x(x_left, t) = ∂u/∂x(x_right, t).

The two boundary samplers must produce paired points — identical coordinates in every dimension except the periodic one — so that the value- and derivative-matching losses are meaningful.

Minimizes weight * [criterion(u_left, u_right) + criterion(du_left, du_right)].

Parameters:

Name Type Description Default
bc_left BoundaryCondition

Left boundary sampler (sampler + dummy value function).

required
bc_right BoundaryCondition

Right boundary sampler (sampler + dummy value function).

required
field Field

The neural field to enforce the condition on.

required
match_dim int

Spatial dimension index for the derivative matching.

0
log_key str

Key used when logging the loss value.

'loss/bc_periodic'
weight float

Loss term weight.

1.0
Source code in src/anypinn/problems/pde.py
class PeriodicBCConstraint(Constraint):
    """
    Enforces periodic boundary conditions:
    ``u(x_left, t) = u(x_right, t)`` and
    ``∂u/∂x(x_left, t) = ∂u/∂x(x_right, t)``.

    The two boundary samplers must produce **paired** points — identical
    coordinates in every dimension except the periodic one — so that
    the value- and derivative-matching losses are meaningful.

    Minimizes
    ``weight * [criterion(u_left, u_right) + criterion(du_left, du_right)]``.

    Args:
        bc_left: Left boundary sampler (sampler + dummy value function).
        bc_right: Right boundary sampler (sampler + dummy value function).
        field: The neural field to enforce the condition on.
        match_dim: Spatial dimension index for the derivative matching.
        log_key: Key used when logging the loss value.
        weight: Loss term weight.
    """

    def __init__(
        self,
        bc_left: BoundaryCondition,
        bc_right: BoundaryCondition,
        field: Field,
        match_dim: int = 0,
        log_key: str = "loss/bc_periodic",
        weight: float = 1.0,
    ):
        self.bc_left = bc_left
        self.bc_right = bc_right
        self.field = field
        self.match_dim = match_dim
        self.log_key = log_key
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        device = next(self.field.parameters()).device
        n_pts = self.bc_left.n_pts

        x_left = self.bc_left.sampler(n_pts).to(device).detach().requires_grad_(True)
        x_right = self.bc_right.sampler(n_pts).to(device).detach().requires_grad_(True)

        u_left = self.field(x_left)
        u_right = self.field(x_right)

        # Value matching: u(x_left, t) = u(x_right, t)
        loss_val: Tensor = criterion(u_left, u_right)

        # Derivative matching: du/dx(x_left, t) = du/dx(x_right, t)
        du_left = diff_partial(u_left, x_left, dim=self.match_dim)
        du_right = diff_partial(u_right, x_right, dim=self.match_dim)
        loss_deriv: Tensor = criterion(du_left, du_right)

        loss: Tensor = self.weight * (loss_val + loss_deriv)
        if log is not None:
            log(self.log_key, loss)
        return loss

bc_left = bc_left instance-attribute

bc_right = bc_right instance-attribute

field = field instance-attribute

log_key = log_key instance-attribute

match_dim = match_dim instance-attribute

weight = weight instance-attribute

__init__(bc_left: BoundaryCondition, bc_right: BoundaryCondition, field: Field, match_dim: int = 0, log_key: str = 'loss/bc_periodic', weight: float = 1.0)

Source code in src/anypinn/problems/pde.py
def __init__(
    self,
    bc_left: BoundaryCondition,
    bc_right: BoundaryCondition,
    field: Field,
    match_dim: int = 0,
    log_key: str = "loss/bc_periodic",
    weight: float = 1.0,
):
    self.bc_left = bc_left
    self.bc_right = bc_right
    self.field = field
    self.match_dim = match_dim
    self.log_key = log_key
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/pde.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    device = next(self.field.parameters()).device
    n_pts = self.bc_left.n_pts

    x_left = self.bc_left.sampler(n_pts).to(device).detach().requires_grad_(True)
    x_right = self.bc_right.sampler(n_pts).to(device).detach().requires_grad_(True)

    u_left = self.field(x_left)
    u_right = self.field(x_right)

    # Value matching: u(x_left, t) = u(x_right, t)
    loss_val: Tensor = criterion(u_left, u_right)

    # Derivative matching: du/dx(x_left, t) = du/dx(x_right, t)
    du_left = diff_partial(u_left, x_left, dim=self.match_dim)
    du_right = diff_partial(u_right, x_right, dim=self.match_dim)
    loss_deriv: Tensor = criterion(du_left, du_right)

    loss: Tensor = self.weight * (loss_val + loss_deriv)
    if log is not None:
        log(self.log_key, loss)
    return loss

ResidualsConstraint

Bases: Constraint

Constraint enforcing the ODE residuals. Minimizes \(\lVert \partial y / \partial t - f(t, y) \rVert^2\).

Parameters:

Name Type Description Default
props ODEProperties

ODE properties.

required
fields FieldsRegistry

List of fields.

required
params ParamsRegistry

List of parameters.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ResidualsConstraint(Constraint):
    """
    Constraint enforcing the ODE residuals.
    Minimizes $\\lVert \\partial y / \\partial t - f(t, y) \\rVert^2$.

    Args:
        props: ODE properties.
        fields: List of fields.
        params: List of parameters.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        merged_args: dict[str, object] = {**props.args, **params}
        if props.expected_args is not None:
            missing = props.expected_args - merged_args.keys()
            if missing:
                raise ValueError(
                    f"ODE function expects args {sorted(missing)!r} but they are not in "
                    f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
                )

        self.fields = fields
        self.weight = weight
        self.order = props.order

        self.ode = props.ode

        # add the trainable params as args
        self.args = props.args.copy()
        self.args.update(params)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        _, x_coll = batch

        n_fields = len(self.fields)
        x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
        preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
        y = torch.stack(preds)

        # Build all derivative levels by chaining (each level differentiates the previous)
        # deriv_levels[k][i] = (k+1)-th derivative of field i
        deriv_levels: list[list[Tensor]] = []
        currents = list(preds)
        for _ in range(self.order):
            next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
            deriv_levels.append(next_level)
            currents = next_level

        # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
        derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
        # The order-th derivative is the LHS to compare against f_out
        high_deriv = torch.stack(deriv_levels[self.order - 1])

        if self.order == 1:
            f_out = self.ode(x_coll, y, self.args)
        else:
            f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

        loss: Tensor = self.weight * criterion(high_deriv, f_out)

        if log is not None:
            log("loss/res", loss)

        return loss

args = props.args.copy() instance-attribute

fields = fields instance-attribute

ode = props.ode instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    merged_args: dict[str, object] = {**props.args, **params}
    if props.expected_args is not None:
        missing = props.expected_args - merged_args.keys()
        if missing:
            raise ValueError(
                f"ODE function expects args {sorted(missing)!r} but they are not in "
                f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
            )

    self.fields = fields
    self.weight = weight
    self.order = props.order

    self.ode = props.ode

    # add the trainable params as args
    self.args = props.args.copy()
    self.args.update(params)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    _, x_coll = batch

    n_fields = len(self.fields)
    x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
    preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
    y = torch.stack(preds)

    # Build all derivative levels by chaining (each level differentiates the previous)
    # deriv_levels[k][i] = (k+1)-th derivative of field i
    deriv_levels: list[list[Tensor]] = []
    currents = list(preds)
    for _ in range(self.order):
        next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
        deriv_levels.append(next_level)
        currents = next_level

    # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
    derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
    # The order-th derivative is the LHS to compare against f_out
    high_deriv = torch.stack(deriv_levels[self.order - 1])

    if self.order == 1:
        f_out = self.ode(x_coll, y, self.args)
    else:
        f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

    loss: Tensor = self.weight * criterion(high_deriv, f_out)

    if log is not None:
        log("loss/res", loss)

    return loss