Skip to content

anypinn.cli.scaffold.fitzhugh_nagumo.ode_csv

FitzHugh-Nagumo neuron model — mathematical definition.

A_KEY = 'a' module-attribute

B = 0.8 module-attribute

EPSILON_KEY = 'epsilon' module-attribute

I_EXT = 0.5 module-attribute

NOISE_STD = 0.05 module-attribute

TRUE_A = 0.7 module-attribute

TRUE_EPSILON = 0.08 module-attribute

T_TOTAL = 50 module-attribute

V0 = -1.0 module-attribute

V_KEY = 'v' module-attribute

W0 = 1.0 module-attribute

W_KEY = 'w' module-attribute

validation: ValidationRegistry = {} module-attribute

create_data_module(hp: ODEHyperparameters)

Source code in src/anypinn/cli/scaffold/fitzhugh_nagumo/ode_csv.py
def create_data_module(hp: ODEHyperparameters):
    from anypinn.catalog.fitzhugh_nagumo import FitzHughNagumoDataModule

    def fhn_unscaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
        v, w = y
        eps = args[EPSILON_KEY]
        a = args[A_KEY]
        dv = v - v**3 / 3 - w + I_EXT
        dw = eps(x) * (v + a(x) - B * w)
        return torch.stack([dv, dw])

    gen_props = ODEProperties(
        ode=fhn_unscaled,
        y0=torch.tensor([V0, W0]),
        args={
            EPSILON_KEY: Argument(TRUE_EPSILON),
            A_KEY: Argument(TRUE_A),
        },
    )

    return FitzHughNagumoDataModule(
        hp=hp,
        gen_props=gen_props,
        noise_std=NOISE_STD,
        validation=validation,
    )

create_problem(hp: ODEHyperparameters) -> ODEInverseProblem

Source code in src/anypinn/cli/scaffold/fitzhugh_nagumo/ode_csv.py
def create_problem(hp: ODEHyperparameters) -> ODEInverseProblem:
    props = ODEProperties(
        ode=fhn_scaled,
        y0=torch.tensor([V0, W0]),
        args={},
    )

    fields = FieldsRegistry(
        {
            V_KEY: Field(config=hp.fields_config),
            W_KEY: Field(config=hp.fields_config),
        }
    )
    params = ParamsRegistry(
        {
            EPSILON_KEY: Parameter(config=hp.params_config),
            A_KEY: Parameter(config=hp.params_config),
        }
    )

    def predict_data(x_data: Tensor, fields: FieldsRegistry, _params: ParamsRegistry) -> Tensor:
        v_pred = fields[V_KEY](x_data)
        return v_pred.unsqueeze(1)  # (N, 1, 1) — only v is observed

    return ODEInverseProblem(
        props=props,
        hp=hp,
        fields=fields,
        params=params,
        predict_data=predict_data,
    )

fhn_scaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Scaled FHN ODE for training. Time scaled by T_TOTAL.

Source code in src/anypinn/cli/scaffold/fitzhugh_nagumo/ode_csv.py
def fhn_scaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
    """Scaled FHN ODE for training. Time scaled by T_TOTAL."""
    v, w = y
    eps = args[EPSILON_KEY]
    a = args[A_KEY]

    dv = (v - v**3 / 3 - w + I_EXT) * T_TOTAL
    dw = eps(x) * (v + a(x) - B * w) * T_TOTAL
    return torch.stack([dv, dw])