Skip to content

anypinn.problems.ode

PredictDataFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

DataConstraint

Bases: Constraint

Constraint enforcing fit to observed data. Minimizes \(\lVert \hat{{y}} - y \rVert^2\).

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
params ParamsRegistry

Parameters registry.

required
predict_data PredictDataFn

Function to predict data values from fields.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class DataConstraint(Constraint):
    """
    Constraint enforcing fit to observed data.
    Minimizes $\\lVert \\hat{{y}} - y \\rVert^2$.

    Args:
        fields: Fields registry.
        params: Parameters registry.
        predict_data: Function to predict data values from fields.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.predict_data = predict_data
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        (x_data, y_data), _ = batch

        y_data_pred = self.predict_data(x_data, self.fields, self.params)

        loss: Tensor = criterion(y_data_pred, y_data)
        loss = self.weight * loss

        if log is not None:
            log("loss/data", loss)

        return loss

fields = fields instance-attribute

params = params instance-attribute

predict_data = predict_data instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.predict_data = predict_data
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    (x_data, y_data), _ = batch

    y_data_pred = self.predict_data(x_data, self.fields, self.params)

    loss: Tensor = criterion(y_data_pred, y_data)
    loss = self.weight * loss

    if log is not None:
        log("loss/data", loss)

    return loss

ICConstraint

Bases: Constraint

Constraint enforcing Initial Conditions (IC). Minimizes \(\lVert y(t_0) - Y_0 \rVert^2\).

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ICConstraint(Constraint):
    """
    Constraint enforcing Initial Conditions (IC).
    Minimizes $\\lVert y(t_0) - Y_0 \\rVert^2$.

    Args:
        fields: Fields registry.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        self.Y0 = props.y0.clone().reshape(-1, 1, 1)
        self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
        self.order = props.order
        self.fields = fields
        self.weight = weight

    @override
    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the constraint.
        """
        self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        device = batch[1].device

        if self.t0.device != device:
            self.t0 = self.t0.to(device)
            self.Y0 = self.Y0.to(device)
            self.dY0 = [d.to(device) for d in self.dY0]

        n_fields = len(self.fields)

        if self.order == 1:
            # Fast path: no requires_grad needed, identical to original behaviour
            Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
            loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
        else:
            x0 = self.t0.detach().requires_grad_(True)
            preds = [f(x0) for f in self.fields.values()]
            Y0_preds = torch.stack(preds)
            total = criterion(Y0_preds, self.Y0)
            # Enforce derivative ICs by chaining from previous level
            currents = list(preds)
            for k in range(self.order - 1):
                next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
                dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
                total = total + criterion(dY0_k_pred, self.dY0[k])
                currents = next_level
            loss = self.weight * total

        if log is not None:
            log("loss/ic", loss)

        return loss

Y0 = props.y0.clone().reshape(-1, 1, 1) instance-attribute

dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)] instance-attribute

fields = fields instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    self.Y0 = props.y0.clone().reshape(-1, 1, 1)
    self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
    self.order = props.order
    self.fields = fields
    self.weight = weight

inject_context(context: InferredContext) -> None

Inject the context into the constraint.

Source code in src/anypinn/problems/ode.py
@override
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the constraint.
    """
    self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    device = batch[1].device

    if self.t0.device != device:
        self.t0 = self.t0.to(device)
        self.Y0 = self.Y0.to(device)
        self.dY0 = [d.to(device) for d in self.dY0]

    n_fields = len(self.fields)

    if self.order == 1:
        # Fast path: no requires_grad needed, identical to original behaviour
        Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
        loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
    else:
        x0 = self.t0.detach().requires_grad_(True)
        preds = [f(x0) for f in self.fields.values()]
        Y0_preds = torch.stack(preds)
        total = criterion(Y0_preds, self.Y0)
        # Enforce derivative ICs by chaining from previous level
        currents = list(preds)
        for k in range(self.order - 1):
            next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
            dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
            total = total + criterion(dY0_k_pred, self.dY0[k])
            currents = next_level
        loss = self.weight * total

    if log is not None:
        log("loss/ic", loss)

    return loss

ODECallable

Bases: Protocol

Protocol for ODE right-hand side callables.

First-order callables (ODEProperties.order == 1) receive three positional arguments and must match this Protocol exactly::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

Higher-order callables (order >= 2) receive a fourth positional argument derivs: list[Tensor], where derivs[k] is the (k+1)-th derivative of all fields stacked as (n_fields, m, 1)::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
           derivs: list[Tensor] = []) -> Tensor: ...

The Protocol is intentionally kept to three arguments so that existing first-order callables remain valid ODECallable implementations. ResidualsConstraint uses _ODECallableN internally to call higher-order functions with the correct signature.

Source code in src/anypinn/problems/ode.py
class ODECallable(Protocol):
    """
    Protocol for ODE right-hand side callables.

    **First-order** callables (``ODEProperties.order == 1``) receive three
    positional arguments and must match this Protocol exactly::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

    **Higher-order** callables (``order >= 2``) receive a fourth positional
    argument ``derivs: list[Tensor]``, where ``derivs[k]`` is the
    ``(k+1)``-th derivative of all fields stacked as ``(n_fields, m, 1)``::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
                   derivs: list[Tensor] = []) -> Tensor: ...

    The Protocol is intentionally kept to three arguments so that existing
    first-order callables remain valid ``ODECallable`` implementations.
    ``ResidualsConstraint`` uses ``_ODECallableN`` internally to call
    higher-order functions with the correct signature.
    """

    def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Source code in src/anypinn/problems/ode.py
def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

ODEHyperparameters dataclass

Bases: PINNHyperparameters

Hyperparameters for ODE inverse problems.

Source code in src/anypinn/problems/ode.py
@dataclass(kw_only=True)
class ODEHyperparameters(PINNHyperparameters):
    """
    Hyperparameters for ODE inverse problems.
    """

    pde_weight: float = 1.0
    ic_weight: float = 1.0
    data_weight: float = 1.0

data_weight: float = 1.0 class-attribute instance-attribute

ic_weight: float = 1.0 class-attribute instance-attribute

pde_weight: float = 1.0 class-attribute instance-attribute

__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None

ODEInverseProblem

Bases: Problem

Generic ODE Inverse Problem. Composes Residuals + IC + Data constraints with MSELoss.

Source code in src/anypinn/problems/ode.py
class ODEInverseProblem(Problem):
    """
    Generic ODE Inverse Problem.
    Composes Residuals + IC + Data constraints with MSELoss.
    """

    def __init__(
        self,
        props: ODEProperties,
        hp: ODEHyperparameters,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
    ) -> None:
        constraints: list[Constraint] = [
            ResidualsConstraint(
                props=props,
                fields=fields,
                params=params,
                weight=hp.pde_weight,
            ),
            ICConstraint(
                props=props,
                fields=fields,
                weight=hp.ic_weight,
            ),
            DataConstraint(
                fields=fields,
                params=params,
                predict_data=predict_data,
                weight=hp.data_weight,
            ),
        ]

        criterion = build_criterion(hp.criterion)

        super().__init__(
            constraints=constraints,
            criterion=criterion,
            fields=fields,
            params=params,
        )

__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    hp: ODEHyperparameters,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
) -> None:
    constraints: list[Constraint] = [
        ResidualsConstraint(
            props=props,
            fields=fields,
            params=params,
            weight=hp.pde_weight,
        ),
        ICConstraint(
            props=props,
            fields=fields,
            weight=hp.ic_weight,
        ),
        DataConstraint(
            fields=fields,
            params=params,
            predict_data=predict_data,
            weight=hp.data_weight,
        ),
    ]

    criterion = build_criterion(hp.criterion)

    super().__init__(
        constraints=constraints,
        criterion=criterion,
        fields=fields,
        params=params,
    )

ODEProperties dataclass

Properties defining an Ordinary Differential Equation problem.

Attributes:

Name Type Description
ode ODECallable

The ODE function (callable).

args ArgsRegistry

Arguments/Parameters for the ODE.

y0 Tensor

Initial conditions.

order int

Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative.

dy0 list[Tensor]

Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).

expected_args frozenset[str] | None

Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time.

Source code in src/anypinn/problems/ode.py
@dataclass
class ODEProperties:
    """
    Properties defining an Ordinary Differential Equation problem.

    Attributes:
        ode: The ODE function (callable).
        args: Arguments/Parameters for the ODE.
        y0: Initial conditions.
        order: Order of the ODE (default 1). For order=n, the ODE callable receives
            derivs as its last argument: derivs[k] is the (k+1)-th derivative.
        dy0: Initial conditions for lower-order derivatives, length = order-1.
            dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).
        expected_args: Optional set of arg keys the ODE function accesses.
            When provided, validated against the merged args+params at construction time.
    """

    ode: ODECallable
    args: ArgsRegistry
    y0: Tensor
    order: int = 1
    dy0: list[Tensor] = dc_field(default_factory=list)
    expected_args: frozenset[str] | None = None

    def __post_init__(self) -> None:
        if self.order < 1:
            raise ValueError(f"order must be >= 1, got {self.order}")
        if len(self.dy0) != self.order - 1:
            raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

args: ArgsRegistry instance-attribute

dy0: list[Tensor] = dc_field(default_factory=list) class-attribute instance-attribute

expected_args: frozenset[str] | None = None class-attribute instance-attribute

ode: ODECallable instance-attribute

order: int = 1 class-attribute instance-attribute

y0: Tensor instance-attribute

__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None

__post_init__() -> None

Source code in src/anypinn/problems/ode.py
def __post_init__(self) -> None:
    if self.order < 1:
        raise ValueError(f"order must be >= 1, got {self.order}")
    if len(self.dy0) != self.order - 1:
        raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

ResidualsConstraint

Bases: Constraint

Constraint enforcing the ODE residuals. Minimizes \(\lVert \partial y / \partial t - f(t, y) \rVert^2\).

Parameters:

Name Type Description Default
props ODEProperties

ODE properties.

required
fields FieldsRegistry

List of fields.

required
params ParamsRegistry

List of parameters.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ResidualsConstraint(Constraint):
    """
    Constraint enforcing the ODE residuals.
    Minimizes $\\lVert \\partial y / \\partial t - f(t, y) \\rVert^2$.

    Args:
        props: ODE properties.
        fields: List of fields.
        params: List of parameters.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        merged_args: dict[str, object] = {**props.args, **params}
        if props.expected_args is not None:
            missing = props.expected_args - merged_args.keys()
            if missing:
                raise ValueError(
                    f"ODE function expects args {sorted(missing)!r} but they are not in "
                    f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
                )

        self.fields = fields
        self.weight = weight
        self.order = props.order

        self.ode = props.ode

        # add the trainable params as args
        self.args = props.args.copy()
        self.args.update(params)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        _, x_coll = batch

        n_fields = len(self.fields)
        x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
        preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
        y = torch.stack(preds)

        # Build all derivative levels by chaining (each level differentiates the previous)
        # deriv_levels[k][i] = (k+1)-th derivative of field i
        deriv_levels: list[list[Tensor]] = []
        currents = list(preds)
        for _ in range(self.order):
            next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
            deriv_levels.append(next_level)
            currents = next_level

        # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
        derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
        # The order-th derivative is the LHS to compare against f_out
        high_deriv = torch.stack(deriv_levels[self.order - 1])

        if self.order == 1:
            f_out = self.ode(x_coll, y, self.args)
        else:
            f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

        loss: Tensor = self.weight * criterion(high_deriv, f_out)

        if log is not None:
            log("loss/res", loss)

        return loss

args = props.args.copy() instance-attribute

fields = fields instance-attribute

ode = props.ode instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    merged_args: dict[str, object] = {**props.args, **params}
    if props.expected_args is not None:
        missing = props.expected_args - merged_args.keys()
        if missing:
            raise ValueError(
                f"ODE function expects args {sorted(missing)!r} but they are not in "
                f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
            )

    self.fields = fields
    self.weight = weight
    self.order = props.order

    self.ode = props.ode

    # add the trainable params as args
    self.args = props.args.copy()
    self.args.update(params)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    _, x_coll = batch

    n_fields = len(self.fields)
    x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
    preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
    y = torch.stack(preds)

    # Build all derivative levels by chaining (each level differentiates the previous)
    # deriv_levels[k][i] = (k+1)-th derivative of field i
    deriv_levels: list[list[Tensor]] = []
    currents = list(preds)
    for _ in range(self.order):
        next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
        deriv_levels.append(next_level)
        currents = next_level

    # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
    derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
    # The order-th derivative is the LHS to compare against f_out
    high_deriv = torch.stack(deriv_levels[self.order - 1])

    if self.order == 1:
        f_out = self.ode(x_coll, y, self.args)
    else:
        f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

    loss: Tensor = self.weight * criterion(high_deriv, f_out)

    if log is not None:
        log("loss/res", loss)

    return loss