Skip to content

anypinn.cli.scaffold.sir.ode

SIR epidemic model — mathematical definition.

BETA_KEY = 'beta' module-attribute

C = 1000000.0 module-attribute

DELTA = 1 / 5 module-attribute

DELTA_KEY = 'delta' module-attribute

I_KEY = 'I' module-attribute

N_KEY = 'N' module-attribute

N_POP = 56000000.0 module-attribute

S_KEY = 'S' module-attribute

T = 90 module-attribute

TRUE_BETA = 0.6 module-attribute

validation_csv: ValidationRegistry = {BETA_KEY: lambda x: torch.full_like(x, TRUE_BETA)} module-attribute

validation_synthetic: ValidationRegistry = {BETA_KEY: lambda x: torch.full_like(x, TRUE_BETA)} module-attribute

SIR(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Scaled SIR ODE system.

Source code in src/anypinn/cli/scaffold/sir/ode.py
def SIR(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
    """Scaled SIR ODE system."""
    S, I = y
    b, d, N = args[BETA_KEY], args[DELTA_KEY], args[N_KEY]

    dS = -b(x) * I * S * C / N(x)
    dI = b(x) * I * S * C / N(x) - d(x) * I

    dS = dS * T
    dI = dI * T
    return torch.stack([dS, dI])

create_data_module_csv(hp: ODEHyperparameters)

Source code in src/anypinn/cli/scaffold/sir/ode.py
def create_data_module_csv(hp: ODEHyperparameters):
    from anypinn.catalog.sir import SIRInvDataModule

    return SIRInvDataModule(
        hp=hp,
        validation=validation_csv,
        callbacks=[DataScaling(y_scale=1 / C)],
    )

create_data_module_synthetic(hp: ODEHyperparameters)

Source code in src/anypinn/cli/scaffold/sir/ode.py
def create_data_module_synthetic(hp: ODEHyperparameters):
    from anypinn.catalog.sir import SIRInvDataModule

    # Unscaled SIR for data generation (no C or T scaling)
    def SIR_unscaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
        S, I = y
        b, d, N = args[BETA_KEY], args[DELTA_KEY], args[N_KEY]
        dS = -b(x) * S * I / N(x)
        dI = b(x) * S * I / N(x) - d(x) * I
        return torch.stack([dS, dI])

    gen_props = ODEProperties(
        ode=SIR_unscaled,
        y0=torch.tensor([N_POP - 1, 1]),
        args={
            BETA_KEY: Argument(TRUE_BETA),
            DELTA_KEY: Argument(DELTA),
            N_KEY: Argument(N_POP),
        },
    )

    return SIRInvDataModule(
        hp=hp,
        gen_props=gen_props,
        validation=validation_synthetic,
        callbacks=[DataScaling(y_scale=1 / C)],
    )

create_problem(hp: ODEHyperparameters) -> ODEInverseProblem

Source code in src/anypinn/cli/scaffold/sir/ode.py
def create_problem(hp: ODEHyperparameters) -> ODEInverseProblem:
    props = ODEProperties(
        ode=SIR,
        y0=torch.tensor([N_POP - 1, 1]) / C,
        args={
            DELTA_KEY: Argument(DELTA),
            N_KEY: Argument(N_POP),
        },
    )

    fields = FieldsRegistry(
        {
            S_KEY: Field(config=hp.fields_config),
            I_KEY: Field(config=hp.fields_config),
        }
    )
    params = ParamsRegistry(
        {
            BETA_KEY: Parameter(config=hp.params_config),
        }
    )

    def predict_data(x_data: Tensor, fields: FieldsRegistry, _params: ParamsRegistry) -> Tensor:
        I_pred = fields[I_KEY](x_data)
        return I_pred.unsqueeze(1)

    return ODEInverseProblem(
        props=props,
        hp=hp,
        fields=fields,
        params=params,
        predict_data=predict_data,
    )