Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove gempyor.seeding Dependence On gempyor.model_info.ModelInfo #422

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/docs/integration_benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "07bfd952-5c42-4704-81cc-5de0c917c0ab",
"metadata": {
"execution": {
Expand Down Expand Up @@ -450,7 +450,7 @@
"\n",
"with Timer(\"onerun_SEIR.seeding\"):\n",
" initial_conditions = s.initial_conditions.get_from_config(sim_id, modinf=s)\n",
" seeding_data, seeding_amounts = s.seeding.get_from_config(sim_id, modinf=s)\n",
" seeding_data, seeding_amounts = s.get_seeding_data(sim_id)\n",
"\n",
"mobility_subpop_indices = s.mobility.indices\n",
"mobility_data_indices = s.mobility.indptr\n",
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/docs/integration_doc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "96230106-73e3-4681-b562-6a7269513375",
"metadata": {
"execution": {
Expand Down Expand Up @@ -109,7 +109,7 @@
"\n",
"\n",
"initial_conditions = gempyor_inference.s.initial_conditions.get_from_config(sim_id2write, modinf=gempyor_inference.s)\n",
"seeding_data, seeding_amounts = gempyor_inference.s.seeding.get_from_config(sim_id2write, modinf=gempyor_inference.s)\n",
"seeding_data, seeding_amounts = gempyor_inference.s.get_seeding_data(sim_id2write)\n",
"\n",
"\n",
"p_draw = gempyor_inference.s.parameters.parameters_quick_draw(\n",
Expand Down
9 changes: 2 additions & 7 deletions flepimop/gempyor_pkg/docs/interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "63618148-81db-4fe0-9395-7f21878c1372",
"metadata": {
"execution": {
Expand Down Expand Up @@ -254,20 +254,15 @@
"\n",
"### Run every time:\n",
"with Timer(\"onerun_SEIR.seeding\"):\n",
" seeding_data, seeding_amounts = gempyor_inference.s.get_seeding_data(sim_id2load if load_ID else sim_id2write)\n",
" if load_ID:\n",
" initial_conditions = gempyor_inference.s.initial_conditions.get_from_file(\n",
" sim_id2load, modinf=gempyor_inference.s\n",
" )\n",
" seeding_data, seeding_amounts = gempyor_inference.s.seeding.get_from_file(\n",
" sim_id2load, modinf=gempyor_inference.s\n",
" )\n",
" else:\n",
" initial_conditions = gempyor_inference.s.initial_conditions.get_from_config(\n",
" sim_id2write, modinf=gempyor_inference.s\n",
" )\n",
" seeding_data, seeding_amounts = gempyor_inference.s.seeding.get_from_config(\n",
" sim_id2write, modinf=gempyor_inference.s\n",
" )\n",
"\n",
"with Timer(\"SEIR.parameters\"):\n",
" # Draw or load parameters\n",
Expand Down
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
out_prefix=prefix,
)

seeding_data = modinf.seeding.get_from_config(sim_id=100, modinf=modinf)
seeding_data = modinf.get_seeding_data(100)
initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf)

mobility_subpop_indices = modinf.mobility.indices
Expand Down
11 changes: 4 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_static_arguments(modinf: model_info.ModelInfo):
)

initial_conditions = modinf.initial_conditions.get_from_config(sim_id=0, modinf=modinf)
seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id=0, modinf=modinf)
seeding_data, seeding_amounts = modinf.get_seeding_data(0)

# reduce them
parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir)
Expand Down Expand Up @@ -672,20 +672,17 @@ def one_simulation(
self.lastsim_parsed_parameters = parsed_parameters

with Timer("onerun_SEIR.seeding"):
seeding_data, seeding_amounts = self.modinf.get_seeding_data(
sim_id2load if load_ID else sim_id2write
)
if load_ID:
initial_conditions = self.modinf.initial_conditions.get_from_file(
sim_id2load, modinf=self.modinf
)
seeding_data, seeding_amounts = self.modinf.seeding.get_from_file(
sim_id2load, modinf=self.modinf
)
else:
initial_conditions = self.modinf.initial_conditions.get_from_config(
sim_id2write, modinf=self.modinf
)
seeding_data, seeding_amounts = self.modinf.seeding.get_from_config(
sim_id2write, modinf=self.modinf
)
self.lastsim_seeding_data = seeding_data
self.lastsim_seeding_amounts = seeding_amounts
self.lastsim_initial_conditions = initial_conditions
Expand Down
68 changes: 46 additions & 22 deletions flepimop/gempyor_pkg/src/gempyor/model_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import datetime
import logging
import os
import pathlib

import confuse
import numba as nb
import numpy as np
import numpy.typing as npt
import pandas as pd
import datetime, os, logging, pathlib, confuse

from . import (
seeding,
subpopulation_structure,
Expand Down Expand Up @@ -296,30 +305,15 @@ def get_output_filename(self, ftype: str, sim_id: int, extension_override: str =
def get_filename(
self, ftype: str, sim_id: int, input: bool, extension_override: str = ""
):
"""return a CSP formated filename."""

if extension_override: # empty strings are Falsy
extension = extension_override
else: # Constructed like this because in some test, extension is not defined
extension = self.extension

if input:
run_id = self.in_run_id
prefix = self.in_prefix
else:
run_id = self.out_run_id
prefix = self.out_prefix

fn = self.path_prefix / file_paths.create_file_name(
run_id=run_id,
prefix=prefix,
index=sim_id + self.first_sim_index - 1,
return self.path_prefix / file_paths.create_file_name(
self.in_run_id if input else self.out_run_id,
self.in_prefix if input else self.out_prefix,
sim_id + self.first_sim_index - 1,
ftype,
extension=extension_override if extension_override else self.extension,
inference_filepath_suffix=self.inference_filepath_suffix,
inference_filename_prefix=self.inference_filename_prefix,
ftype=ftype,
extension=extension,
)
return fn

def get_setup_name(self):
return self.setup_name
Expand Down Expand Up @@ -359,3 +353,33 @@ def write_simID(
df=df,
)
return fname

def get_seeding_data(self, sim_id: int) -> tuple[nb.typed.Dict, npt.NDArray[np.number]]:
"""
Pull the seeding data for the info represented by this model info instance.

Args:
sim_id: The simulation ID to pull seeding data for.

Returns:
A tuple containing the seeding data dictionary and the seeding data array.

See Also:
`gempyor.seeding.Seeding.get_from_config`
"""
return self.seeding.get_from_config(
self.compartments,
self.subpop_struct,
self.n_days,
self.ti,
self.tf,
(
None
if self.seeding_config is None
else self.get_input_filename(
ftype=self.seeding_config["seeding_file_type"].get(),
sim_id=sim_id,
extension_override="csv",
)
),
)
Loading
Loading