Skip to content

anypinn.core.problem

Core problem abstractions for PINN.

Constraint

Bases: ABC

Abstract base class for a constraint (loss term) in the PINN. Returns a loss value for the given batch.

Source code in src/anypinn/core/problem.py
class Constraint(ABC):
    """
    Abstract base class for a constraint (loss term) in the PINN.
    Returns a loss value for the given batch.
    """

    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the constraint. This can be used by the constraint to access the
        data used to compute the loss.

        Args:
            context: The context to inject.
        """
        return None

    @abstractmethod
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """
        Calculate the loss for this constraint.

        Args:
            batch: The current batch of data/collocation points.
            criterion: The loss function (e.g. MSE).
            log: Optional logging function.

        Returns:
            The calculated loss tensor.
        """

inject_context(context: InferredContext) -> None

Inject the context into the constraint. This can be used by the constraint to access the data used to compute the loss.

Parameters:

Name Type Description Default
context InferredContext

The context to inject.

required
Source code in src/anypinn/core/problem.py
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the constraint. This can be used by the constraint to access the
    data used to compute the loss.

    Args:
        context: The context to inject.
    """
    return None

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor abstractmethod

Calculate the loss for this constraint.

Parameters:

Name Type Description Default
batch TrainingBatch

The current batch of data/collocation points.

required
criterion Module

The loss function (e.g. MSE).

required
log LogFn | None

Optional logging function.

None

Returns:

Type Description
Tensor

The calculated loss tensor.

Source code in src/anypinn/core/problem.py
@abstractmethod
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """
    Calculate the loss for this constraint.

    Args:
        batch: The current batch of data/collocation points.
        criterion: The loss function (e.g. MSE).
        log: Optional logging function.

    Returns:
        The calculated loss tensor.
    """

Problem

Bases: Module

Aggregates operator residuals and constraints into total loss. Manages fields, parameters, constraints, and validation.

Parameters:

Name Type Description Default
constraints list[Constraint]

List of constraints to enforce.

required
criterion Module

Loss function module.

required
fields FieldsRegistry

List of fields (neural networks) to solve for.

required
params ParamsRegistry

List of learnable parameters.

required
Source code in src/anypinn/core/problem.py
class Problem(nn.Module):
    """
    Aggregates operator residuals and constraints into total loss.
    Manages fields, parameters, constraints, and validation.

    Args:
        constraints: List of constraints to enforce.
        criterion: Loss function module.
        fields: List of fields (neural networks) to solve for.
        params: List of learnable parameters.
    """

    def __init__(
        self,
        constraints: list[Constraint],
        criterion: nn.Module,
        fields: FieldsRegistry,
        params: ParamsRegistry,
    ):
        super().__init__()
        self.constraints = constraints
        self.criterion = criterion
        self.fields = fields
        self.params = params

        self._fields = nn.ModuleList(fields.values())
        self._params = nn.ModuleList(params.values())

    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the problem.

        This should be called after data is loaded but before training starts.
        Pure function entries are passed through unchanged.

        Args:
            context: The context to inject.
        """
        self.context = context
        for c in self.constraints:
            c.inject_context(context)

    def training_loss(self, batch: TrainingBatch, log: LogFn | None = None) -> Tensor:
        """
        Calculate the total loss from all constraints.

        Args:
            batch: Current batch.
            log: Optional logging function.

        Returns:
            Sum of losses from all constraints.
        """
        _, x_coll = batch

        if not self.constraints:
            total = torch.tensor(0.0, device=x_coll.device)
        else:
            losses = iter(self.constraints)
            total = next(losses).loss(batch, self.criterion, log)
            for c in losses:
                total = total + c.loss(batch, self.criterion, log)

        if log is not None:
            for name, param in self.params.items():
                param_loss = self._param_validation_loss(name, param, x_coll)
                if param_loss is not None:
                    log(f"loss/{name}", param_loss, progress_bar=True)

            log(LOSS_KEY, total, progress_bar=True)

        return total

    def predict(self, batch: DataBatch) -> tuple[DataBatch, dict[str, Tensor]]:
        """
        Generate predictions for a given batch of data.
        Returns unscaled predictions in original domain.

        Args:
            batch: Batch of input coordinates.

        Returns:
            Tuple of (original_batch, predictions_dict).
        """

        x, y = batch

        n = x.shape[0]
        preds = {name: f(x).reshape(n, -1).squeeze(-1) for name, f in self.fields.items()}
        preds |= {name: p(x).reshape(n, -1).squeeze(-1) for name, p in self.params.items()}

        return (x.squeeze(-1), y.squeeze(-1)), preds

    def true_values(self, x: Tensor) -> dict[str, Tensor] | None:
        """
        Get the true values for a given x coordinates.
        Returns None if no validation source is configured.
        """

        return {
            name: p_true.reshape(x.shape[0], -1).squeeze(-1)
            for name, p in self.params.items()
            if (p_true := self._get_true_param(name, x)) is not None
        } or None

    def _get_true_param(self, param_name: str, x: Tensor) -> Tensor | None:
        """
        Get the ground truth values for a parameter at given coordinates.

        Args:
            param_name: Name of the parameter.
            x: Input coordinates.

        Returns:
            Ground truth values, or None if no validation source is configured.
        """
        if param_name not in self.context.validation:
            return None

        fn = self.context.validation[param_name]

        if isinstance(fn, _ColumnLookup):
            domain = self.context.domain
            if domain.dx is None:
                raise ValueError(
                    f"Cannot perform ColumnRef lookup for '{param_name}': "
                    "domain step size (dx) is unknown. Ensure the domain was inferred from "
                    "a uniformly-spaced coordinate tensor, or use a callable validation source."
                )
            x_idx = ((x.squeeze(-1) - domain.x0) / domain.dx[0]).round().unsqueeze(-1)
            return fn(x_idx)

        return fn(x)

    @torch.no_grad()
    def _param_validation_loss(
        self, param_name: str, param: Parameter, x_coll: Tensor
    ) -> Tensor | None:
        """
        Compute validation loss for a parameter against ground truth.

        Args:
            param: The parameter to compute validation loss for.
            x_coll: The input coordinates.

        Returns:
            Loss value, or None if no validation source is configured.
        """
        true = self._get_true_param(param_name, x_coll)
        if true is None:
            return None

        pred = param(x_coll)

        return torch.mean((true - pred) ** 2)

constraints = constraints instance-attribute

criterion = criterion instance-attribute

fields = fields instance-attribute

params = params instance-attribute

__init__(constraints: list[Constraint], criterion: nn.Module, fields: FieldsRegistry, params: ParamsRegistry)

Source code in src/anypinn/core/problem.py
def __init__(
    self,
    constraints: list[Constraint],
    criterion: nn.Module,
    fields: FieldsRegistry,
    params: ParamsRegistry,
):
    super().__init__()
    self.constraints = constraints
    self.criterion = criterion
    self.fields = fields
    self.params = params

    self._fields = nn.ModuleList(fields.values())
    self._params = nn.ModuleList(params.values())

inject_context(context: InferredContext) -> None

Inject the context into the problem.

This should be called after data is loaded but before training starts. Pure function entries are passed through unchanged.

Parameters:

Name Type Description Default
context InferredContext

The context to inject.

required
Source code in src/anypinn/core/problem.py
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the problem.

    This should be called after data is loaded but before training starts.
    Pure function entries are passed through unchanged.

    Args:
        context: The context to inject.
    """
    self.context = context
    for c in self.constraints:
        c.inject_context(context)

predict(batch: DataBatch) -> tuple[DataBatch, dict[str, Tensor]]

Generate predictions for a given batch of data. Returns unscaled predictions in original domain.

Parameters:

Name Type Description Default
batch DataBatch

Batch of input coordinates.

required

Returns:

Type Description
tuple[DataBatch, dict[str, Tensor]]

Tuple of (original_batch, predictions_dict).

Source code in src/anypinn/core/problem.py
def predict(self, batch: DataBatch) -> tuple[DataBatch, dict[str, Tensor]]:
    """
    Generate predictions for a given batch of data.
    Returns unscaled predictions in original domain.

    Args:
        batch: Batch of input coordinates.

    Returns:
        Tuple of (original_batch, predictions_dict).
    """

    x, y = batch

    n = x.shape[0]
    preds = {name: f(x).reshape(n, -1).squeeze(-1) for name, f in self.fields.items()}
    preds |= {name: p(x).reshape(n, -1).squeeze(-1) for name, p in self.params.items()}

    return (x.squeeze(-1), y.squeeze(-1)), preds

training_loss(batch: TrainingBatch, log: LogFn | None = None) -> Tensor

Calculate the total loss from all constraints.

Parameters:

Name Type Description Default
batch TrainingBatch

Current batch.

required
log LogFn | None

Optional logging function.

None

Returns:

Type Description
Tensor

Sum of losses from all constraints.

Source code in src/anypinn/core/problem.py
def training_loss(self, batch: TrainingBatch, log: LogFn | None = None) -> Tensor:
    """
    Calculate the total loss from all constraints.

    Args:
        batch: Current batch.
        log: Optional logging function.

    Returns:
        Sum of losses from all constraints.
    """
    _, x_coll = batch

    if not self.constraints:
        total = torch.tensor(0.0, device=x_coll.device)
    else:
        losses = iter(self.constraints)
        total = next(losses).loss(batch, self.criterion, log)
        for c in losses:
            total = total + c.loss(batch, self.criterion, log)

    if log is not None:
        for name, param in self.params.items():
            param_loss = self._param_validation_loss(name, param, x_coll)
            if param_loss is not None:
                log(f"loss/{name}", param_loss, progress_bar=True)

        log(LOSS_KEY, total, progress_bar=True)

    return total

true_values(x: Tensor) -> dict[str, Tensor] | None

Get the true values for a given x coordinates. Returns None if no validation source is configured.

Source code in src/anypinn/core/problem.py
def true_values(self, x: Tensor) -> dict[str, Tensor] | None:
    """
    Get the true values for a given x coordinates.
    Returns None if no validation source is configured.
    """

    return {
        name: p_true.reshape(x.shape[0], -1).squeeze(-1)
        for name, p in self.params.items()
        if (p_true := self._get_true_param(name, x)) is not None
    } or None