diff --git a/flepimop/gempyor_pkg/src/gempyor/cli.py b/flepimop/gempyor_pkg/src/gempyor/cli.py index 405509ba7..529fd2638 100644 --- a/flepimop/gempyor_pkg/src/gempyor/cli.py +++ b/flepimop/gempyor_pkg/src/gempyor/cli.py @@ -39,74 +39,72 @@ def patch(ctx: click.Context = mock_context, **kwargs) -> None: \b ```bash - $ flepimop patch config_sample_2pop_modifiers_part.yml config_sample_2pop_outcomes_part.yml > config_sample_2pop_patched.yml - $ cat config_sample_2pop_patched.yml - outcome_modifiers_scenarios: [] - first_sim_index: 1 - jobs: 14 - stoch_traj_flag: false - seir_modifiers_scenarios: [] - write_parquet: true - write_csv: false - config_src: [config_sample_2pop_modifiers_part.yml, config_sample_2pop_outcomes_part.yml] - seir_modifiers: - scenarios: [Ro_lockdown, Ro_all] - modifiers: - Ro_lockdown: - method: SinglePeriodModifier - parameter: Ro - period_start_date: 2020-03-15 - period_end_date: 2020-05-01 - subpop: all - value: 0.4 - Ro_relax: - method: SinglePeriodModifier - parameter: Ro - period_start_date: 2020-05-01 - period_end_date: 2020-08-31 - subpop: all - value: 0.8 - Ro_all: - method: StackedModifier - modifiers: [Ro_lockdown, Ro_relax] - outcome_modifiers: - scenarios: [test_limits] - modifiers: - test_limits: - method: SinglePeriodModifier - parameter: incidCase::probability - subpop: all - period_start_date: 2020-02-01 - period_end_date: 2020-06-01 - value: 0.5 - outcomes: - method: delayframe - outcomes: - incidCase: - source: - incidence: - infection_stage: I - probability: + $ flepimop patch config_sample_2pop_modifiers_part.yml config_sample_2pop_outcomes_part.yml > config_sample_2pop_patched.yml + $ cat config_sample_2pop_patched.yml + write_csv: false + stoch_traj_flag: false + jobs: 14 + write_parquet: true + first_sim_index: 1 + config_src: [config_sample_2pop_modifiers_part.yml, config_sample_2pop_outcomes_part.yml] + seir_modifiers: + scenarios: [Ro_lockdown, Ro_all] + modifiers: + Ro_lockdown: + method: SinglePeriodModifier + parameter: Ro + period_start_date: 2020-03-15 + period_end_date: 2020-05-01 + subpop: all + value: 0.4 + Ro_relax: + method: SinglePeriodModifier + parameter: Ro + period_start_date: 2020-05-01 + period_end_date: 2020-08-31 + subpop: all + value: 0.8 + Ro_all: + method: StackedModifier + modifiers: [Ro_lockdown, Ro_relax] + outcome_modifiers: + scenarios: [test_limits] + modifiers: + test_limits: + method: SinglePeriodModifier + parameter: incidCase::probability + subpop: all + period_start_date: 2020-02-01 + period_end_date: 2020-06-01 value: 0.5 - delay: - value: 5 - incidHosp: - source: - incidence: - infection_stage: I - probability: - value: 0.05 - delay: - value: 7 - duration: - value: 10 - name: currHosp - incidDeath: - source: incidHosp - probability: - value: 0.2 - delay: - value: 14 + outcomes: + method: delayframe + outcomes: + incidCase: + source: + incidence: + infection_stage: I + probability: + value: 0.5 + delay: + value: 5 + incidHosp: + source: + incidence: + infection_stage: I + probability: + value: 0.05 + delay: + value: 7 + duration: + value: 10 + name: currHosp + incidDeath: + source: incidHosp + probability: + value: 0.2 + delay: + value: 14 ``` """ parse_config_files(config, ctx, **kwargs) diff --git a/flepimop/gempyor_pkg/src/gempyor/shared_cli.py b/flepimop/gempyor_pkg/src/gempyor/shared_cli.py index a88c599dd..367769cde 100644 --- a/flepimop/gempyor_pkg/src/gempyor/shared_cli.py +++ b/flepimop/gempyor_pkg/src/gempyor/shared_cli.py @@ -255,12 +255,14 @@ def _parse_option(param: click.Parameter, value: Any) -> Any: cfg["config_src"] = [str(k) for k in config_src] # deal with the scenario overrides - scen_args = {k for k in parsed_args if k.endswith("scenarios") and kwargs.get(k)} - for option in scen_args: + scen_args = {k for k in parsed_args if k.endswith("_scenarios")} + for option in {s for s in scen_args if kwargs.get(s)}: key = option.replace("_scenarios", "") value = _parse_option(config_file_options[option], kwargs[option]) if cfg[key].exists(): - cfg[key]["scenarios"] = as_list(value) + cfg[key]["scenarios"] = ( + list(value) if isinstance(value, tuple) else as_list(value) + ) else: raise ValueError( f"Specified {option} when no {key} in configuration file(s): {config_src}" diff --git a/flepimop/gempyor_pkg/tests/cli/test_flepimop_patch_cli.py b/flepimop/gempyor_pkg/tests/cli/test_flepimop_patch_cli.py index 35fa57470..b4ba20021 100644 --- a/flepimop/gempyor_pkg/tests/cli/test_flepimop_patch_cli.py +++ b/flepimop/gempyor_pkg/tests/cli/test_flepimop_patch_cli.py @@ -67,3 +67,346 @@ def test_overlapping_sections_value_error( assert str(result.exception) == ( "Configuration files contain overlapping keys, seir, introduced by config_two.yml." ) + + +@pytest.mark.parametrize( + ("data", "seir_modifier_scenarios", "outcome_modifier_scenarios"), + ( + ( + { + "seir_modifiers": { + "scenarios": ["Ro_lockdown", "Ro_all"], + "modifiers": { + "Ro_lockdown": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-03-15", + "period_end_date": "2020-05-01", + "subpop": "all", + "value": 0.4, + }, + "Ro_relax": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-05-01", + "period_end_date": "2020-07-01", + "subpop": "all", + "value": 0.8, + }, + "Ro_all": { + "method": "StackedModifier", + "modifiers": ["Ro_lockdown", "Ro_relax"], + }, + }, + }, + }, + [], + [], + ), + ( + { + "seir_modifiers": { + "scenarios": ["Ro_lockdown", "Ro_all"], + "modifiers": { + "Ro_lockdown": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-03-15", + "period_end_date": "2020-05-01", + "subpop": "all", + "value": 0.4, + }, + "Ro_relax": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-05-01", + "period_end_date": "2020-07-01", + "subpop": "all", + "value": 0.8, + }, + "Ro_all": { + "method": "StackedModifier", + "modifiers": ["Ro_lockdown", "Ro_relax"], + }, + }, + }, + }, + ["Ro_all"], + [], + ), + ( + { + "seir_modifiers": { + "scenarios": ["Ro_lockdown", "Ro_all"], + "modifiers": { + "Ro_lockdown": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-03-15", + "period_end_date": "2020-05-01", + "subpop": "all", + "value": 0.4, + }, + "Ro_relax": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-05-01", + "period_end_date": "2020-07-01", + "subpop": "all", + "value": 0.8, + }, + "Ro_all": { + "method": "StackedModifier", + "modifiers": ["Ro_lockdown", "Ro_relax"], + }, + }, + }, + }, + ["Ro_all", "Ro_relax", "Ro_lockdown"], + [], + ), + ( + { + "outcome_modifiers": { + "scenarios": ["test_limits"], + "modifiers": { + "test_limits": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "subpop": "all", + "period_start_date": "2020-02-01", + "period_end_date": "2020-06-01", + "value": 0.5, + }, + "test_expansion": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "period_start_date": "2020-06-01", + "period_end_date": "2020-08-01", + "subpop": "all", + "value": 0.7, + }, + "test_limits_expansion": { + "method": "StackedModifier", + "modifiers": ["test_limits", "test_expansion"], + }, + }, + }, + }, + [], + [], + ), + ( + { + "outcome_modifiers": { + "scenarios": ["test_limits"], + "modifiers": { + "test_limits": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "subpop": "all", + "period_start_date": "2020-02-01", + "period_end_date": "2020-06-01", + "value": 0.5, + }, + "test_expansion": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "period_start_date": "2020-06-01", + "period_end_date": "2020-08-01", + "subpop": "all", + "value": 0.7, + }, + "test_limits_expansion": { + "method": "StackedModifier", + "modifiers": ["test_limits", "test_expansion"], + }, + }, + }, + }, + [], + ["test_limits_expansion"], + ), + ( + { + "outcome_modifiers": { + "scenarios": ["test_limits"], + "modifiers": { + "test_limits": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "subpop": "all", + "period_start_date": "2020-02-01", + "period_end_date": "2020-06-01", + "value": 0.5, + }, + "test_expansion": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "period_start_date": "2020-06-01", + "period_end_date": "2020-08-01", + "subpop": "all", + "value": 0.7, + }, + "test_limits_expansion": { + "method": "StackedModifier", + "modifiers": ["test_limits", "test_expansion"], + }, + }, + }, + }, + [], + ["test_limits", "test_expansion", "test_limits_expansion"], + ), + ( + { + "seir_modifiers": { + "scenarios": ["Ro_lockdown", "Ro_all"], + "modifiers": { + "Ro_lockdown": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-03-15", + "period_end_date": "2020-05-01", + "subpop": "all", + "value": 0.4, + }, + "Ro_relax": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-05-01", + "period_end_date": "2020-07-01", + "subpop": "all", + "value": 0.8, + }, + "Ro_all": { + "method": "StackedModifier", + "modifiers": ["Ro_lockdown", "Ro_relax"], + }, + }, + }, + "outcome_modifiers": { + "scenarios": ["test_limits"], + "modifiers": { + "test_limits": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "subpop": "all", + "period_start_date": "2020-02-01", + "period_end_date": "2020-06-01", + "value": 0.5, + }, + "test_expansion": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "period_start_date": "2020-06-01", + "period_end_date": "2020-08-01", + "subpop": "all", + "value": 0.7, + }, + "test_limits_expansion": { + "method": "StackedModifier", + "modifiers": ["test_limits", "test_expansion"], + }, + }, + }, + }, + [], + [], + ), + ( + { + "seir_modifiers": { + "scenarios": ["Ro_lockdown", "Ro_all"], + "modifiers": { + "Ro_lockdown": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-03-15", + "period_end_date": "2020-05-01", + "subpop": "all", + "value": 0.4, + }, + "Ro_relax": { + "method": "SinglePeriodModifier", + "parameter": "Ro", + "period_start_date": "2020-05-01", + "period_end_date": "2020-07-01", + "subpop": "all", + "value": 0.8, + }, + "Ro_all": { + "method": "StackedModifier", + "modifiers": ["Ro_lockdown", "Ro_relax"], + }, + }, + }, + "outcome_modifiers": { + "scenarios": ["test_limits"], + "modifiers": { + "test_limits": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "subpop": "all", + "period_start_date": "2020-02-01", + "period_end_date": "2020-06-01", + "value": 0.5, + }, + "test_expansion": { + "method": "SinglePeriodModifier", + "parameter": "incidCase::probability", + "period_start_date": "2020-06-01", + "period_end_date": "2020-08-01", + "subpop": "all", + "value": 0.7, + }, + "test_limits_expansion": { + "method": "StackedModifier", + "modifiers": ["test_limits", "test_expansion"], + }, + }, + }, + }, + ["Ro_relax"], + ["test_expansion"], + ), + ), +) +def test_editing_modifier_scenarios( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + data: dict[str, Any], + seir_modifier_scenarios: list[str], + outcome_modifier_scenarios: list[str], +) -> None: + # Setup the test + monkeypatch.chdir(tmp_path) + config_path = tmp_path / "config.yml" + config_path.write_text(yaml.dump(data)) + + # Invoke the command + runner = CliRunner() + args = [config_path.name] + if seir_modifier_scenarios: + for s in seir_modifier_scenarios: + args += ["--seir_modifiers_scenarios", s] + if outcome_modifier_scenarios: + for o in outcome_modifier_scenarios: + args += ["--outcome_modifiers_scenarios", o] + result = runner.invoke(patch, args) + assert result.exit_code == 0 + + # Check the output + patched_data = yaml.safe_load(result.output) + assert "seir_modifiers_scenarios" not in patched_data + assert patched_data.get("seir_modifiers", {}).get("scenarios", []) == ( + seir_modifier_scenarios + if seir_modifier_scenarios + else data.get("seir_modifiers", {}).get("scenarios", []) + ) + assert "outcome_modifiers_scenarios" not in patched_data + assert patched_data.get("outcome_modifiers", {}).get("scenarios", []) == ( + outcome_modifier_scenarios + if outcome_modifier_scenarios + else data.get("outcome_modifiers", {}).get("scenarios", []) + )