Skip to content

Commit

Permalink
Add documentation to the Seeding class
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyWillard committed Dec 10, 2024
1 parent 8ba5f0c commit e31cac8
Showing 1 changed file with 66 additions and 12 deletions.
78 changes: 66 additions & 12 deletions flepimop/gempyor_pkg/src/gempyor/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import date
import logging
from typing import Any
import warnings

import confuse
import numba as nb
Expand All @@ -23,9 +24,6 @@
logger = logging.getLogger(__name__)


## TODO: ideally here path_prefix should not be used and all files loaded from modinf


# Internal functionality
def _DataFrame2NumbaDict(
df: pd.DataFrame,
Expand Down Expand Up @@ -114,7 +112,22 @@ def _DataFrame2NumbaDict(

# Exported functionality
class Seeding(SimulationComponent):
"""
Class to handle the seeding of the simulation.
Attributes:
seeding_config: The configuration for the seeding.
path_prefix: The path prefix to use when reading files.
"""

def __init__(self, config: confuse.ConfigView, path_prefix: str = "."):
"""
Initialize a seeding instance.
Args:
config: The configuration for the seeding.
path_prefix: The path prefix to use when reading files.
"""
self.seeding_config = config
self.path_prefix = path_prefix

Expand All @@ -127,6 +140,27 @@ def get_from_config(
tf: date,
input_filename: str | None,
) -> tuple[nb.typed.Dict, npt.NDArray[np.number]]:
"""
Get seeding data from the configuration.
Args:
compartments: The compartments for the simulation.
subpop_struct: The subpopulation structure for the simulation.
n_days: The number of days in the simulation.
ti: The start date of the simulation.
tf: The end date of the simulation.
input_filename: The input filename to use for seeding data. Only used if
the seeding method is 'FolderDraw'.
Returns:
A tuple containing the seeding data as a Numba dictionary and the seeding
amounts as a Numpy array. The seeding data is a dictionary with the
following keys:
- "seeding_sources": The source compartments for the seeding.
- "seeding_destinations": The destination compartments for the seeding.
- "seeding_subpops": The subpopulations for the seeding.
- "day_start_idx": The start index for each day in the seeding data.
"""
method = "NoSeeding"
if self.seeding_config is not None and "method" in self.seeding_config.keys():
method = self.seeding_config["method"].as_str()
Expand Down Expand Up @@ -166,16 +200,10 @@ def get_from_config(
else:
raise ValueError(f"Unknown seeding method given, '{method}'.")

# Sorting by date is very important here for the seeding format necessary !!!!
# print(seeding.shape)
# Sorting by date is important for the seeding format
seeding = seeding.sort_values(by="date", axis="index").reset_index()
# print(seeding)
mask = (seeding["date"].dt.date > ti) & (seeding["date"].dt.date <= tf)
seeding = seeding.loc[mask].reset_index()
# print(seeding.shape)
# print(seeding)

# TODO: print.

amounts = np.zeros(len(seeding))
if method == "PoissonDistributed":
Expand All @@ -195,11 +223,37 @@ def get_from_config(
def get_from_file(
self, *args: Any, **kwargs: Any
) -> tuple[nb.typed.Dict, npt.NDArray[np.number]]:
"""only difference with draw seeding is that the sim_id is now sim_id2load"""
"""
This method is deprecated. Use `get_from_config` instead.
Args:
*args: Positional arguments to pass to `get_from_config`.
**kwargs: Keyword arguments to pass to `get_from_config`.
Returns:
The result of `get_from_config`.
"""
warnings.warn(
"The 'get_from_file' method is deprecated. Use 'get_from_config' instead.",
DeprecationWarning,
)
return self.get_from_config(*args, **kwargs)


def SeedingFactory(config: confuse.ConfigView, path_prefix: str = "."):
def SeedingFactory(config: confuse.ConfigView, path_prefix: str = ".") -> Seeding:
"""
Create a Seeding instance based on the given configuration.
This function will use the given configuration to either lookup a plugin class for
the seeding instance or fallback to the default Seeding class.
Args:
config: The configuration for the seeding.
path_prefix: The path prefix to use when reading files.
Returns:
A Seeding instance.
"""
if config is not None and "method" in config.keys():
if config["method"].as_str() == "plugin":
klass = utils.search_and_import_plugins_class(
Expand Down

0 comments on commit e31cac8

Please sign in to comment.