Skip to content

anypinn.cli.scaffold.seir.ode_csv

SEIR epidemic model — mathematical definition.

BETA_KEY = 'beta' module-attribute

E0 = 0.01 module-attribute

E_KEY = 'E' module-attribute

GAMMA_KEY = 'gamma' module-attribute

I0 = 0.001 module-attribute

I_KEY = 'I' module-attribute

NOISE_STD = 0.0005 module-attribute

S0 = 0.99 module-attribute

SIGMA_KEY = 'sigma' module-attribute

S_KEY = 'S' module-attribute

TRUE_BETA = 0.5 module-attribute

TRUE_GAMMA = 1 / 10 module-attribute

TRUE_SIGMA = 1 / 5.2 module-attribute

T_DAYS = 160 module-attribute

validation: ValidationRegistry = {} module-attribute

SEIR(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor

Scaled SEIR ODE system.

Source code in src/anypinn/cli/scaffold/seir/ode_csv.py
def SEIR(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
    """Scaled SEIR ODE system."""
    S, E, I = y
    b = args[BETA_KEY]
    sigma = args[SIGMA_KEY]
    gamma = args[GAMMA_KEY]

    dS = -b(x) * S * I
    dE = b(x) * S * I - sigma(x) * E
    dI = sigma(x) * E - gamma(x) * I

    dS = dS * T_DAYS
    dE = dE * T_DAYS
    dI = dI * T_DAYS
    return torch.stack([dS, dE, dI])

create_data_module(hp: ODEHyperparameters)

Source code in src/anypinn/cli/scaffold/seir/ode_csv.py
def create_data_module(hp: ODEHyperparameters):
    from anypinn.catalog.seir import SEIRDataModule

    # Unscaled SEIR for data generation
    def SEIR_unscaled(x: Tensor, y: Tensor, args: ArgsRegistry) -> Tensor:
        S, E, I = y
        b = args[BETA_KEY]
        sigma = args[SIGMA_KEY]
        gamma = args[GAMMA_KEY]
        dS = -b(x) * S * I
        dE = b(x) * S * I - sigma(x) * E
        dI = sigma(x) * E - gamma(x) * I
        return torch.stack([dS, dE, dI])

    gen_props = ODEProperties(
        ode=SEIR_unscaled,
        y0=torch.tensor([S0, E0, I0]),
        args={
            BETA_KEY: Argument(TRUE_BETA),
            SIGMA_KEY: Argument(TRUE_SIGMA),
            GAMMA_KEY: Argument(TRUE_GAMMA),
        },
    )

    return SEIRDataModule(
        hp=hp,
        gen_props=gen_props,
        noise_std=NOISE_STD,
        validation=validation,
        callbacks=[DataScaling(y_scale=1.0)],
    )

create_problem(hp: ODEHyperparameters) -> ODEInverseProblem

Source code in src/anypinn/cli/scaffold/seir/ode_csv.py
def create_problem(hp: ODEHyperparameters) -> ODEInverseProblem:
    props = ODEProperties(
        ode=SEIR,
        y0=torch.tensor([S0, E0, I0]),
        args={
            SIGMA_KEY: Argument(TRUE_SIGMA),
            GAMMA_KEY: Argument(TRUE_GAMMA),
        },
    )

    fields = FieldsRegistry(
        {
            S_KEY: Field(config=hp.fields_config),
            E_KEY: Field(config=hp.fields_config),
            I_KEY: Field(config=hp.fields_config),
        }
    )
    params = ParamsRegistry(
        {
            BETA_KEY: Parameter(config=hp.params_config),
        }
    )

    def predict_data(x_data: Tensor, fields: FieldsRegistry, _params: ParamsRegistry) -> Tensor:
        I_pred = fields[I_KEY](x_data)
        return I_pred.unsqueeze(1)

    return ODEInverseProblem(
        props=props,
        hp=hp,
        fields=fields,
        params=params,
        predict_data=predict_data,
    )