Skip to content

anypinn.problems.ode

ODE constraint types, properties, and the ODEInverseProblem convenience class.

PredictDataFn: TypeAlias = Callable[[Tensor, FieldsRegistry, ParamsRegistry], Tensor] module-attribute

DataConstraint

Bases: Constraint

Constraint enforcing fit to observed data. Minimizes ||y_hat - y||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
params ParamsRegistry

Parameters registry.

required
predict_data PredictDataFn

Function to predict data values from fields.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class DataConstraint(Constraint):
    """
    Constraint enforcing fit to observed data.
    Minimizes ``||y_hat - y||^2``.

    Args:
        fields: Fields registry.
        params: Parameters registry.
        predict_data: Function to predict data values from fields.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
        weight: float = 1.0,
    ):
        self.fields = fields
        self.params = params
        self.predict_data = predict_data
        self.weight = weight

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the data-fitting loss."""
        (x_data, y_data), _ = batch

        y_data_pred = self.predict_data(x_data, self.fields, self.params)

        loss: Tensor = criterion(y_data_pred, y_data)
        loss = self.weight * loss

        if log is not None:
            log("loss/data", loss)

        return loss

fields = fields instance-attribute

params = params instance-attribute

predict_data = predict_data instance-attribute

weight = weight instance-attribute

__init__(fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
    weight: float = 1.0,
):
    self.fields = fields
    self.params = params
    self.predict_data = predict_data
    self.weight = weight

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the data-fitting loss.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the data-fitting loss."""
    (x_data, y_data), _ = batch

    y_data_pred = self.predict_data(x_data, self.fields, self.params)

    loss: Tensor = criterion(y_data_pred, y_data)
    loss = self.weight * loss

    if log is not None:
        log("loss/data", loss)

    return loss

ICConstraint

Bases: Constraint

Constraint enforcing Initial Conditions (IC). Minimizes ||y(t0) - Y0||^2.

Parameters:

Name Type Description Default
fields FieldsRegistry

Fields registry.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ICConstraint(Constraint):
    """
    Constraint enforcing Initial Conditions (IC).
    Minimizes ``||y(t0) - Y0||^2``.

    Args:
        fields: Fields registry.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        self.Y0 = props.y0.clone().reshape(-1, 1, 1)
        self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
        self.order = props.order
        self.fields = fields
        self.weight = weight

    @override
    def inject_context(self, context: InferredContext) -> None:
        """
        Inject the context into the constraint.
        """
        self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the initial-condition loss."""
        device = batch[1].device

        if self.t0.device != device:
            self.t0 = self.t0.to(device)
            self.Y0 = self.Y0.to(device)
            self.dY0 = [d.to(device) for d in self.dY0]

        n_fields = len(self.fields)

        if self.order == 1:
            # Fast path: no requires_grad needed, identical to original behaviour
            Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
            loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
        else:
            x0 = self.t0.detach().requires_grad_(True)
            preds = [f(x0) for f in self.fields.values()]
            Y0_preds = torch.stack(preds)
            total = criterion(Y0_preds, self.Y0)
            # Enforce derivative ICs by chaining from previous level
            currents = list(preds)
            for k in range(self.order - 1):
                next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
                dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
                total = total + criterion(dY0_k_pred, self.dY0[k])
                currents = next_level
            loss = self.weight * total

        if log is not None:
            log("loss/ic", loss)

        return loss

Y0 = props.y0.clone().reshape(-1, 1, 1) instance-attribute

dY0 = [(dy.clone().reshape(-1, 1, 1)) for dy in (props.dy0)] instance-attribute

fields = fields instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    self.Y0 = props.y0.clone().reshape(-1, 1, 1)
    self.dY0 = [dy.clone().reshape(-1, 1, 1) for dy in props.dy0]
    self.order = props.order
    self.fields = fields
    self.weight = weight

inject_context(context: InferredContext) -> None

Inject the context into the constraint.

Source code in src/anypinn/problems/ode.py
@override
def inject_context(self, context: InferredContext) -> None:
    """
    Inject the context into the constraint.
    """
    self.t0 = torch.tensor(context.domain.x0, dtype=torch.float32).reshape(1, 1)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the initial-condition loss.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the initial-condition loss."""
    device = batch[1].device

    if self.t0.device != device:
        self.t0 = self.t0.to(device)
        self.Y0 = self.Y0.to(device)
        self.dY0 = [d.to(device) for d in self.dY0]

    n_fields = len(self.fields)

    if self.order == 1:
        # Fast path: no requires_grad needed, identical to original behaviour
        Y0_preds = torch.stack([f(self.t0) for f in self.fields.values()])
        loss: Tensor = self.weight * criterion(Y0_preds, self.Y0)
    else:
        x0 = self.t0.detach().requires_grad_(True)
        preds = [f(x0) for f in self.fields.values()]
        Y0_preds = torch.stack(preds)
        total = criterion(Y0_preds, self.Y0)
        # Enforce derivative ICs by chaining from previous level
        currents = list(preds)
        for k in range(self.order - 1):
            next_level = [diff_grad(currents[i], x0) for i in range(n_fields)]
            dY0_k_pred = torch.stack(next_level)  # (n_fields, 1, 1)
            total = total + criterion(dY0_k_pred, self.dY0[k])
            currents = next_level
        loss = self.weight * total

    if log is not None:
        log("loss/ic", loss)

    return loss

ODECallable

Bases: Protocol

Protocol for ODE right-hand side callables.

First-order callables (ODEProperties.order == 1) receive three positional arguments and must match this Protocol exactly::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

Higher-order callables (order >= 2) receive a fourth positional argument derivs: list[Tensor], where derivs[k] is the (k+1)-th derivative of all fields stacked as (n_fields, m, 1)::

def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
           derivs: list[Tensor] = []) -> Tensor: ...

The Protocol is intentionally kept to three arguments so that existing first-order callables remain valid ODECallable implementations. ResidualsConstraint uses _ODECallableN internally to call higher-order functions with the correct signature.

Source code in src/anypinn/problems/ode.py
class ODECallable(Protocol):
    """
    Protocol for ODE right-hand side callables.

    **First-order** callables (``ODEProperties.order == 1``) receive three
    positional arguments and must match this Protocol exactly::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

    **Higher-order** callables (``order >= 2``) receive a fourth positional
    argument ``derivs: list[Tensor]``, where ``derivs[k]`` is the
    ``(k+1)``-th derivative of all fields stacked as ``(n_fields, m, 1)``::

        def my_ode(x: Tensor, y: Tensor, args: ArgsRegistry,
                   derivs: list[Tensor] = []) -> Tensor: ...

    The Protocol is intentionally kept to three arguments so that existing
    first-order callables remain valid ``ODECallable`` implementations.
    ``ResidualsConstraint`` uses ``_ODECallableN`` internally to call
    higher-order functions with the correct signature.
    """

    def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

__call__(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Source code in src/anypinn/problems/ode.py
def __call__(self, x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor: ...

ODEHyperparameters dataclass

Bases: PINNHyperparameters

Hyperparameters for ODE inverse problems.

Extends PINNHyperparameters with per-constraint loss weights used by ODEInverseProblem.

Attributes:

Name Type Description
pde_weight float

Weight for the ODE residual loss term.

ic_weight float

Weight for the initial-condition loss term.

data_weight float

Weight for the data-fitting loss term.

Source code in src/anypinn/problems/ode.py
@dataclass(kw_only=True)
class ODEHyperparameters(PINNHyperparameters):
    """
    Hyperparameters for ODE inverse problems.

    Extends ``PINNHyperparameters`` with per-constraint loss weights
    used by ``ODEInverseProblem``.

    Attributes:
        pde_weight: Weight for the ODE residual loss term.
        ic_weight: Weight for the initial-condition loss term.
        data_weight: Weight for the data-fitting loss term.
    """

    pde_weight: float = 1.0
    ic_weight: float = 1.0
    data_weight: float = 1.0

data_weight: float = 1.0 class-attribute instance-attribute

ic_weight: float = 1.0 class-attribute instance-attribute

pde_weight: float = 1.0 class-attribute instance-attribute

__init__(*, lr: float, training_data: IngestionConfig | GenerationConfig, fields_config: MLPConfig, params_config: MLPConfig | ScalarConfig, max_epochs: int | None = None, gradient_clip_val: float | None = None, criterion: Criteria = 'mse', optimizer: AdamConfig | LBFGSConfig | None = None, scheduler: ReduceLROnPlateauConfig | CosineAnnealingConfig | None = None, early_stopping: EarlyStoppingConfig | None = None, smma_stopping: SMMAStoppingConfig | None = None, pde_weight: float = 1.0, ic_weight: float = 1.0, data_weight: float = 1.0) -> None

ODEInverseProblem

Bases: Problem

Convenience class composing Residuals + IC + Data constraints.

Wires together ResidualsConstraint, ICConstraint, and DataConstraint with the loss criterion from hyperparameters. For forward-only problems (no data fitting), compose the constraints manually instead.

Example

problem = ODEInverseProblem( ... props=props, ... hp=ODEHyperparameters(...), ... fields={"S": field_s, "I": field_i, "R": field_r}, ... params={"beta": Parameter(ScalarConfig(init_value=0.3))}, ... predict_data=lambda t, f, p: torch.stack( ... [f"S", f"I", f"R"], dim=1 ... ).squeeze(-1), ... )

Source code in src/anypinn/problems/ode.py
class ODEInverseProblem(Problem):
    """
    Convenience class composing Residuals + IC + Data constraints.

    Wires together ``ResidualsConstraint``, ``ICConstraint``, and
    ``DataConstraint`` with the loss criterion from hyperparameters.
    For forward-only problems (no data fitting), compose the constraints
    manually instead.

    Example:
        >>> problem = ODEInverseProblem(
        ...     props=props,
        ...     hp=ODEHyperparameters(...),
        ...     fields={"S": field_s, "I": field_i, "R": field_r},
        ...     params={"beta": Parameter(ScalarConfig(init_value=0.3))},
        ...     predict_data=lambda t, f, p: torch.stack(
        ...         [f["S"](t), f["I"](t), f["R"](t)], dim=1
        ...     ).squeeze(-1),
        ... )
    """

    def __init__(
        self,
        props: ODEProperties,
        hp: ODEHyperparameters,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        predict_data: PredictDataFn,
    ) -> None:
        constraints: list[Constraint] = [
            ResidualsConstraint(
                props=props,
                fields=fields,
                params=params,
                weight=hp.pde_weight,
            ),
            ICConstraint(
                props=props,
                fields=fields,
                weight=hp.ic_weight,
            ),
            DataConstraint(
                fields=fields,
                params=params,
                predict_data=predict_data,
                weight=hp.data_weight,
            ),
        ]

        criterion = build_criterion(hp.criterion)

        super().__init__(
            constraints=constraints,
            criterion=criterion,
            fields=fields,
            params=params,
        )

__init__(props: ODEProperties, hp: ODEHyperparameters, fields: FieldsRegistry, params: ParamsRegistry, predict_data: PredictDataFn) -> None

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    hp: ODEHyperparameters,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    predict_data: PredictDataFn,
) -> None:
    constraints: list[Constraint] = [
        ResidualsConstraint(
            props=props,
            fields=fields,
            params=params,
            weight=hp.pde_weight,
        ),
        ICConstraint(
            props=props,
            fields=fields,
            weight=hp.ic_weight,
        ),
        DataConstraint(
            fields=fields,
            params=params,
            predict_data=predict_data,
            weight=hp.data_weight,
        ),
    ]

    criterion = build_criterion(hp.criterion)

    super().__init__(
        constraints=constraints,
        criterion=criterion,
        fields=fields,
        params=params,
    )

ODEProperties dataclass

Properties defining an Ordinary Differential Equation problem.

Attributes:

Name Type Description
ode ODECallable

The ODE function (callable).

args ArgsRegistry

Arguments/Parameters for the ODE.

y0 Tensor

Initial conditions.

order int

Order of the ODE (default 1). For order=n, the ODE callable receives derivs as its last argument: derivs[k] is the (k+1)-th derivative.

dy0 list[Tensor]

Initial conditions for lower-order derivatives, length = order-1. dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).

expected_args frozenset[str] | None

Optional set of arg keys the ODE function accesses. When provided, validated against the merged args+params at construction time.

Example

def sir_ode(t, y, args): ... S, I, R = y ... beta, gamma = args"beta", args"gamma" ... N = S + I + R ... dS = -beta * S * I / N ... dI = beta * S * I / N - gamma * I ... dR = gamma * I ... return torch.stack([dS, dI, dR]) props = ODEProperties( ... ode=sir_ode, ... args={"beta": Argument(0.3), "gamma": Argument(0.1)}, ... y0=torch.tensor([0.99, 0.01, 0.0]), ... )

Source code in src/anypinn/problems/ode.py
@dataclass
class ODEProperties:
    """
    Properties defining an Ordinary Differential Equation problem.

    Attributes:
        ode: The ODE function (callable).
        args: Arguments/Parameters for the ODE.
        y0: Initial conditions.
        order: Order of the ODE (default 1). For order=n, the ODE callable receives
            derivs as its last argument: derivs[k] is the (k+1)-th derivative.
        dy0: Initial conditions for lower-order derivatives, length = order-1.
            dy0[k] is the IC for the (k+1)-th derivative, shape (n_fields,).
        expected_args: Optional set of arg keys the ODE function accesses.
            When provided, validated against the merged args+params at construction time.

    Example:
        >>> def sir_ode(t, y, args):
        ...     S, I, R = y
        ...     beta, gamma = args["beta"](t), args["gamma"](t)
        ...     N = S + I + R
        ...     dS = -beta * S * I / N
        ...     dI = beta * S * I / N - gamma * I
        ...     dR = gamma * I
        ...     return torch.stack([dS, dI, dR])
        >>> props = ODEProperties(
        ...     ode=sir_ode,
        ...     args={"beta": Argument(0.3), "gamma": Argument(0.1)},
        ...     y0=torch.tensor([0.99, 0.01, 0.0]),
        ... )
    """

    ode: ODECallable
    args: ArgsRegistry
    y0: Tensor
    order: int = 1
    dy0: list[Tensor] = dc_field(default_factory=list)
    expected_args: frozenset[str] | None = None

    def __post_init__(self) -> None:
        if self.order < 1:
            raise ValueError(f"order must be >= 1, got {self.order}")
        if len(self.dy0) != self.order - 1:
            raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

args: ArgsRegistry instance-attribute

dy0: list[Tensor] = dc_field(default_factory=list) class-attribute instance-attribute

expected_args: frozenset[str] | None = None class-attribute instance-attribute

ode: ODECallable instance-attribute

order: int = 1 class-attribute instance-attribute

y0: Tensor instance-attribute

__init__(ode: ODECallable, args: ArgsRegistry, y0: Tensor, order: int = 1, dy0: list[Tensor] = list(), expected_args: frozenset[str] | None = None) -> None

__post_init__() -> None

Source code in src/anypinn/problems/ode.py
def __post_init__(self) -> None:
    if self.order < 1:
        raise ValueError(f"order must be >= 1, got {self.order}")
    if len(self.dy0) != self.order - 1:
        raise ValueError(f"dy0 must have length order-1={self.order - 1}, got {len(self.dy0)}")

ResidualsConstraint

Bases: Constraint

Constraint enforcing the ODE residuals. Minimizes ||dy/dt - f(t, y)||^2.

Parameters:

Name Type Description Default
props ODEProperties

ODE properties.

required
fields FieldsRegistry

List of fields.

required
params ParamsRegistry

List of parameters.

required
weight float

Weight for this loss term.

1.0
Source code in src/anypinn/problems/ode.py
class ResidualsConstraint(Constraint):
    """
    Constraint enforcing the ODE residuals.
    Minimizes ``||dy/dt - f(t, y)||^2``.

    Args:
        props: ODE properties.
        fields: List of fields.
        params: List of parameters.
        weight: Weight for this loss term.
    """

    def __init__(
        self,
        props: ODEProperties,
        fields: FieldsRegistry,
        params: ParamsRegistry,
        weight: float = 1.0,
    ):
        if len(fields) != len(props.y0):
            raise ValueError(
                f"Number of fields ({len(fields)}) must match number of initial conditions "
                f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
            )

        merged_args: dict[str, object] = {**props.args, **params}
        if props.expected_args is not None:
            missing = props.expected_args - merged_args.keys()
            if missing:
                raise ValueError(
                    f"ODE function expects args {sorted(missing)!r} but they are not in "
                    f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
                )

        self.fields = fields
        self.weight = weight
        self.order = props.order

        self.ode = props.ode

        # add the trainable params as args
        self.args = props.args.copy()
        self.args.update(params)

    @override
    def loss(
        self,
        batch: TrainingBatch,
        criterion: nn.Module,
        log: LogFn | None = None,
    ) -> Tensor:
        """Compute the ODE residual loss over collocation points."""
        _, x_coll = batch

        n_fields = len(self.fields)
        x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
        preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
        y = torch.stack(preds)

        # Build all derivative levels by chaining (each level differentiates the previous)
        # deriv_levels[k][i] = (k+1)-th derivative of field i
        deriv_levels: list[list[Tensor]] = []
        currents = list(preds)
        for _ in range(self.order):
            next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
            deriv_levels.append(next_level)
            currents = next_level

        # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
        derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
        # The order-th derivative is the LHS to compare against f_out
        high_deriv = torch.stack(deriv_levels[self.order - 1])

        if self.order == 1:
            f_out = self.ode(x_coll, y, self.args)
        else:
            f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

        loss: Tensor = self.weight * criterion(high_deriv, f_out)

        if log is not None:
            log("loss/res", loss)

        return loss

args = props.args.copy() instance-attribute

fields = fields instance-attribute

ode = props.ode instance-attribute

order = props.order instance-attribute

weight = weight instance-attribute

__init__(props: ODEProperties, fields: FieldsRegistry, params: ParamsRegistry, weight: float = 1.0)

Source code in src/anypinn/problems/ode.py
def __init__(
    self,
    props: ODEProperties,
    fields: FieldsRegistry,
    params: ParamsRegistry,
    weight: float = 1.0,
):
    if len(fields) != len(props.y0):
        raise ValueError(
            f"Number of fields ({len(fields)}) must match number of initial conditions "
            f"in y0 ({len(props.y0)}). Field keys: {list(fields)}."
        )

    merged_args: dict[str, object] = {**props.args, **params}
    if props.expected_args is not None:
        missing = props.expected_args - merged_args.keys()
        if missing:
            raise ValueError(
                f"ODE function expects args {sorted(missing)!r} but they are not in "
                f"props.args or params. Available keys: {sorted(merged_args.keys())!r}."
            )

    self.fields = fields
    self.weight = weight
    self.order = props.order

    self.ode = props.ode

    # add the trainable params as args
    self.args = props.args.copy()
    self.args.update(params)

loss(batch: TrainingBatch, criterion: nn.Module, log: LogFn | None = None) -> Tensor

Compute the ODE residual loss over collocation points.

Source code in src/anypinn/problems/ode.py
@override
def loss(
    self,
    batch: TrainingBatch,
    criterion: nn.Module,
    log: LogFn | None = None,
) -> Tensor:
    """Compute the ODE residual loss over collocation points."""
    _, x_coll = batch

    n_fields = len(self.fields)
    x_copies = [x_coll.detach().clone().requires_grad_(True) for _ in range(n_fields)]
    preds = [f(x_copies[i]) for i, f in enumerate(self.fields.values())]
    y = torch.stack(preds)

    # Build all derivative levels by chaining (each level differentiates the previous)
    # deriv_levels[k][i] = (k+1)-th derivative of field i
    deriv_levels: list[list[Tensor]] = []
    currents = list(preds)
    for _ in range(self.order):
        next_level = [diff_grad(currents[i], x_copies[i]) for i in range(n_fields)]
        deriv_levels.append(next_level)
        currents = next_level

    # derivs[k] = (k+1)-th derivative stacked across fields, passed to the ODE callable
    derivs = [torch.stack(deriv_levels[k]) for k in range(self.order - 1)]
    # The order-th derivative is the LHS to compare against f_out
    high_deriv = torch.stack(deriv_levels[self.order - 1])

    if self.order == 1:
        f_out = self.ode(x_coll, y, self.args)
    else:
        f_out = cast(_ODECallableN, self.ode)(x_coll, y, self.args, derivs)

    loss: Tensor = self.weight * criterion(high_deriv, f_out)

    if log is not None:
        log("loss/res", loss)

    return loss