Skip to content

anypinn.cli.scaffold.damped_oscillator.ode_csv

Damped oscillator — mathematical definition.

NOISE_STD = 0.02 module-attribute

OMEGA_KEY = 'omega0' module-attribute

TRUE_OMEGA0 = 2 * math.pi module-attribute

TRUE_ZETA = 0.15 module-attribute

T_TOTAL = 5 module-attribute

V0 = 0.0 module-attribute

V_KEY = 'v' module-attribute

X0 = 1.0 module-attribute

X_KEY = 'x' module-attribute

ZETA_KEY = 'zeta' module-attribute

validation: ValidationRegistry = {} module-attribute

create_data_module(hp: ODEHyperparameters)

Source code in src/anypinn/cli/scaffold/damped_oscillator/ode_csv.py
def create_data_module(hp: ODEHyperparameters):
    from anypinn.catalog.damped_oscillator import DampedOscillatorDataModule

    def oscillator_unscaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
        pos, vel = y
        z = args[ZETA_KEY]
        omega0 = args[OMEGA_KEY]
        dx = vel
        dv = -2 * z(x) * omega0(x) * vel - omega0(x) ** 2 * pos
        return torch.stack([dx, dv])

    gen_props = ODEProperties(
        ode=oscillator_unscaled,
        y0=torch.tensor([X0, V0]),
        args={
            ZETA_KEY: Argument(TRUE_ZETA),
            OMEGA_KEY: Argument(TRUE_OMEGA0),
        },
    )

    return DampedOscillatorDataModule(
        hp=hp,
        gen_props=gen_props,
        noise_std=NOISE_STD,
        validation=validation,
        callbacks=[DataScaling(y_scale=1.0)],
    )

create_problem(hp: ODEHyperparameters) -> ODEInverseProblem

Source code in src/anypinn/cli/scaffold/damped_oscillator/ode_csv.py
def create_problem(hp: ODEHyperparameters) -> ODEInverseProblem:
    props = ODEProperties(
        ode=oscillator,
        y0=torch.tensor([X0, V0]),
        args={
            OMEGA_KEY: Argument(TRUE_OMEGA0),
        },
    )

    fields = FieldsRegistry(
        {
            X_KEY: Field(config=hp.fields_config),
            V_KEY: Field(config=hp.fields_config),
        }
    )
    params = ParamsRegistry(
        {
            ZETA_KEY: Parameter(config=hp.params_config),
        }
    )

    def predict_data(x_data: Tensor, fields: FieldsRegistry, _params: ParamsRegistry) -> Tensor:
        x_pred = fields[X_KEY](x_data)
        return x_pred.unsqueeze(1)

    return ODEInverseProblem(
        props=props,
        hp=hp,
        fields=fields,
        params=params,
        predict_data=predict_data,
    )

oscillator(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Scaled damped oscillator ODE: \(dx/dt = v\), \(dv/dt = -2 zeta omega_0 v - omega_0^2 x\).

Source code in src/anypinn/cli/scaffold/damped_oscillator/ode_csv.py
def oscillator(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
    """Scaled damped oscillator ODE: $dx/dt = v$, $dv/dt = -2 zeta omega_0 v - omega_0^2 x$."""
    pos, vel = y
    z = args[ZETA_KEY]
    omega0 = args[OMEGA_KEY]

    dx = vel
    dv = -2 * z(x) * omega0(x) * vel - omega0(x) ** 2 * pos

    dx = dx * T_TOTAL
    dv = dv * T_TOTAL
    return torch.stack([dx, dv])