diff --git a/.github/workflows/flepicommon-ci.yml b/.github/workflows/flepicommon-ci.yml index 9f8abc089..6f9613fd7 100644 --- a/.github/workflows/flepicommon-ci.yml +++ b/.github/workflows/flepicommon-ci.yml @@ -16,6 +16,7 @@ on: jobs: tests: runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.draft == false strategy: matrix: R-version: ["4.3.3"] diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml index e2637f1af..6ef20ce74 100644 --- a/.github/workflows/gempyor-ci.yml +++ b/.github/workflows/gempyor-ci.yml @@ -18,6 +18,7 @@ on: jobs: tests: runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.draft == false strategy: matrix: python-version: ["3.10", "3.11"] diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml index d80a2e735..35e4d379e 100644 --- a/.github/workflows/inference-ci.yml +++ b/.github/workflows/inference-ci.yml @@ -20,7 +20,7 @@ on: jobs: tests: runs-on: ubuntu-latest - if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }} + if: (github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success') && (github.event_name != 'pull_request' || github.event.pull_request.draft == false) strategy: matrix: R-version: ["4.3.3"] diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..e2449af27 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,41 @@ +name: Lint + +on: + workflow_dispatch: + push: + paths: + - 'flepimop/gempyor_pkg/**/*.py' + pull_request: + paths: + - '**/*.py' + branches: + - main + +jobs: + black-for-python: + runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.draft == false + env: + BLACK_LINE_LENGTH: 92 + BLACK_EXTEND_EXCLUDE: 'flepimop/gempyor_pkg/src/gempyor/steps_rk4.py' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + sparse-checkout: | + * + !documentation/ + sparse-checkout-cone-mode: false + - name: Determine Source + run: | + if [ ${{ github.event_name }} == "push" ]; then + echo "BLACK_SRC=flepimop/gempyor_pkg/" >> $GITHUB_ENV + else + echo "BLACK_SRC=." >> $GITHUB_ENV + fi + - name: Black Formatter Check + uses: psf/black@stable + with: + src: ${{ env.BLACK_SRC }} + options: "--line-length ${{ env.BLACK_LINE_LENGTH }} --extend-exclude '${{ env.BLACK_EXTEND_EXCLUDE }}' --check --verbose" diff --git a/batch/inference_job_launcher.py b/batch/inference_job_launcher.py index b69d223a5..884f210ca 100755 --- a/batch/inference_job_launcher.py +++ b/batch/inference_job_launcher.py @@ -403,23 +403,37 @@ def launch_batch( if "scenarios" in config["outcome_modifiers"]: outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"] - handler.launch(job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios) + handler.launch( + job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios + ) # Set job_name as environmental variable so it can be pulled for pushing to git os.environ["job_name"] = job_name # Set run_id as environmental variable so it can be pulled for pushing to git TODO - (rc, txt) = subprocess.getstatusoutput(f"git checkout -b run_{job_name}") # TODO: cd ... + (rc, txt) = subprocess.getstatusoutput( + f"git checkout -b run_{job_name}" + ) # TODO: cd ... print(txt) return rc -def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, num_blocks=None, batch_system=None): +def autodetect_params( + config, + data_path, + *, + num_jobs=None, + sims_per_job=None, + num_blocks=None, + batch_system=None, +): if num_jobs and sims_per_job and num_blocks: return (num_jobs, sims_per_job, num_blocks) if "inference" not in config or "iterations_per_slot" not in config["inference"]: - raise click.UsageError("inference::iterations_per_slot undefined in config, can't autodetect parameters") + raise click.UsageError( + "inference::iterations_per_slot undefined in config, can't autodetect parameters" + ) iterations_per_slot = int(config["inference"]["iterations_per_slot"]) if num_jobs is None: @@ -429,11 +443,17 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if sims_per_job is None: if num_blocks is not None: sims_per_job = int(math.ceil(iterations_per_slot / num_blocks)) - print(f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]") - print(f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]" + ) + print( + f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]" + ) else: if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) geodata_fname = pathlib.Path(data_path) / config["subpop_setup"]["geodata"] with open(geodata_fname) as geodata_fp: num_subpops = sum(1 for line in geodata_fp) @@ -458,7 +478,9 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if num_blocks is None: num_blocks = int(math.ceil(iterations_per_slot / sims_per_job)) - print(f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]" + ) return (num_jobs, sims_per_job, num_blocks) @@ -478,7 +500,9 @@ def get_aws_job_queues(job_queue_prefix): return sorted(queues_with_jobs, key=queues_with_jobs.get) -def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False): +def aws_countfiles_autodetect_runid( + s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False +): import boto3 s3 = boto3.resource("s3") @@ -487,15 +511,21 @@ def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_fr all_files = list(bucket.objects.filter(Prefix=prefix)) all_files = [f.key for f in all_files] if restart_from_run_id is None: - print("WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎...") + print( + "WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎..." + ) restart_from_run_id = all_files[0].split("/")[3] - if user_confirmation(question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", default=True): + if user_confirmation( + question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", default=True + ): print(f"great, continuing with run_id {restart_from_run_id}...") else: raise ValueError(f"Abording, please specify --restart_from_run_id manually.") final_llik = [f for f in all_files if ("llik" in f) and ("final" in f)] - if len(final_llik) == 0: # hacky: there might be a bucket with no llik files, e.g if init. + if ( + len(final_llik) == 0 + ): # hacky: there might be a bucket with no llik files, e.g if init. final_llik = [f for f in all_files if ("init" in f) and ("final" in f)] if len(final_llik) != num_jobs: @@ -583,8 +613,12 @@ def build_job_metadata(self, job_name): manifest = {} manifest["cmd"] = " ".join(sys.argv[:]) manifest["job_name"] = job_name - manifest["data_sha"] = subprocess.getoutput("cd {self.data_path}; git rev-parse HEAD") - manifest["flepimop_sha"] = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") + manifest["data_sha"] = subprocess.getoutput( + "cd {self.data_path}; git rev-parse HEAD" + ) + manifest["flepimop_sha"] = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse HEAD" + ) # Save the manifest file to S3 with open("manifest.json", "w") as f: @@ -594,17 +628,25 @@ def build_job_metadata(self, job_name): # need these to be uploaded so they can be executed. this_file_path = os.path.dirname(os.path.realpath(__file__)) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_runner.sh"), destination=f"{job_name}-runner.sh" + source=os.path.join(this_file_path, "AWS_inference_runner.sh"), + destination=f"{job_name}-runner.sh", ) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_copy.sh"), destination=f"{job_name}-copy.sh" + source=os.path.join(this_file_path, "AWS_inference_copy.sh"), + destination=f"{job_name}-copy.sh", ) tarfile_name = f"{job_name}.tar.gz" self.tar_working_dir(tarfile_name=tarfile_name) - self.save_file(source=tarfile_name, destination=f"{job_name}.tar.gz", remove_source=True) + self.save_file( + source=tarfile_name, destination=f"{job_name}.tar.gz", remove_source=True + ) - self.save_file(source="manifest.json", destination=f"{job_name}/manifest.json", remove_source=True) + self.save_file( + source="manifest.json", + destination=f"{job_name}/manifest.json", + remove_source=True, + ) def tar_working_dir(self, tarfile_name): # this tar file always has the structure: @@ -616,10 +658,13 @@ def tar_working_dir(self, tarfile_name): or q == "covid-dashboard-app" or q == "renv.cache" or q == "sample_data" - or q == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything + or q + == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything or q.startswith(".") ): - tar.add(os.path.join(self.flepi_path, q), arcname=os.path.join("flepiMoP", q)) + tar.add( + os.path.join(self.flepi_path, q), arcname=os.path.join("flepiMoP", q) + ) elif q == "sample_data": for r in os.listdir(os.path.join(self.flepi_path, "sample_data")): if r != "united-states-commutes": @@ -629,10 +674,17 @@ def tar_working_dir(self, tarfile_name): ) # tar.add(os.path.join("flepiMoP", "sample_data", r)) for p in os.listdir(self.data_path): - if not (p.startswith(".") or p.endswith("tar.gz") or p in self.outputs or p == "flepiMoP"): + if not ( + p.startswith(".") + or p.endswith("tar.gz") + or p in self.outputs + or p == "flepiMoP" + ): tar.add( p, - filter=lambda x: None if os.path.basename(x.name).startswith(".") else x, + filter=lambda x: ( + None if os.path.basename(x.name).startswith(".") else x + ), ) tar.close() @@ -656,7 +708,13 @@ def save_file(self, source, destination, remove_source=False, prefix=""): if remove_source: os.remove(source) - def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios): + def launch( + self, + job_name, + config_filepath, + seir_modifiers_scenarios, + outcome_modifiers_scenarios, + ): s3_results_path = f"s3://{self.s3_bucket}/{job_name}" if self.batch_system == "slurm": @@ -676,7 +734,10 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo ## TODO: check how each of these variables are used downstream base_env_vars = [ {"name": "BATCH_SYSTEM", "value": self.batch_system}, - {"name": "S3_MODEL_PROJECT_PATH", "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz"}, + { + "name": "S3_MODEL_PROJECT_PATH", + "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz", + }, {"name": "DVC_OUTPUTS", "value": " ".join(self.outputs)}, {"name": "S3_RESULTS_PATH", "value": s3_results_path}, {"name": "FS_RESULTS_PATH", "value": fs_results_path}, @@ -700,14 +761,22 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo }, {"name": "FLEPI_STOCHASTIC_RUN", "value": str(self.stochastic)}, {"name": "FLEPI_RESET_CHIMERICS", "value": str(self.reset_chimerics)}, - {"name": "FLEPI_MEM_PROFILE", "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE"))}, - {"name": "FLEPI_MEM_PROF_ITERS", "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50"))}, + { + "name": "FLEPI_MEM_PROFILE", + "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE")), + }, + { + "name": "FLEPI_MEM_PROF_ITERS", + "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50")), + }, {"name": "SLACK_CHANNEL", "value": str(self.slack_channel)}, ] with open(config_filepath) as f: config = yaml.full_load(f) - for ctr, (s, d) in enumerate(itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios)): + for ctr, (s, d) in enumerate( + itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios) + ): cur_job_name = f"{job_name}_{s}_{d}" # Create first job cur_env_vars = base_env_vars.copy() @@ -719,7 +788,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": "1"}) cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) if not (self.restart_from_location is None): - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{self.restart_from_location}"}) + cur_env_vars.append( + {"name": "LAST_JOB_OUTPUT", "value": f"{self.restart_from_location}"} + ) cur_env_vars.append( { "name": "OLD_FLEPI_RUN_INDEX", @@ -732,8 +803,18 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo if self.continuation: cur_env_vars.append({"name": "FLEPI_CONTINUATION", "value": f"TRUE"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_RUN_ID", "value": f"{self.continuation_run_id}"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_LOCATION", "value": f"{self.continuation_location}"}) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_RUN_ID", + "value": f"{self.continuation_run_id}", + } + ) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_LOCATION", + "value": f"{self.continuation_location}", + } + ) cur_env_vars.append( { "name": "FLEPI_CONTINUATION_FTYPE", @@ -814,7 +895,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo postprod_command, command_name="sbatch postprod", fail_on_fail=True ) postprod_job_id = stdout.decode().split(" ")[-1][:-1] - print(f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}") + print( + f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}" + ) elif self.batch_system == "local": cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}"}) @@ -831,12 +914,24 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars = base_env_vars.copy() cur_env_vars.append({"name": "FLEPI_SEIR_SCENARIOS", "value": s}) cur_env_vars.append({"name": "FLEPI_OUTCOME_SCENARIOS", "value": d}) - cur_env_vars.append({"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"}) - cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"}) - cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"}) - cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}_block{block_idx}"}) + cur_env_vars.append( + {"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"} + ) + cur_env_vars.append( + {"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"} + ) + cur_env_vars.append( + {"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"} + ) + cur_env_vars.append( + {"name": "JOB_NAME", "value": f"{cur_job_name}_block{block_idx}"} + ) cur_job = batch_client.submit_job( jobName=f"{cur_job_name}_block{block_idx}", jobQueue=cur_job_queue, @@ -895,19 +990,29 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo em = "" if self.resume_discard_seeding: em = f", discarding seeding results." - print(f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}") + print( + f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}" + ) if self.batch_system == "aws": print(f" >> Final output will be: {s3_results_path}/model_output/") elif self.batch_system == "slurm": print(f" >> Final output will be: {fs_results_path}/model_output/") if self.s3_upload: - print(f" >> Final output will be uploaded to {s3_results_path}/model_output/") + print( + f" >> Final output will be uploaded to {s3_results_path}/model_output/" + ) if self.continuation: - print(f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}") + print( + f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}" + ) print(f" >> Run id is {self.run_id}") print(f" >> config is {config_filepath.split('/')[-1]}") - flepimop_branch = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD") - data_branch = subprocess.getoutput(f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD") + flepimop_branch = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD" + ) + data_branch = subprocess.getoutput( + f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD" + ) data_hash = subprocess.getoutput(f"cd {self.data_path}; git rev-parse HEAD") flepimop_hash = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") print(f""" >> FLEPIMOP branch is {flepimop_branch} with hash {flepimop_hash}""") diff --git a/batch/scenario_job.py b/batch/scenario_job.py index 1961974eb..62907a370 100755 --- a/batch/scenario_job.py +++ b/batch/scenario_job.py @@ -196,13 +196,17 @@ def launch_job_inner( tarfile_name = f"{job_name}.tar.gz" tar = tarfile.open(tarfile_name, "w:gz") for p in os.listdir("."): - if not (p.startswith(".") or p.endswith("tar.gz") or p in dvc_outputs or p == "batch"): + if not ( + p.startswith(".") or p.endswith("tar.gz") or p in dvc_outputs or p == "batch" + ): tar.add(p, filter=lambda x: None if x.name.startswith(".") else x) tar.close() # Upload the tar'd contents of this directory and the runner script to S3 runner_script_name = f"{job_name}-runner.sh" - local_runner_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh") + local_runner_script = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh" + ) s3_client = boto3.client("s3") s3_client.upload_file(local_runner_script, s3_input_bucket, runner_script_name) s3_client.upload_file(tarfile_name, s3_input_bucket, tarfile_name) @@ -219,7 +223,9 @@ def launch_job_inner( {"name": "S3_RESULTS_PATH", "value": results_path}, {"name": "SLOTS_PER_JOB", "value": str(slots_per_job)}, ] - s3_cp_run_script = f"aws s3 cp s3://{s3_input_bucket}/{runner_script_name} $PWD/run-flepimop-inference" + s3_cp_run_script = ( + f"aws s3 cp s3://{s3_input_bucket}/{runner_script_name} $PWD/run-flepimop-inference" + ) command = ["sh", "-c", f"{s3_cp_run_script}; /bin/bash $PWD/run-flepimop-inference"] container_overrides = { "vcpus": vcpu, @@ -246,7 +252,9 @@ def launch_job_inner( containerOverrides=container_overrides, ) - print(f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}") + print( + f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}" + ) def get_dvc_outputs(): diff --git a/bin/lint b/bin/lint new file mode 100755 index 000000000..50b31272b --- /dev/null +++ b/bin/lint @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +black --line-length 92 \ + --extend-exclude 'flepimop/gempyor_pkg/src/gempyor/steps_rk4.py' \ + --verbose . diff --git a/bin/pre-commit b/bin/pre-commit new file mode 100755 index 000000000..f9fdd15e6 --- /dev/null +++ b/bin/pre-commit @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +if which black > /dev/null 2>&1; then + black --line-length 92 \ + --extend-exclude 'flepimop/gempyor_pkg/src/gempyor/steps_rk4.py' \ + --check --verbose . +else + echo "'black' is not available so python files will not be checked." +fi diff --git a/documentation/gitbook/development/python-guidelines-for-developers.md b/documentation/gitbook/development/python-guidelines-for-developers.md index 59f2d3c67..75a7e60ab 100644 --- a/documentation/gitbook/development/python-guidelines-for-developers.md +++ b/documentation/gitbook/development/python-guidelines-for-developers.md @@ -61,21 +61,23 @@ and to run just some subset of the tests (e.g here just the outcome tests), use: pytest -vvvv -k outcomes ``` -{% hint style="danger" %} -Before committing, make sure you **format your code** using black (see below) and that the **test passes** (see above). -{% endhint %} +For more details on how to use `pytest` please refer to their [usage guide](https://docs.pytest.org/en/latest/how-to/usage.html). ### Formatting -We try to remain close to Python conventions and to follow the updated rules and best practices. For formatting, we use [black](https://github.com/psf/black), the _Uncompromising Code Formatter_ before submitting pull requests. It provides a consistent style, which is useful when diffing. We use a custom length of 120 characters as the baseline is short for scientific code. Here is the line to use to format your code: +We try to remain close to Python conventions and to follow the updated rules and best practices. For formatting, we use [black](https://github.com/psf/black), the _Uncompromising Code Formatter_ before submitting pull requests. It provides a consistent style, which is useful when diffing. We use a custom length of 92 characters as the baseline is short for scientific code. Here is the line to use to format your code: ```bash -black --line-length 120 . --exclude renv* +black --line-length 92 \ + --extend-exclude 'flepimop/gempyor_pkg/src/gempyor/steps_rk4.py' \ + --verbose . ``` -{% hint style="warning" %} -Please use type-hints as much as possible, as we are trying to move towards static checks. -{% endhint %} +For those using a Mac or Linux system for development this command is also available for use by calling `./dev/lint`. Similarly, you can take advantage of the formatting pre-commit hook found at `bin/pre-commit`. To start using it copy this file to your git hooks folder: + +```bash +cp -f bin/pre-commit .git/hooks/ +``` #### Structure of the main classes diff --git a/examples/test_cli.py b/examples/test_cli.py index 349f7219a..857ad8249 100644 --- a/examples/test_cli.py +++ b/examples/test_cli.py @@ -1,4 +1,3 @@ - from click.testing import CliRunner from gempyor.simulate import simulate import os @@ -6,33 +5,35 @@ # See here to test click application https://click.palletsprojects.com/en/8.1.x/testing/ # would be useful to also call the command directly + def test_config_sample_2pop(): - os.chdir(os.path.dirname(__file__) + "/tutorials") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorials") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "config_sample_2pop.yml"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output def test_sample_2pop_modifiers(): - os.chdir(os.path.dirname(__file__) + "/tutorials") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop_modifiers.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorials") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "config_sample_2pop_modifiers.yml"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output + def test_simple_usa_statelevel(): - os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'simple_usa_statelevel.yml', '-n', '1']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output \ No newline at end of file + os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "simple_usa_statelevel.yml", "-n", "1"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py index ecbbba962..84414e7ff 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -68,14 +69,20 @@ def __init__( # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first if self.sanitize: - too_early = min([min(i) for i in self.parameters["start_date"]]) < self.start_date + too_early = ( + min([min(i) for i in self.parameters["start_date"]]) < self.start_date + ) too_late = max([max(i) for i in self.parameters["end_date"]]) > self.end_date if too_early or too_late: - raise ValueError("at least one period start or end date is not between global dates") + raise ValueError( + "at least one period start or end date is not between global dates" + ) for grp_config in npi_config["groups"]: affected_subpops_grp = self.__get_affected_subpops_grp(grp_config) - for sub_index in range(len(self.parameters["start_date"][affected_subpops_grp[0]])): + for sub_index in range( + len(self.parameters["start_date"][affected_subpops_grp[0]]) + ): period_range = pd.date_range( self.parameters["start_date"][affected_subpops_grp[0]][sub_index], self.parameters["end_date"][affected_subpops_grp[0]][sub_index], @@ -111,7 +118,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -153,7 +162,9 @@ def __createFromConfig(self, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) # print(self.name, this_spatial_groups) @@ -209,7 +220,9 @@ def __createFromDf(self, loaded_df, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) for subpop in this_spatial_group["ungrouped"]: @@ -227,7 +240,9 @@ def __createFromDf(self, loaded_df, npi_config): for subpop in group: self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates - self.parameters.at[subpop, "value"] = loaded_df.at[",".join(group), "value"] + self.parameters.at[subpop, "value"] = loaded_df.at[ + ",".join(group), "value" + ] else: dist = npi_config["value"].as_random_distribution() drawn_value = dist(size=1) @@ -258,11 +273,16 @@ def __get_affected_subpops(self, npi_config): affected_subpops_grp += [str(n.get()) for n in grp_config["subpop"]] affected_subpops = set(affected_subpops_grp) if len(affected_subpops) != len(affected_subpops_grp): - raise ValueError(f"In NPI {self.name}, some subpops belong to several groups. This is unsupported.") + raise ValueError( + f"In NPI {self.name}, some subpops belong to several groups. This is unsupported." + ) return affected_subpops def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -278,7 +298,9 @@ def getReductionToWrite(self): # self.parameters.index is a list of subpops for this_spatial_groups in self.spatial_groups: # spatially ungrouped dataframe - df_ungroup = self.parameters[self.parameters.index.isin(this_spatial_groups["ungrouped"])].copy() + df_ungroup = self.parameters[ + self.parameters.index.isin(this_spatial_groups["ungrouped"]) + ].copy() df_ungroup.index.name = "subpop" df_ungroup["start_date"] = df_ungroup["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) @@ -301,7 +323,9 @@ def getReductionToWrite(self): "start_date": df_group["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) ), - "end_date": df_group["end_date"].apply(lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l])), + "end_date": df_group["end_date"].apply( + lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) + ), "value": df_group["value"], } ).set_index("subpop") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py index e078ddeba..cdda3c4b9 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -60,15 +61,22 @@ def __init__( self.__createFromConfig(npi_config) # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first - if self.parameters["start_date"].min() < self.start_date or self.parameters["end_date"].max() > self.end_date: - raise ValueError(f"""{self.name} : at least one period start or end date is not between global dates""") + if ( + self.parameters["start_date"].min() < self.start_date + or self.parameters["end_date"].max() > self.end_date + ): + raise ValueError( + f"""{self.name} : at least one period start or end date is not between global dates""" + ) # for index in self.parameters.index: # period_range = pd.date_range(self.parameters["start_date"][index], self.parameters["end_date"][index]) ## This the line that does the work # self.npi_old.loc[index, period_range] = np.tile(self.parameters["value"][index], (len(period_range), 1)).T - period_range = pd.date_range(self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0]) + period_range = pd.date_range( + self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0] + ) self.npi.loc[self.parameters.index, period_range] = np.tile( self.parameters["value"][:], (len(period_range), 1) ).T @@ -90,7 +98,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -122,13 +132,19 @@ def __createFromConfig(self, npi_config): self.parameters["modifier_name"] = self.name self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) self.parameters["parameter"] = self.param_name - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = self.dist( size=len(self.spatial_groups["ungrouped"]) @@ -153,10 +169,14 @@ def __createFromDf(self, loaded_df, npi_config): # self.parameters = loaded_df[["modifier_name", "start_date", "end_date", "parameter", "value"]].copy() # dates are picked from config self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) ## This is more legible to me, but if we change it here, we should change it in __createFromConfig as well # if npi_config["period_start_date"].exists(): @@ -175,17 +195,24 @@ def __createFromDf(self, loaded_df, npi_config): # TODO: to be consistent with MTR, we want to also draw the values for the subpops # that are not in the loaded_df. - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = loaded_df.loc[ self.spatial_groups["ungrouped"], "value" ] if self.spatial_groups["grouped"]: for group in self.spatial_groups["grouped"]: - self.parameters.loc[group, "value"] = loaded_df.loc[",".join(group), "value"] + self.parameters.loc[group, "value"] = loaded_df.loc[ + ",".join(group), "value" + ] def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -198,7 +225,9 @@ def getReduction(self, param): def getReductionToWrite(self): # spatially ungrouped dataframe - df = self.parameters[self.parameters.index.isin(self.spatial_groups["ungrouped"])].copy() + df = self.parameters[ + self.parameters.index.isin(self.spatial_groups["ungrouped"]) + ].copy() df.index.name = "subpop" df["start_date"] = df["start_date"].astype("str") df["end_date"] = df["end_date"].astype("str") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py index 489a48fbb..6cf178735 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py @@ -47,7 +47,9 @@ def __init__( if isinstance(scenario, str): settings = modifiers_library.get(scenario) if settings is None: - raise RuntimeError(f"couldn't find scenario in config file [got: {scenario}]") + raise RuntimeError( + f"couldn't find scenario in config file [got: {scenario}]" + ) # via profiling: faster to recreate the confuse view than to fetch+resolve due to confuse isinstance # checks scenario_npi_config = confuse.RootView([settings]) @@ -68,12 +70,16 @@ def __init__( ) new_params = sub_npi.param_name # either a list (if stacked) or a string - new_params = [new_params] if isinstance(new_params, str) else new_params # convert to list + new_params = ( + [new_params] if isinstance(new_params, str) else new_params + ) # convert to list # Add each parameter at first encounter, with a neutral start for new_p in new_params: if new_p not in self.param_name: self.param_name.append(new_p) - if new_p in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",new_p): + if ( + new_p in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",new_p): self.reductions[new_p] = 0 else: # for the reductionprod and product method, the initial neutral is 1 ) self.reductions[new_p] = 1 @@ -81,7 +87,9 @@ def __init__( for param in self.param_name: # Get reduction return a neutral value for this overlap operation if no parameeter exists reduction = sub_npi.getReduction(param) - if param in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",param): + if ( + param in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",param): self.reductions[param] += reduction elif param in pnames_overlap_operation_reductionprod: self.reductions[param] *= 1 - reduction @@ -104,7 +112,9 @@ def __init__( self.reduction_params.clear() for param in self.param_name: - if param in pnames_overlap_operation_reductionprod: # re.match("^transition_rate \d+$",param): + if ( + param in pnames_overlap_operation_reductionprod + ): # re.match("^transition_rate \d+$",param): self.reductions[param] = 1 - self.reductions[param] # check that no NPI is called several times, and retourn them @@ -124,7 +134,10 @@ def __checkErrors(self): # ) def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py index f964d4c6e..e18f43f28 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py @@ -43,16 +43,22 @@ def get_spatial_groups(grp_config, affected_subpops: list) -> dict: else: spatial_groups["grouped"] = grp_config["subpop_groups"].get() spatial_groups["ungrouped"] = list( - set(affected_subpops) - set(flatten_list_of_lists(spatial_groups["grouped"])) + set(affected_subpops) + - set(flatten_list_of_lists(spatial_groups["grouped"])) ) # flatten the list of lists of grouped subpops, so we can do some checks flat_grouped_list = flatten_list_of_lists(spatial_groups["grouped"]) # check that all subpops are either grouped or ungrouped if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops): - print("set of grouped and ungrouped subpops", set(flat_grouped_list + spatial_groups["ungrouped"])) + print( + "set of grouped and ungrouped subpops", + set(flat_grouped_list + spatial_groups["ungrouped"]), + ) print("set of affected subpops ", set(affected_subpops)) - raise ValueError(f"The two above sets are differs for for intervention with config \n {grp_config}") + raise ValueError( + f"The two above sets are differs for for intervention with config \n {grp_config}" + ) if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len( flat_grouped_list + spatial_groups["ungrouped"] ): diff --git a/flepimop/gempyor_pkg/src/gempyor/calibrate.py b/flepimop/gempyor_pkg/src/gempyor/calibrate.py index e5cb287aa..875db12ca 100644 --- a/flepimop/gempyor_pkg/src/gempyor/calibrate.py +++ b/flepimop/gempyor_pkg/src/gempyor/calibrate.py @@ -158,7 +158,9 @@ def calibrate( # TODO here for resume if resume or resume_location is not None: - print("Doing a resume, this only work with the same number of slot and parameters right now") + print( + "Doing a resume, this only work with the same number of slot and parameters right now" + ) p0 = None if resume_location is not None: backend = emcee.backends.HDFBackend(resume_location) @@ -195,7 +197,10 @@ def calibrate( # plotting the chain sampler = emcee.backends.HDFBackend(filename, read_only=True) gempyor.postprocess_inference.plot_chains( - inferpar=gempyor_inference.inferpar, sampler_output=sampler, sampled_slots=None, save_to=f"{run_id}_chains.pdf" + inferpar=gempyor_inference.inferpar, + sampler_output=sampler, + sampled_slots=None, + save_to=f"{run_id}_chains.pdf", ) print("EMCEE Run done, doing sampling") @@ -203,11 +208,14 @@ def calibrate( shutil.rmtree(project_path + "model_output/", ignore_errors=True) max_indices = np.argsort(sampler.get_log_prob()[-1, :])[-nsamples:] - samples = sampler.get_chain()[-1, max_indices, :] # the last iteration, for selected slots + samples = sampler.get_chain()[ + -1, max_indices, : + ] # the last iteration, for selected slots gempyor_inference.set_save(True) with multiprocessing.Pool(ncpu) as pool: results = pool.starmap( - gempyor_inference.get_logloss_as_single_number, [(samples[i, :],) for i in range(len(max_indices))] + gempyor_inference.get_logloss_as_single_number, + [(samples[i, :],) for i in range(len(max_indices))], ) # results = [] # for fn in gempyor.utils.list_filenames(folder="model_output/", filters=[run_id, "hosp.parquet"]): diff --git a/flepimop/gempyor_pkg/src/gempyor/compartments.py b/flepimop/gempyor_pkg/src/gempyor/compartments.py index ec87cf7e5..9bce1a6be 100644 --- a/flepimop/gempyor_pkg/src/gempyor/compartments.py +++ b/flepimop/gempyor_pkg/src/gempyor/compartments.py @@ -13,7 +13,13 @@ class Compartments: # Minimal object to be easily picklable for // runs - def __init__(self, seir_config=None, compartments_config=None, compartments_file=None, transitions_file=None): + def __init__( + self, + seir_config=None, + compartments_config=None, + compartments_file=None, + transitions_file=None, + ): self.times_set = 0 ## Something like this is needed for check script: @@ -29,7 +35,7 @@ def __init__(self, seir_config=None, compartments_config=None, compartments_file return def constructFromConfig(self, seir_config, compartment_config): - """ + """ This method is called by the constructor if the compartments are not loaded from a file. It will parse the compartments and transitions from the configuration files. It will populate self.compartments and self.transitions. @@ -43,7 +49,7 @@ def __eq__(self, other): ).all().all() def parse_compartments(self, seir_config, compartment_config): - """ Parse the compartments from the configuration file: + """Parse the compartments from the configuration file: seir_config: the configuration file for the SEIR model compartment_config: the configuration file for the compartments Example: if config says: @@ -75,12 +81,16 @@ def parse_compartments(self, seir_config, compartment_config): else: compartment_df = pd.merge(compartment_df, tmp, on="key") compartment_df = compartment_df.drop(["key"], axis=1) - compartment_df["name"] = compartment_df.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1) + compartment_df["name"] = compartment_df.apply( + lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1 + ) return compartment_df def parse_transitions(self, seir_config, fake_config=False): rc = reduce( - lambda a, b: pd.concat([a, self.parse_single_transition(seir_config, b, fake_config)]), + lambda a, b: pd.concat( + [a, self.parse_single_transition(seir_config, b, fake_config)] + ), seir_config["transitions"], pd.DataFrame(), ) @@ -93,12 +103,17 @@ def check_transition_element(self, single_transition_config, problem_dimension=N def check_transition_elements(self, single_transition_config, problem_dimension): return True - def access_original_config_by_multi_index(self, config_piece, index, dimension=None, encapsulate_as_list=False): + def access_original_config_by_multi_index( + self, config_piece, index, dimension=None, encapsulate_as_list=False + ): if dimension is None: dimension = [None for i in index] tmp = [y for y in zip(index, range(len(index)), dimension)] tmp = zip(index, range(len(index)), dimension) - tmp = [list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp] + tmp = [ + list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) + for x in tmp + ] return tmp def expand_transition_elements(self, single_transition_config, problem_dimension): @@ -108,7 +123,9 @@ def expand_transition_elements(self, single_transition_config, problem_dimension # replace "source" by the actual source from the config for p_idx in range(proportion_size): if new_transition_config["proportional_to"][p_idx] == "source": - new_transition_config["proportional_to"][p_idx] = new_transition_config["source"] + new_transition_config["proportional_to"][p_idx] = new_transition_config[ + "source" + ] temp_array = np.zeros(problem_dimension) @@ -117,43 +134,77 @@ def expand_transition_elements(self, single_transition_config, problem_dimension new_transition_config["rate"] = np.zeros(problem_dimension, dtype=object) new_transition_config["proportional_to"] = np.zeros(problem_dimension, dtype=object) - new_transition_config["proportion_exponent"] = np.zeros(problem_dimension, dtype=object) + new_transition_config["proportion_exponent"] = np.zeros( + problem_dimension, dtype=object + ) - it = np.nditer(temp_array, flags=["multi_index"]) # it is an iterator that will go through all the indexes of the array + it = np.nditer( + temp_array, flags=["multi_index"] + ) # it is an iterator that will go through all the indexes of the array for x in it: try: - new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index) + new_transition_config["source"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["source"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `source:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `source:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e try: - new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index) + new_transition_config["destination"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["destination"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `destination:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `destination:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - + try: - new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index) + new_transition_config["rate"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["rate"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `rate:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `rate:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e @@ -173,43 +224,66 @@ def expand_transition_elements(self, single_transition_config, problem_dimension ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - - if "proportion_exponent" in single_transition_config: # if proportion_exponent is not defined, it is set to 1 + + if ( + "proportion_exponent" in single_transition_config + ): # if proportion_exponent is not defined, it is set to 1 try: self.access_original_config_by_multi_index( single_transition_config["proportion_exponent"][0], it.multi_index, problem_dimension, ) - new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string( - [ - self.access_original_config_by_multi_index( - single_transition_config["proportion_exponent"][p_idx], - it.multi_index, - problem_dimension, - ) - for p_idx in range(proportion_size) - ] + new_transition_config["proportion_exponent"][it.multi_index] = ( + list_recursive_convert_to_string( + [ + self.access_original_config_by_multi_index( + single_transition_config["proportion_exponent"][p_idx], + it.multi_index, + problem_dimension, + ) + for p_idx in range(proportion_size) + ] + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e else: - new_transition_config["proportion_exponent"][it.multi_index] = ["1"] * proportion_size + new_transition_config["proportion_exponent"][it.multi_index] = [ + "1" + ] * proportion_size return new_transition_config def format_source(self, source_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column + ) + ] return rc def unformat_source(self, source_column): @@ -231,7 +305,12 @@ def unformat_destination(self, destination_column): return rc def format_rate(self, rate_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column + ) + ] return rc def unformat_rate(self, rate_column, compartment_dimension): @@ -251,7 +330,9 @@ def format_proportional_to(self, proportional_to_column): lambda x: reduce( lambda a, b: str(a) + "_" + str(b), map( - lambda x: reduce(lambda a, b: str(a) + "+" + str(b), as_list(x)), + lambda x: reduce( + lambda a, b: str(a) + "+" + str(b), as_list(x) + ), x, ), ), @@ -284,7 +365,9 @@ def format_proportion_exponent(self, proportion_exponent_column): ] return rc - def unformat_proportion_exponent(self, proportion_exponent_column, compartment_dimension): + def unformat_proportion_exponent( + self, proportion_exponent_column, compartment_dimension + ): rc = [x.split("%*%") for x in proportion_exponent_column] for row in range(len(rc)): rc[row] = [x.split("*", maxsplit=compartment_dimension - 1) for x in rc[row]] @@ -293,18 +376,28 @@ def unformat_proportion_exponent(self, proportion_exponent_column, compartment_d elem.append(1) return rc - def parse_single_transition(self, seir_config, single_transition_config, fake_config=False): + def parse_single_transition( + self, seir_config, single_transition_config, fake_config=False + ): ## This method relies on having run parse_compartments if not fake_config: single_transition_config = single_transition_config.get() self.check_transition_element(single_transition_config["source"]) self.check_transition_element(single_transition_config["destination"]) - source_dimension = [get_list_dimension(x) for x in single_transition_config["source"]] - destination_dimension = [get_list_dimension(x) for x in single_transition_config["destination"]] - problem_dimension = reduce(lambda x, y: max(x, y), (source_dimension, destination_dimension)) + source_dimension = [ + get_list_dimension(x) for x in single_transition_config["source"] + ] + destination_dimension = [ + get_list_dimension(x) for x in single_transition_config["destination"] + ] + problem_dimension = reduce( + lambda x, y: max(x, y), (source_dimension, destination_dimension) + ) self.check_transition_elements(single_transition_config, problem_dimension) - transitions = self.expand_transition_elements(single_transition_config, problem_dimension) + transitions = self.expand_transition_elements( + single_transition_config, problem_dimension + ) tmp_array = np.zeros(problem_dimension) it = np.nditer(tmp_array, flags=["multi_index"]) @@ -317,7 +410,9 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co "destination": [transitions["destination"][it.multi_index]], "rate": [transitions["rate"][it.multi_index]], "proportional_to": [transitions["proportional_to"][it.multi_index]], - "proportion_exponent": [transitions["proportion_exponent"][it.multi_index]], + "proportion_exponent": [ + transitions["proportion_exponent"][it.multi_index] + ], }, index=[0], ) @@ -328,7 +423,10 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co return rc def toFile( - self, compartments_file="compartments.parquet", transitions_file="transitions.parquet", write_parquet=True + self, + compartments_file="compartments.parquet", + transitions_file="transitions.parquet", + write_parquet=True, ): out_df = self.compartments.copy() if write_parquet: @@ -342,7 +440,9 @@ def toFile( out_df["destination"] = self.format_destination(out_df["destination"]) out_df["rate"] = self.format_rate(out_df["rate"]) out_df["proportional_to"] = self.format_proportional_to(out_df["proportional_to"]) - out_df["proportion_exponent"] = self.format_proportion_exponent(out_df["proportion_exponent"]) + out_df["proportion_exponent"] = self.format_proportion_exponent( + out_df["proportion_exponent"] + ) if write_parquet: pa_df = pa.Table.from_pandas(out_df, preserve_index=False) pa.parquet.write_table(pa_df, transitions_file) @@ -355,9 +455,15 @@ def fromFile(self, compartments_file, transitions_file): self.transitions = pq.read_table(transitions_file).to_pandas() compartment_dimension = self.compartments.shape[1] - 1 self.transitions["source"] = self.unformat_source(self.transitions["source"]) - self.transitions["destination"] = self.unformat_destination(self.transitions["destination"]) - self.transitions["rate"] = self.unformat_rate(self.transitions["rate"], compartment_dimension) - self.transitions["proportional_to"] = self.unformat_proportional_to(self.transitions["proportional_to"]) + self.transitions["destination"] = self.unformat_destination( + self.transitions["destination"] + ) + self.transitions["rate"] = self.unformat_rate( + self.transitions["rate"], compartment_dimension + ) + self.transitions["proportional_to"] = self.unformat_proportional_to( + self.transitions["proportional_to"] + ) self.transitions["proportion_exponent"] = self.unformat_proportion_exponent( self.transitions["proportion_exponent"], compartment_dimension ) @@ -371,7 +477,9 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i :param comp_dict: :return: """ - mask = pd.concat([self.compartments[k] == v for k, v in comp_dict.items()], axis=1).all(axis=1) + mask = pd.concat( + [self.compartments[k] == v for k, v in comp_dict.items()], axis=1 + ).all(axis=1) comp_idx = self.compartments[mask].index.values if len(comp_idx) != 1: raise ValueError( @@ -382,10 +490,11 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i def get_ncomp(self) -> int: return len(self.compartments) - def get_transition_array(self): with Timer("SEIR.compartments"): - transition_array = np.zeros((self.transitions.shape[1], self.transitions.shape[0]), dtype="int64") + transition_array = np.zeros( + (self.transitions.shape[1], self.transitions.shape[0]), dtype="int64" + ) for cit, colname in enumerate(("source", "destination")): for it, elem in enumerate(self.transitions[colname]): elem = reduce(lambda a, b: a + "_" + b, elem) @@ -395,7 +504,9 @@ def get_transition_array(self): rc = compartment if rc == -1: print(self.compartments) - raise ValueError(f"Could not find {colname} defined by {elem} in compartments") + raise ValueError( + f"Could not find {colname} defined by {elem} in compartments" + ) transition_array[cit, it] = rc unique_strings = [] @@ -417,8 +528,12 @@ def get_transition_array(self): # parenthesis are now supported # assert reduce(lambda a, b: a and b, [(x.find("(") == -1) for x in unique_strings]) # assert reduce(lambda a, b: a and b, [(x.find(")") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings]) + assert reduce( + lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings] + ) + assert reduce( + lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings] + ) for it, elem in enumerate(self.transitions["rate"]): candidate = reduce(lambda a, b: a + "*" + b, elem) @@ -454,8 +569,12 @@ def get_transition_array(self): # rc = compartment # if rc == -1: # raise ValueError(f"Could not find match for {elem3} in compartments") - proportion_info[0][current_proportion_sum_it] = current_proportion_sum_start - proportion_info[1][current_proportion_sum_it] = current_proportion_sum_start + len(elem_tmp) + proportion_info[0][ + current_proportion_sum_it + ] = current_proportion_sum_start + proportion_info[1][current_proportion_sum_it] = ( + current_proportion_sum_start + len(elem_tmp) + ) current_proportion_sum_it += 1 current_proportion_sum_start += len(elem_tmp) proportion_compartment_index = 0 @@ -490,7 +609,9 @@ def get_transition_array(self): if self.compartments["name"][compartment] == elem3: rc = compartment if rc == -1: - raise ValueError(f"Could not find proportional_to {elem3} in compartments") + raise ValueError( + f"Could not find proportional_to {elem3} in compartments" + ) proportion_array[proportion_index] = rc proportion_index += 1 @@ -528,18 +649,24 @@ def get_transition_array(self): def parse_parameters(self, parameters, parameter_names, unique_strings): # parsed_parameters_old = self.parse_parameter_strings_to_numpy_arrays(parameters, parameter_names, unique_strings) - parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2(parameters, parameter_names, unique_strings) + parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2( + parameters, parameter_names, unique_strings + ) # for i in range(len(unique_strings)): # print(unique_strings[i], (parsed_parameters[i]==parsed_parameters_old[i]).all()) return parsed_parameters - def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names, string_list): + def parse_parameter_strings_to_numpy_arrays_v2( + self, parameters, parameter_names, string_list + ): # is using eval a better way ??? import sympy as sp # Validate input lengths if len(parameters) != len(parameter_names): - raise ValueError("Number of parameter values does not match the number of parameter names.") + raise ValueError( + "Number of parameter values does not match the number of parameter names." + ) # Define the symbols used in the formulas symbolic_parameters_namespace = {name: sp.symbols(name) for name in parameter_names} @@ -554,11 +681,15 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names f = sp.sympify(formula, locals=symbolic_parameters_namespace) parsed_formulas.append(f) except Exception as e: - print(f"Cannot parse formula: '{formula}' from parameters {parameter_names}") + print( + f"Cannot parse formula: '{formula}' from parameters {parameter_names}" + ) raise (e) # Print the error message for debugging # the list order needs to be right. - parameter_values = {param: value for param, value in zip(symbolic_parameters, parameters)} + parameter_values = { + param: value for param, value in zip(symbolic_parameters, parameters) + } parameter_values_list = [parameter_values[param] for param in symbolic_parameters] # Create a lambdify function for substitution @@ -573,7 +704,9 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names if not isinstance(substituted_formulas[i], np.ndarray): for k in range(len(substituted_formulas)): if isinstance(substituted_formulas[k], np.ndarray): - substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k]) + substituted_formulas[i] = substituted_formulas[i] * np.ones_like( + substituted_formulas[k] + ) return np.array(substituted_formulas) @@ -623,17 +756,23 @@ def parse_parameter_strings_to_numpy_arrays( if not is_totally_resolvable: not_resolvable_indices = [it for it, x in enumerate(is_resolvable) if not x] - tmp_rc[not_resolvable_indices] = self.parse_parameter_strings_to_numpy_arrays( - parameters, - parameter_names, - [string[not is_resolvable]], - operator_reduce_lambdas, - operators[1:], + tmp_rc[not_resolvable_indices] = ( + self.parse_parameter_strings_to_numpy_arrays( + parameters, + parameter_names, + [string[not is_resolvable]], + operator_reduce_lambdas, + operators[1:], + ) ) for numeric_index in [x for x in range(len(is_numeric)) if is_numeric[x]]: tmp_rc[numeric_index] = parameters[0] * 0 + float(string[numeric_index]) for parameter_index in [x for x in range(len(is_parameter)) if is_parameter[x]]: - parameter_name_index = [it for it, x in enumerate(parameter_names) if x == string[parameter_index]] + parameter_name_index = [ + it + for it, x in enumerate(parameter_names) + if x == string[parameter_index] + ] tmp_rc[parameter_index] = parameters[parameter_name_index] rc[sit] = reduce(operator_reduce_lambdas[operators[0]], tmp_rc) @@ -648,7 +787,9 @@ def get_compartments_explicitDF(self): df = df.rename(columns=rename_dict) return df - def plot(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): """ if source_filters is [["age0to17"], ["OMICRON", "WILD"]], it means filter all transitions that have as source age0to17 AND (OMICRON OR WILD). @@ -712,8 +853,12 @@ def list_access_element_safe(thing, idx, dimension=None, encapsulate_as_list=Fal except Exception as e: print(f"Error {e}:") print(f">>> in list_access_element_safe for {thing} at index {idx}") - print(">>> This is often, but not always because the object above is a list (there are brackets around it).") - print(">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it.") + print( + ">>> This is often, but not always because the object above is a list (there are brackets around it)." + ) + print( + ">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it." + ) print(f"dimension: {dimension}") raise e @@ -755,7 +900,9 @@ def compartments(): def plot(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) # TODO: this should be a command like build compartments. ( @@ -774,7 +921,9 @@ def plot(): def export(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) ( unique_strings, transition_array, diff --git a/flepimop/gempyor_pkg/src/gempyor/config_validator.py b/flepimop/gempyor_pkg/src/gempyor/config_validator.py index 95f6e3029..61757cfa9 100644 --- a/flepimop/gempyor_pkg/src/gempyor/config_validator.py +++ b/flepimop/gempyor_pkg/src/gempyor/config_validator.py @@ -1,33 +1,60 @@ import yaml -from pydantic import BaseModel, ValidationError, model_validator, Field, AfterValidator, validator +from pydantic import ( + BaseModel, + ValidationError, + model_validator, + Field, + AfterValidator, + validator, +) from datetime import date from typing import Dict, List, Union, Literal, Optional, Annotated, Any from functools import partial from gempyor import compartments + def read_yaml(file_path: str) -> dict: - with open(file_path, 'r') as stream: + with open(file_path, "r") as stream: config = yaml.safe_load(stream) - + return CheckConfig(**config).model_dump() - + + def allowed_values(v, values): assert v in values return v + # def parse_value(cls, values): # value = values.get('value') # parsed_val = compartments.Compartments.parse_parameter_strings_to_numpy_arrays_v2(value) # return parsed_val - + + class SubpopSetupConfig(BaseModel): geodata: str mobility: Optional[str] selected: List[str] = Field(default_factory=list) # state_level: Optional[bool] = False # pretty sure this doesn't exist anymore + class InitialConditionsConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile', 'plugin']))] = 'Default' + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "Default", + "SetInitialConditions", + "SetInitialConditionsFolderDraw", + "InitialConditionsFolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "Default" initial_file_type: Optional[str] = None initial_conditions_file: Optional[str] = None proportional: Optional[bool] = None @@ -36,105 +63,160 @@ class InitialConditionsConfig(BaseModel): ignore_population_checks: Optional[bool] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_initial_file_check(cls, values): - method = values.get('method') - initial_conditions_file = values.get('initial_conditions_file') - initial_file_type = values.get('initial_file_type') - if method in {'FromFile', 'SetInitialConditions'} and not initial_conditions_file: - raise ValueError(f'Error in InitialConditions: An initial_conditions_file is required when method is {method}') - if method in {'InitialConditionsFolderDraw','SetInitialConditionsFolderDraw'} and not initial_file_type: - raise ValueError(f'Error in InitialConditions: initial_file_type is required when method is {method}') + method = values.get("method") + initial_conditions_file = values.get("initial_conditions_file") + initial_file_type = values.get("initial_file_type") + if method in {"FromFile", "SetInitialConditions"} and not initial_conditions_file: + raise ValueError( + f"Error in InitialConditions: An initial_conditions_file is required when method is {method}" + ) + if ( + method in {"InitialConditionsFolderDraw", "SetInitialConditionsFolderDraw"} + and not initial_file_type + ): + raise ValueError( + f"Error in InitialConditions: initial_file_type is required when method is {method}" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in InitialConditions: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in InitialConditions: a plugin file path is required when method is plugin" + ) return values class SeedingConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' # note: removed NegativeBinomialDistributed because no longer supported + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "NoSeeding", + "PoissonDistributed", + "FolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "NoSeeding" # note: removed NegativeBinomialDistributed because no longer supported lambda_file: Optional[str] = None seeding_file_type: Optional[str] = None seeding_file: Optional[str] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_seedingfile(cls, values): - method = values.get('method') - lambda_file = values.get('lambda_file') - seeding_file_type = values.get('seeding_file_type') - seeding_file = values.get('seeding_file') - if method == 'PoissonDistributed' and not lambda_file: - raise ValueError(f'Error in Seeding: A lambda_file is required when method is {method}') - if method == 'FolderDraw' and not seeding_file_type: - raise ValueError('Error in Seeding: A seeding_file_type is required when method is FolderDraw') - if method == 'FromFile' and not seeding_file: - raise ValueError('Error in Seeding: A seeding_file is required when method is FromFile') + method = values.get("method") + lambda_file = values.get("lambda_file") + seeding_file_type = values.get("seeding_file_type") + seeding_file = values.get("seeding_file") + if method == "PoissonDistributed" and not lambda_file: + raise ValueError( + f"Error in Seeding: A lambda_file is required when method is {method}" + ) + if method == "FolderDraw" and not seeding_file_type: + raise ValueError( + "Error in Seeding: A seeding_file_type is required when method is FolderDraw" + ) + if method == "FromFile" and not seeding_file: + raise ValueError( + "Error in Seeding: A seeding_file is required when method is FromFile" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in Seeding: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in Seeding: a plugin file path is required when method is plugin" + ) return values - + + class IntegrationConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['rk4', 'rk4.jit', 'best.current', 'legacy']))] = 'rk4' + method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["rk4", "rk4.jit", "best.current", "legacy"]) + ), + ] = "rk4" dt: float = 2.0 + class ValueConfig(BaseModel): - distribution: str = 'fixed' - value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS + distribution: str = "fixed" + value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS mean: Optional[float] = None sd: Optional[float] = None a: Optional[float] = None b: Optional[float] = None - @model_validator(mode='before') + @model_validator(mode="before") def check_distr(cls, values): - distr = values.get('distribution') - value = values.get('value') - mean = values.get('mean') - sd = values.get('sd') - a = values.get('a') - b = values.get('b') - if distr != 'fixed': + distr = values.get("distribution") + value = values.get("value") + mean = values.get("mean") + sd = values.get("sd") + a = values.get("a") + b = values.get("b") + if distr != "fixed": if not mean and not sd: - raise ValueError('Error in value: mean and sd must be provided for non-fixed distributions') - if distr == 'truncnorm' and not a and not b: - raise ValueError('Error in value: a and b must be provided for truncated normal distributions') - if distr == 'fixed' and not value: - raise ValueError('Error in value: value must be provided for fixed distributions') + raise ValueError( + "Error in value: mean and sd must be provided for non-fixed distributions" + ) + if distr == "truncnorm" and not a and not b: + raise ValueError( + "Error in value: a and b must be provided for truncated normal distributions" + ) + if distr == "fixed" and not value: + raise ValueError( + "Error in value: value must be provided for fixed distributions" + ) return values + class BaseParameterConfig(BaseModel): value: Optional[ValueConfig] = None modifier_parameter: Optional[str] = None - name: Optional[str] = None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + name: Optional[str] = ( + None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + ) + class SeirParameterConfig(BaseParameterConfig): value: Optional[ValueConfig] = None - stacked_modifier_method: Annotated[str, AfterValidator(partial(allowed_values, values=['sum', 'product', 'reduction_product']))] = None + stacked_modifier_method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["sum", "product", "reduction_product"]) + ), + ] = None rolling_mean_windows: Optional[float] = None timeseries: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def which_value(cls, values): - value = values.get('value') is not None - timeseries = values.get('timeseries') is not None + value = values.get("value") is not None + timeseries = values.get("timeseries") is not None if value and timeseries: - raise ValueError('Error in seir::parameters: your parameter is both a timeseries and a value, please choose one') + raise ValueError( + "Error in seir::parameters: your parameter is both a timeseries and a value, please choose one" + ) return values - - -class TransitionConfig(BaseModel): + + +class TransitionConfig(BaseModel): # !! sometimes these are lists of lists and sometimes they are lists... how to deal with this? source: List[List[str]] destination: List[List[str]] @@ -142,11 +224,15 @@ class TransitionConfig(BaseModel): proportion_exponent: List[List[str]] proportional_to: List[str] + class SeirConfig(BaseModel): - integration: IntegrationConfig # is this Optional? - parameters: Dict[str, SeirParameterConfig] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? + integration: IntegrationConfig # is this Optional? + parameters: Dict[ + str, SeirParameterConfig + ] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? transitions: List[TransitionConfig] + class SinglePeriodModifierConfig(BaseModel): method: Literal["SinglePeriodModifier"] parameter: str @@ -157,15 +243,18 @@ class SinglePeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class MultiPeriodDatesConfig(BaseModel): start_date: date end_date: date - + + class MultiPeriodGroupsConfig(BaseModel): subpop: List[str] subpop_groups: Optional[str] = None periods: List[MultiPeriodDatesConfig] + class MultiPeriodModifierConfig(BaseModel): method: Literal["MultiPeriodModifier"] parameter: str @@ -173,37 +262,47 @@ class MultiPeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class StackedModifierConfig(BaseModel): method: Literal["StackedModifier"] modifiers: List[str] + class ModifiersConfig(BaseModel): scenarios: List[str] modifiers: Dict[str, Any] - + @field_validator("modifiers") def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: errors = [] for key, entry in value.items(): method = entry.get("method") - if method not in {"SinglePeriodModifier", "MultiPeriodModifier", "StackedModifier"}: + if method not in { + "SinglePeriodModifier", + "MultiPeriodModifier", + "StackedModifier", + }: errors.append(f"Invalid modifier method: {method}") if errors: raise ValueError("Errors in modifiers:\n" + "\n".join(errors)) return value -class SourceConfig(BaseModel): # set up only for incidence or prevalence. Can this be any name? i don't think so atm +class SourceConfig( + BaseModel +): # set up only for incidence or prevalence. Can this be any name? i don't think so atm incidence: Dict[str, str] = None - prevalence: Dict[str, str] = None + prevalence: Dict[str, str] = None # note: these dictionaries have to have compartment names... more complicated to set this up - @model_validator(mode='before') + @model_validator(mode="before") def which_source(cls, values): - incidence = values.get('incidence') - prevalence = values.get('prevalence') + incidence = values.get("incidence") + prevalence = values.get("prevalence") if incidence and prevalence: - raise ValueError('Error in outcomes::source. Can only be incidence or prevalence, not both.') + raise ValueError( + "Error in outcomes::source. Can only be incidence or prevalence, not both." + ) return values # @model_validator(mode='before') # DOES NOT WORK @@ -214,12 +313,13 @@ def which_source(cls, values): # source_names.append(key) # return source_names # Access keys using a loop + class DelayFrameConfig(BaseModel): source: Optional[SourceConfig] = None probability: Optional[BaseParameterConfig] = None delay: Optional[BaseParameterConfig] = None duration: Optional[BaseParameterConfig] = None - sum: Optional[List[str]] = None # only for sums of other outcomes + sum: Optional[List[str]] = None # only for sums of other outcomes # @validator("sum") # def validate_sum_elements(cls, value: Optional[List[str]]) -> Optional[List[str]]: @@ -233,65 +333,94 @@ class DelayFrameConfig(BaseModel): # return value # note: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names - @model_validator(mode='before') + @model_validator(mode="before") def check_outcome_type(cls, values): - sum_present = values.get('sum') is not None - source_present = values.get('source') is not None + sum_present = values.get("sum") is not None + source_present = values.get("source") is not None if sum_present and source_present: - raise ValueError(f"Error in outcome: Both 'sum' and 'source' are present. Choose one.") + raise ValueError( + f"Error in outcome: Both 'sum' and 'source' are present. Choose one." + ) elif not sum_present and not source_present: - raise ValueError(f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one.") + raise ValueError( + f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one." + ) return values + class OutcomesConfig(BaseModel): - method: Literal["delayframe"] # Is this required? I don't see it anywhere in the gempyor code + method: Literal[ + "delayframe" + ] # Is this required? I don't see it anywhere in the gempyor code param_from_file: Optional[bool] = None param_subpop_file: Optional[str] = None outcomes: Dict[str, DelayFrameConfig] - - @model_validator(mode='before') + + @model_validator(mode="before") def check_paramfromfile_type(cls, values): - param_from_file = values.get('param_from_file') is not None - param_subpop_file = values.get('param_subpop_file') is not None + param_from_file = values.get("param_from_file") is not None + param_subpop_file = values.get("param_subpop_file") is not None if param_from_file and not param_subpop_file: - raise ValueError(f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True") + raise ValueError( + f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True" + ) return values + class ResampleConfig(BaseModel): aggregator: Optional[str] = None freq: Optional[str] = None skipna: Optional[bool] = False + class LikelihoodParams(BaseModel): scale: float # are there other options here? + class LikelihoodReg(BaseModel): - name: str + name: str + class LikelihoodConfig(BaseModel): - dist: Annotated[str, AfterValidator(partial(allowed_values, values=['pois', 'norm', 'norm_cov', 'nbinom', 'rmse', 'absolute_error']))] = None + dist: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=["pois", "norm", "norm_cov", "nbinom", "rmse", "absolute_error"], + ) + ), + ] = None params: Optional[LikelihoodParams] = None + class StatisticsConfig(BaseModel): name: str sim_var: str data_var: str regularize: Optional[LikelihoodReg] = None resample: Optional[ResampleConfig] = None - scale: Optional[float] = None # is scale here or at likelihood level? - zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? + scale: Optional[float] = None # is scale here or at likelihood level? + zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? likelihood: LikelihoodConfig + class InferenceConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['emcee', 'default', 'classical']))] = 'default' # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options - iterations_per_slot: Optional[int] # i think this is optional because it is also set in command line?? - do_inference: bool + method: Annotated[ + str, + AfterValidator(partial(allowed_values, values=["emcee", "default", "classical"])), + ] = "default" # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options + iterations_per_slot: Optional[ + int + ] # i think this is optional because it is also set in command line?? + do_inference: bool gt_data_path: str statistics: Dict[str, StatisticsConfig] + class CheckConfig(BaseModel): name: str setup_name: Optional[str] = None @@ -312,32 +441,36 @@ class CheckConfig(BaseModel): outcome_modifiers: Optional[ModifiersConfig] = None inference: Optional[InferenceConfig] = None -# add validator for if modifiers exist but seir/outcomes do not - -# there is an error in the one below - @model_validator(mode='before') + # add validator for if modifiers exist but seir/outcomes do not + + # there is an error in the one below + @model_validator(mode="before") def verify_inference(cls, values): - inference_present = values.get('inference') is not None - start_date_groundtruth = values.get('start_date_groundtruth') is not None + inference_present = values.get("inference") is not None + start_date_groundtruth = values.get("start_date_groundtruth") is not None if inference_present and not start_date_groundtruth: - raise ValueError('Inference mode is enabled but no groundtruth dates are provided') + raise ValueError( + "Inference mode is enabled but no groundtruth dates are provided" + ) elif start_date_groundtruth and not inference_present: - raise ValueError('Groundtruth dates are provided but inference mode is not enabled') + raise ValueError( + "Groundtruth dates are provided but inference mode is not enabled" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def check_dates(cls, values): - start_date = values.get('start_date') - end_date = values.get('end_date') + start_date = values.get("start_date") + end_date = values.get("end_date") if start_date and end_date: if end_date <= start_date: - raise ValueError('end_date must be greater than start_date') + raise ValueError("end_date must be greater than start_date") return values - - @model_validator(mode='before') + + @model_validator(mode="before") def init_or_seed(cls, values): - init = values.get('initial_conditions') - seed = values.get('seeding') + init = values.get("initial_conditions") + seed = values.get("seeding") if not init or seed: - raise ValueError('either initial_conditions or seeding must be provided') + raise ValueError("either initial_conditions or seeding must be provided") return values diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py index a847a5f90..c484a3980 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py @@ -48,7 +48,9 @@ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -81,7 +83,12 @@ True, ) df = seir.states2Df(modinf, states) -assert df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] > 1 +assert ( + df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[ + str(modinf.tf), "20002" + ] + > 1 +) print(df) ts = df cp = "R" diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py index 43066e5ee..1bbf5b207 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py @@ -53,7 +53,11 @@ def ode_integration( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -65,7 +69,9 @@ def ode_integration( def rhs(t, x, today): print("rhs.t", t) states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -80,9 +86,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -91,41 +101,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -142,9 +167,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -168,18 +199,22 @@ def rhs(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -257,7 +292,11 @@ def rk4_integration1( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -267,7 +306,9 @@ def rk4_integration1( def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -282,9 +323,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -293,41 +338,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -344,9 +404,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -380,18 +446,22 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -447,7 +517,11 @@ def rk4_integration2( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -458,7 +532,9 @@ def rk4_integration2( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -473,9 +549,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -484,41 +564,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -535,9 +630,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -572,18 +673,22 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -644,7 +749,11 @@ def rk4_integration3( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -655,7 +764,9 @@ def rk4_integration3( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -670,9 +781,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -681,41 +796,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -732,9 +862,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -753,16 +889,18 @@ def rk4_integrate(t, x, today): @jit(nopython=True) def day_wrapper_rk4(today, states_next): x_ = np.zeros((2, ncompartments, nspatial_nodes)) - for seeding_instance_idx in range(day_start_idx_dict[today], day_start_idx_dict[today + 1]): + for seeding_instance_idx in range( + day_start_idx_dict[today], day_start_idx_dict[today + 1] + ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_subpops_dict[seeding_instance_idx] seeding_sources = seeding_sources_dict[seeding_instance_idx] seeding_destinations = seeding_destinations_dict[seeding_instance_idx] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][ + seeding_subpops + ] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts # ADD TO cumulative, this is debatable, @@ -838,7 +976,11 @@ def rk4_integration4( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -849,7 +991,9 @@ def rk4_integration4( @jit(nopython=True) # , fastmath=True, parallel=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -864,9 +1008,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -875,41 +1023,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -926,9 +1089,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -963,18 +1132,22 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1031,7 +1204,11 @@ def rk4_integration5( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1058,18 +1235,22 @@ def rk4_integration5( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1092,7 +1273,9 @@ def rk4_integration5( x = x_ + kx[i - 1] * rk_coefs[i] states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1107,45 +1290,62 @@ def rk4_integration5( proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] - # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][ - today + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] ] + # exponents should not be a proportion, since we don't sum them over sum compartments + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 - ) == transitions[transition_proportion_stop_col][transition_index] + transitions[transition_proportion_start_col][ + transition_index + ] + + 1 + ) == transitions[transition_proportion_stop_col][ + transition_index + ] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][ - spatial_node - ] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] rate_change_compartment = proportion_change_compartment @@ -1157,7 +1357,9 @@ def rk4_integration5( rate_change_compartment *= parameters[ transitions[transition_rate_col][transition_index] ][today][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1174,8 +1376,12 @@ def rk4_integration5( # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move states_diff[ 1, transitions[transition_destination_col][transition_index], : ] += number_move # Cumumlative @@ -1234,7 +1440,11 @@ def rk4_integration2_smart( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1248,7 +1458,9 @@ def rhs(t, x): if (today) > ndays: today = ndays - 1 states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1263,9 +1475,13 @@ def rhs(t, x): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -1274,41 +1490,56 @@ def rhs(t, x): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1325,9 +1556,15 @@ def rhs(t, x): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -1374,18 +1611,24 @@ def rk4_integrate(today, x): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1450,9 +1693,12 @@ def rk4_integrate(today, x): ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -1504,7 +1750,11 @@ def rk4_integration_aot( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1515,7 +1765,9 @@ def rk4_integration_aot( def rhs(t, x, today): # states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] states_current = x[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1530,9 +1782,13 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -1541,41 +1797,56 @@ def rhs(t, x, today): source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1592,9 +1863,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return states_diff # return a 1D vector @@ -1628,18 +1905,22 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index fc0613045..9d5accc2c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -82,11 +82,17 @@ def simulation_atomic( np.random.seed(int.from_bytes(os.urandom(4), byteorder="little")) random_id = np.random.randint(0, 1e8) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=snpi_df_in) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=snpi_df_in + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( - modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=hnpi_df_in + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + bypass_DF=hnpi_df_in, ) else: npi_outcomes = None @@ -94,10 +100,14 @@ def simulation_atomic( # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) # Convert the seeding data dictionnary to a numba dictionnary - seeding_data_nbdict = nb.typed.Dict.empty(key_type=nb.types.unicode_type, value_type=nb.types.int64[:]) + seeding_data_nbdict = nb.typed.Dict.empty( + key_type=nb.types.unicode_type, value_type=nb.types.int64[:] + ) for k, v in seeding_data.items(): seeding_data_nbdict[k] = np.array(v, dtype=np.int64) @@ -151,7 +161,9 @@ def get_static_arguments(modinf): ) = modinf.compartments.get_transition_array() outcomes_parameters = outcomes.read_parameters_from_config(modinf) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, load_ID=False, sim_id2load=None, config=config + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( modinf=modinf, @@ -162,7 +174,9 @@ def get_static_arguments(modinf): else: npi_outcomes = None - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) initial_conditions = modinf.initial_conditions.get_from_config(sim_id=0, modinf=modinf) seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id=0, modinf=modinf) @@ -170,7 +184,9 @@ def get_static_arguments(modinf): # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) if real_simulation: states = seir.steps_SEIR( @@ -198,7 +214,9 @@ def get_static_arguments(modinf): subpop=modinf.subpop_struct.subpop_names, ) - zeros = np.zeros((len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"]))) + zeros = np.zeros( + (len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"])) + ) states = xr.Dataset( data_vars=dict( prevalence=(["date", "compartment", "subpop"], zeros), @@ -261,7 +279,9 @@ def autodetect_scenarios(config): outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) @@ -275,41 +295,41 @@ def autodetect_scenarios(config): return seir_modifiers_scenarios[0], outcome_modifiers_scenarios[0] + # rewrite the get log loss functions as single functions, not in a class. This is not faster # def get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # if not inferpar.check_in_bound(proposal=proposal): # if not silent: # print("OUT OF BOUND!!") # return -np.inf, -np.inf, -np.inf -# +# # snpi_df_mod, hnpi_df_mod = inferpar.inject_proposal( # proposal=proposal, # snpi_df=static_sim_arguments["snpi_df_ref"], # hnpi_df=static_sim_arguments["hnpi_df_ref"], # ) -# +# # ss = copy.deepcopy(static_sim_arguments) # ss["snpi_df_in"] = snpi_df_mod # ss["hnpi_df_in"] = hnpi_df_mod # del ss["snpi_df_ref"] # del ss["hnpi_df_ref"] -# +# # outcomes_df = simulation_atomic(**ss, modinf=modinf, save=save) -# +# # ll_total, logloss, regularizations = logloss.compute_logloss( # model_df=outcomes_df, subpop_names=modinf.subpop_struct.subpop_names # ) # if not silent: # print(f"llik is {ll_total}") -# +# # return ll_total, logloss, regularizations -# +# # def get_logloss_as_single_number(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # ll_total, logloss, regularizations = get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent, save) # return ll_total - class GempyorInference: def __init__( self, @@ -333,12 +353,20 @@ def __init__( config.set_file(os.path.join(path_prefix, config_filepath)) - self.seir_modifiers_scenario, self.outcome_modifiers_scenario = autodetect_scenarios(config) + self.seir_modifiers_scenario, self.outcome_modifiers_scenario = ( + autodetect_scenarios(config) + ) if run_id is None: run_id = file_paths.run_id() if prefix is None: - prefix = config["name"].get() + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + "/" + run_id + "/" + prefix = ( + config["name"].get() + + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + + "/" + + run_id + + "/" + ) in_run_id = run_id if out_run_id is None: out_run_id = in_run_id @@ -387,7 +415,8 @@ def __init__( self.do_inference = True self.inference_method = "emcee" self.inferpar = inference_parameter.InferenceParameters( - global_config=config, subpop_names=self.modinf.subpop_struct.subpop_names + global_config=config, + subpop_names=self.modinf.subpop_struct.subpop_names, ) self.logloss = logloss.LogLoss( inference_config=config["inference"], @@ -412,7 +441,14 @@ def set_save(self, save): def get_all_sim_arguments(self): # inferpar, logloss, static_sim_arguments, modinf, proposal, silent, save - return [self.inferpar, self.logloss, self.static_sim_arguments, self.modinf, self.silent, self.save] + return [ + self.inferpar, + self.logloss, + self.static_sim_arguments, + self.modinf, + self.silent, + self.save, + ] def get_logloss(self, proposal): if not self.inferpar.check_in_bound(proposal=proposal): @@ -479,11 +515,15 @@ def update_run_id(self, new_run_id, new_out_run_id=None): else: self.modinf.out_run_id = new_out_run_id - def one_simulation_legacy(self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None): + def one_simulation_legacy( + self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None + ): sim_id2write = int(sim_id2write) if load_ID: sim_id2load = int(sim_id2load) - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): with Timer("onerun_SEIR"): seir.onerun_SEIR( sim_id2write=sim_id2write, @@ -533,7 +573,9 @@ def one_simulation( sim_id2load = int(sim_id2load) self.lastsim_sim_id2load = sim_id2load - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): if not self.already_built and self.modinf.outcomes_config is not None: self.outcomes_parameters = outcomes.read_parameters_from_config(self.modinf) @@ -541,10 +583,24 @@ def one_simulation( npi_seir = None if parallel: with Timer("//things"): - with ProcessPoolExecutor(max_workers=max(mp.cpu_count(), 3)) as executor: - if self.modinf.seir_config is not None and self.modinf.npi_config_seir is not None: - ret_seir = executor.submit(seir.build_npi_SEIR, self.modinf, load_ID, sim_id2load, config) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + with ProcessPoolExecutor( + max_workers=max(mp.cpu_count(), 3) + ) as executor: + if ( + self.modinf.seir_config is not None + and self.modinf.npi_config_seir is not None + ): + ret_seir = executor.submit( + seir.build_npi_SEIR, + self.modinf, + load_ID, + sim_id2load, + config, + ) + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): ret_outcomes = executor.submit( outcomes.build_outcome_modifiers, self.modinf, @@ -553,7 +609,9 @@ def one_simulation( config, ) if not self.already_built: - ret_comparments = executor.submit(self.modinf.compartments.get_transition_array) + ret_comparments = executor.submit( + self.modinf.compartments.get_transition_array + ) # print("expections:", ret_seir.exception(), ret_outcomes.exception(), ret_comparments.exception()) @@ -565,18 +623,33 @@ def one_simulation( self.proportion_info, ) = ret_comparments.result() self.already_built = True - if self.modinf.seir_config is not None and self.modinf.npi_config_seir is not None: + if ( + self.modinf.seir_config is not None + and self.modinf.npi_config_seir is not None + ): npi_seir = ret_seir.result() - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = ret_outcomes.result() else: if not self.already_built: self.build_structure() - if self.modinf.seir_config is not None and self.modinf.npi_config_seir is not None: + if ( + self.modinf.seir_config is not None + and self.modinf.npi_config_seir is not None + ): npi_seir = seir.build_npi_SEIR( - modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + modinf=self.modinf, + load_ID=load_ID, + sim_id2load=sim_id2load, + config=config, ) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = outcomes.build_outcome_modifiers( modinf=self.modinf, load_ID=load_ID, @@ -602,8 +675,12 @@ def one_simulation( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = self.modinf.initial_conditions.get_from_file(sim_id2load, modinf=self.modinf) - seeding_data, seeding_amounts = self.modinf.seeding.get_from_file(sim_id2load, modinf=self.modinf) + initial_conditions = self.modinf.initial_conditions.get_from_file( + sim_id2load, modinf=self.modinf + ) + seeding_data, seeding_amounts = self.modinf.seeding.get_from_file( + sim_id2load, modinf=self.modinf + ) else: initial_conditions = self.modinf.initial_conditions.get_from_config( sim_id2write, modinf=self.modinf @@ -649,7 +726,7 @@ def one_simulation( parameters=self.outcomes_parameters, loaded_values=loaded_values, npi=npi_outcomes, - bypass_seir_xr=states + bypass_seir_xr=states, ) self.lastsim_outcomes_df = outcomes_df self.lastsim_hpar_df = hpar_df @@ -664,14 +741,18 @@ def one_simulation( ) return 0 - def plot_transition_graph(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot_transition_graph( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): self.modinf.compartments.plot( output_file=output_file, source_filters=source_filters, destination_filters=destination_filters, ) - def get_outcome_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_outcome_npi( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): npi_outcomes = None if self.modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( @@ -695,7 +776,9 @@ def get_seir_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_F ) return npi_seir - def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parameters( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): param_df = None if bypass_DF is not None: param_df = bypass_DF @@ -716,7 +799,9 @@ def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, b ) return p_draw - def get_seir_parametersDF(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parametersDF( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): p_draw = self.get_seir_parameters( load_ID=load_ID, sim_id2load=sim_id2load, @@ -771,7 +856,9 @@ def get_parsed_parameters_seir( if not self.already_built: self.build_structure() - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -788,7 +875,9 @@ def get_reduced_parameters_seir( # bypass_DF=None, # bypass_FN=None, ): - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -809,7 +898,9 @@ def paramred_parallel(run_spec, snpi_fn): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) snpi = pq.read_table(snpi_fn).to_pandas() @@ -820,7 +911,9 @@ def paramred_parallel(run_spec, snpi_fn): params_draw_arr = gempyor_inference.get_seir_parameters( bypass_FN=snpi_fn.replace("snpi", "spar") ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from @@ -835,7 +928,9 @@ def paramred_parallel_config(run_spec, dummy): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) npi_seir = gempyor_inference.get_seir_npi() @@ -843,6 +938,8 @@ def paramred_parallel_config(run_spec, dummy): params_draw_arr = ( gempyor_inference.get_seir_parameters() ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from diff --git a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py index 8f1b0fc53..e6a795192 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py @@ -36,9 +36,14 @@ def add_modifier(self, pname, ptype, parameter_config, subpops): """ # identify spatial group affected_subpops = set(subpops) - if parameter_config["subpop"].exists() and parameter_config["subpop"].get() != "all": + if ( + parameter_config["subpop"].exists() + and parameter_config["subpop"].get() != "all" + ): affected_subpops = {str(n.get()) for n in parameter_config["subpop"]} - spatial_groups = NPI.helpers.get_spatial_groups(parameter_config, list(affected_subpops)) + spatial_groups = NPI.helpers.get_spatial_groups( + parameter_config, list(affected_subpops) + ) # ungrouped subpop (all affected subpop by default) have one parameter per subpop if spatial_groups["ungrouped"]: @@ -87,7 +92,9 @@ def build_from_config(self, global_config, subpop_names): for config_part in ["seir_modifiers", "outcome_modifiers"]: if global_config[config_part].exists(): for npi in global_config[config_part]["modifiers"].get(): - if global_config[config_part]["modifiers"][npi]["perturbation"].exists(): + if global_config[config_part]["modifiers"][npi][ + "perturbation" + ].exists(): self.add_modifier( pname=npi, ptype=config_part, diff --git a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py index 554f3910c..7048a684e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py +++ b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py @@ -45,19 +45,30 @@ def __init__( if self.initial_conditions_config is not None: if "ignore_population_checks" in self.initial_conditions_config.keys(): - self.ignore_population_checks = self.initial_conditions_config["ignore_population_checks"].get(bool) + self.ignore_population_checks = self.initial_conditions_config[ + "ignore_population_checks" + ].get(bool) if "allow_missing_subpops" in self.initial_conditions_config.keys(): - self.allow_missing_subpops = self.initial_conditions_config["allow_missing_subpops"].get(bool) + self.allow_missing_subpops = self.initial_conditions_config[ + "allow_missing_subpops" + ].get(bool) if "allow_missing_compartments" in self.initial_conditions_config.keys(): - self.allow_missing_compartments = self.initial_conditions_config["allow_missing_compartments"].get(bool) + self.allow_missing_compartments = self.initial_conditions_config[ + "allow_missing_compartments" + ].get(bool) # TODO: add check, this option onlywork with tidy dataframe if "proportional" in self.initial_conditions_config.keys(): - self.proportional_ic = self.initial_conditions_config["proportional"].get(bool) + self.proportional_ic = self.initial_conditions_config["proportional"].get( + bool + ) def get_from_config(self, sim_id: int, modinf) -> np.ndarray: method = "Default" - if self.initial_conditions_config is not None and "method" in self.initial_conditions_config.keys(): + if ( + self.initial_conditions_config is not None + and "method" in self.initial_conditions_config.keys() + ): method = self.initial_conditions_config["method"].as_str() if method == "Default": @@ -69,10 +80,13 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: if method == "SetInitialConditions" or method == "SetInitialConditionsFolderDraw": # TODO Think about - Does not support the new way of doing compartment indexing if method == "SetInitialConditionsFolderDraw": - ic_df = modinf.read_simID(ftype=self.initial_conditions_config["initial_file_type"], sim_id=sim_id) + ic_df = modinf.read_simID( + ftype=self.initial_conditions_config["initial_file_type"], sim_id=sim_id + ) else: ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_tidydataframe( ic_df=ic_df, @@ -85,11 +99,13 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: elif method == "InitialConditionsFolderDraw" or method == "FromFile": if method == "InitialConditionsFolderDraw": ic_df = modinf.read_simID( - ftype=self.initial_conditions_config["initial_file_type"].get(), sim_id=sim_id + ftype=self.initial_conditions_config["initial_file_type"].get(), + sim_id=sim_id, ) elif method == "FromFile": ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_seir_output( @@ -102,7 +118,9 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: raise NotImplementedError(f"unknown initial conditions method [got: {method}]") # check that the inputed values sums to the subpop population: - check_population(y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks) + check_population( + y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks + ) return y0 @@ -145,7 +163,9 @@ def read_initial_condition_from_tidydataframe( states_pl = ic_df[ic_df["subpop"] == pl] for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): if "mc_name" in states_pl.columns: - ic_df_compartment_val = states_pl[states_pl["mc_name"] == comp_name]["amount"] + ic_df_compartment_val = states_pl[states_pl["mc_name"] == comp_name][ + "amount" + ] else: filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") ic_df_compartment_val = states_pl.copy() @@ -177,7 +197,9 @@ def read_initial_condition_from_tidydataframe( logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartment ({modinf.compartments.compartments['name'].iloc[0]})" ) - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO: this is probably ok but highlighting for consistency if "proportional" in self.initial_conditions_config.keys(): if self.initial_conditions_config["proportional"].get(): @@ -202,7 +224,9 @@ def read_initial_condition_from_tidydataframe( return y0 -def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops, allow_missing_compartments): +def read_initial_condition_from_seir_output( + ic_df, modinf, allow_missing_subpops, allow_missing_compartments +): """ Read the initial conditions from the SEIR output. @@ -227,9 +251,13 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ic_df["date"] = ic_df["date"].dt.date ic_df["date"] = ic_df["date"].astype(str) - ic_df = ic_df[(ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence")] + ic_df = ic_df[ + (ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence") + ] if ic_df.empty: - raise ValueError(f"There is no entry for initial time ti in the provided initial_conditions::states_file.") + raise ValueError( + f"There is no entry for initial time ti in the provided initial_conditions::states_file." + ) y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops)) for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): @@ -239,7 +267,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") ic_df_compartment = ic_df.copy() for mc_name, mc_value in filters.items(): - ic_df_compartment = ic_df_compartment[ic_df_compartment["mc_" + mc_name] == mc_value] + ic_df_compartment = ic_df_compartment[ + ic_df_compartment["mc_" + mc_name] == mc_value + ] if len(ic_df_compartment) > 1: # ic_df_compartment = ic_df_compartment.iloc[0] @@ -248,7 +278,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ) elif ic_df_compartment.empty: if allow_missing_compartments: - ic_df_compartment = pd.DataFrame(0, columns=ic_df_compartment.columns, index=[0]) + ic_df_compartment = pd.DataFrame( + 0, columns=ic_df_compartment.columns, index=[0] + ) else: raise ValueError( f"Initial Conditions: Could not set compartment {comp_name} (id: {comp_idx}) in subpop {pl} (id: {pl_idx}). The data from the init file is {ic_df_compartment[pl]}." @@ -262,7 +294,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops if pl in ic_df.columns: y0[comp_idx, pl_idx] = float(ic_df_compartment[pl].iloc[0]) elif allow_missing_subpops: - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO this should set the full subpop, not just the 0th commpartment logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartments ({modinf.compartments.compartments['name'].iloc[0]})" diff --git a/flepimop/gempyor_pkg/src/gempyor/logloss.py b/flepimop/gempyor_pkg/src/gempyor/logloss.py index 95892dac0..09993421f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/logloss.py +++ b/flepimop/gempyor_pkg/src/gempyor/logloss.py @@ -17,7 +17,13 @@ class LogLoss: - def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_setup, path_prefix: str = "."): + def __init__( + self, + inference_config: confuse.ConfigView, + subpop_struct, + time_setup, + path_prefix: str = ".", + ): # TODO: bad format for gt because each date must have a value for each column, but if it doesn't and you add NA # then this NA has a meaning that depends on skip NA, which is annoying. # A lot of things can go wrong here, in the previous approach where GT was cast to xarray as @@ -35,20 +41,30 @@ def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_set # made the controversial choice of storing the gt as an xarray dataset instead of a dictionary # of dataframes - self.gt_xr = xr.Dataset.from_dataframe(self.gt.reset_index().set_index(["date", "subpop"])) + self.gt_xr = xr.Dataset.from_dataframe( + self.gt.reset_index().set_index(["date", "subpop"]) + ) # Very important: subsample the subpop in the population, in the right order, and sort by the date index. - self.gt_xr = self.gt_xr.sortby("date").reindex({"subpop": subpop_struct.subpop_names}) + self.gt_xr = self.gt_xr.sortby("date").reindex( + {"subpop": subpop_struct.subpop_names} + ) # This will force at 0, if skipna is False, data of some variable that don't exist if iother exist # and damn python datetime types are ugly... - self.first_date = max(pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti) - self.last_date = min(pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf) + self.first_date = max( + pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti + ) + self.last_date = min( + pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf + ) self.statistics = {} for key, value in inference_config["statistics"].items(): self.statistics[key] = statistics.Statistic(key, value) - def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename=None, **kwargs): + def plot_gt( + self, ax=None, subpop=None, statistic=None, subplot=False, filename=None, **kwargs + ): """Plots ground truth data. Args: @@ -68,7 +84,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= fig, axes = plt.subplots( len(self.gt["subpop"].unique()), len(self.gt.columns.drop("subpop")), - figsize=(4 * len(self.gt.columns.drop("subpop")), 3 * len(self.gt["subpop"].unique())), + figsize=( + 4 * len(self.gt.columns.drop("subpop")), + 3 * len(self.gt["subpop"].unique()), + ), dpi=250, sharex=True, ) @@ -81,7 +100,9 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= subpops = [subpop] if statistic is None: - statistics = self.gt.columns.drop("subpop") # Assuming other columns are statistics + statistics = self.gt.columns.drop( + "subpop" + ) # Assuming other columns are statistics else: statistics = [statistic] @@ -107,7 +128,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= plt.savefig(filename, **kwargs) # Save the figure if subplot: - return fig, axes # Return figure and subplots for potential further customization + return ( + fig, + axes, + ) # Return figure and subplots for potential further customization else: return ax # Optionally return the axis @@ -121,7 +145,9 @@ def compute_logloss(self, model_df, subpop_names): coords = {"statistic": list(self.statistics.keys()), "subpop": subpop_names} logloss = xr.DataArray( - np.zeros((len(coords["statistic"]), len(coords["subpop"]))), dims=["statistic", "subpop"], coords=coords + np.zeros((len(coords["statistic"]), len(coords["subpop"]))), + dims=["statistic", "subpop"], + coords=coords, ) regularizations = 0 diff --git a/flepimop/gempyor_pkg/src/gempyor/model_info.py b/flepimop/gempyor_pkg/src/gempyor/model_info.py index 54f981d8f..35502aadf 100644 --- a/flepimop/gempyor_pkg/src/gempyor/model_info.py +++ b/flepimop/gempyor_pkg/src/gempyor/model_info.py @@ -1,6 +1,13 @@ import pandas as pd import datetime, os, logging, pathlib, confuse -from . import seeding, subpopulation_structure, parameters, compartments, file_paths, initial_conditions +from . import ( + seeding, + subpopulation_structure, + parameters, + compartments, + file_paths, + initial_conditions, +) from .utils import read_df, write_df logger = logging.getLogger(__name__) @@ -11,7 +18,9 @@ def __init__(self, config: confuse.ConfigView): self.ti = config["start_date"].as_date() self.tf = config["end_date"].as_date() if self.tf <= self.ti: - raise ValueError("tf (time to finish) is less than or equal to ti (time to start)") + raise ValueError( + "tf (time to finish) is less than or equal to ti (time to start)" + ) self.n_days = (self.tf - self.ti).days + 1 self.dates = pd.date_range(start=self.ti, end=self.tf, freq="D") @@ -29,7 +38,7 @@ class ModelInfo: seeding # One of seeding or initial_conditions is required when running seir outcomes # Required if running outcomes seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters - outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes + outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes inference # Required if running inference ``` """ @@ -94,7 +103,9 @@ def __init__( # 3. What about subpopulations subpop_config = config["subpop_setup"] if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) self.path_prefix = pathlib.Path(path_prefix) self.subpop_struct = subpopulation_structure.SubpopulationStructure( @@ -112,7 +123,9 @@ def __init__( self.seir_config = config["seir"] self.parameters_config = config["seir"]["parameters"] self.initial_conditions_config = ( - config["initial_conditions"] if config["initial_conditions"].exists() else None + config["initial_conditions"] + if config["initial_conditions"].exists() + else None ) self.seeding_config = config["seeding"] if config["seeding"].exists() else None @@ -130,7 +143,9 @@ def __init__( subpop_names=self.subpop_struct.subpop_names, path_prefix=self.path_prefix, ) - self.seeding = seeding.SeedingFactory(config=self.seeding_config, path_prefix=self.path_prefix) + self.seeding = seeding.SeedingFactory( + config=self.seeding_config, path_prefix=self.path_prefix + ) self.initial_conditions = initial_conditions.InitialConditionsFactory( config=self.initial_conditions_config, path_prefix=self.path_prefix ) @@ -144,11 +159,19 @@ def __init__( self.npi_config_seir = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - self.npi_config_seir = config["seir_modifiers"]["modifiers"][seir_modifiers_scenario] - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() + self.npi_config_seir = config["seir_modifiers"]["modifiers"][ + seir_modifiers_scenario + ] + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() else: - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all elif self.seir_modifiers_scenario is not None: raise ValueError( "An seir modifiers scenario was provided to ModelInfo but no 'seir_modifiers' sections in config" @@ -157,7 +180,9 @@ def __init__( logging.info("Running ModelInfo with seir but without SEIR Modifiers") elif self.seir_modifiers_scenario is not None: - raise ValueError("A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config") + raise ValueError( + "A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config" + ) else: logging.critical("Running ModelInfo without SEIR") @@ -167,11 +192,19 @@ def __init__( self.npi_config_outcomes = None if config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][self.outcome_modifiers_scenario] - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() + self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][ + self.outcome_modifiers_scenario + ] + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() else: - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all ## NEED TO IMPLEMENT THIS -- CURRENTLY CANNOT USE outcome modifiers elif self.outcome_modifiers_scenario is not None: @@ -182,7 +215,9 @@ def __init__( else: self.outcome_modifiers_scenario = None else: - logging.info("Running ModelInfo with outcomes but without Outcomes Modifiers") + logging.info( + "Running ModelInfo with outcomes but without Outcomes Modifiers" + ) elif self.outcome_modifiers_scenario is not None: raise ValueError( "An outcome modifiers scenario was provided to ModelInfo but no 'outcomes:' sections in config" @@ -228,7 +263,9 @@ def __init__( os.makedirs(datadir, exist_ok=True) if self.write_parquet and self.write_csv: - print("Confused between reading .csv or parquet. Assuming input file is .parquet") + print( + "Confused between reading .csv or parquet. Assuming input file is .parquet" + ) if self.write_parquet: self.extension = "parquet" elif self.write_csv: @@ -252,7 +289,9 @@ def get_output_filename(self, ftype: str, sim_id: int, extension_override: str = extension_override=extension_override, ) - def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: str = ""): + def get_filename( + self, ftype: str, sim_id: int, input: bool, extension_override: str = "" + ): """return a CSP formated filename.""" if extension_override: # empty strings are Falsy @@ -281,7 +320,9 @@ def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: def get_setup_name(self): return self.setup_name - def read_simID(self, ftype: str, sim_id: int, input: bool = True, extension_override: str = ""): + def read_simID( + self, ftype: str, sim_id: int, input: bool = True, extension_override: str = "" + ): fname = self.get_filename( ftype=ftype, sim_id=sim_id, diff --git a/flepimop/gempyor_pkg/src/gempyor/outcomes.py b/flepimop/gempyor_pkg/src/gempyor/outcomes.py index 5563f4d85..e5f09c1d2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/outcomes.py +++ b/flepimop/gempyor_pkg/src/gempyor/outcomes.py @@ -22,7 +22,9 @@ def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1): sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots) loaded_values = None - if (n_jobs == 1) or (modinf.nslots == 1): # run single process for debugging/profiling purposes + if (n_jobs == 1) or ( + modinf.nslots == 1 + ): # run single process for debugging/profiling purposes for sim_offset in np.arange(nslots): onerun_delayframe_outcomes( sim_id2write=sim_id2writes[sim_offset], @@ -100,7 +102,9 @@ def onerun_delayframe_outcomes( npi_outcomes = None if modinf.npi_config_outcomes: - npi_outcomes = build_outcome_modifiers(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_outcomes = build_outcome_modifiers( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) loaded_values = None if load_ID: @@ -117,7 +121,13 @@ def onerun_delayframe_outcomes( ) with Timer("onerun_delayframe_outcomes.postprocess"): - postprocess_and_write(sim_id=sim_id2write, modinf=modinf, outcomes_df=outcomes_df, hpar=hpar, npi=npi_outcomes) + postprocess_and_write( + sim_id=sim_id2write, + modinf=modinf, + outcomes_df=outcomes_df, + hpar=hpar, + npi=npi_outcomes, + ) def read_parameters_from_config(modinf: model_info.ModelInfo): @@ -129,7 +139,10 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): if modinf.outcomes_config["param_from_file"].exists(): if modinf.outcomes_config["param_from_file"].get(): # Load the actual csv file - branching_file = modinf.path_prefix / modinf.outcomes_config["param_subpop_file"].as_str() + branching_file = ( + modinf.path_prefix + / modinf.outcomes_config["param_subpop_file"].as_str() + ) branching_data = pa.parquet.read_table(branching_file).to_pandas() if "relative_probability" not in list(branching_data["quantity"]): raise ValueError( @@ -142,14 +155,18 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): "", end="", ) - branching_data = branching_data[branching_data["subpop"].isin(modinf.subpop_struct.subpop_names)] + branching_data = branching_data[ + branching_data["subpop"].isin(modinf.subpop_struct.subpop_names) + ] print( "Intersect with seir simulation: ", len(branching_data.subpop.unique()), "kept", ) - if len(branching_data.subpop.unique()) != len(modinf.subpop_struct.subpop_names): + if len(branching_data.subpop.unique()) != len( + modinf.subpop_struct.subpop_names + ): raise ValueError( f"Places in seir input files does not correspond to subpops in outcome probability file {branching_file}" ) @@ -170,10 +187,14 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"unsure how to read outcome {new_comp}: not a str, nor an incidence or prevalence: {src_name}" ) - parameters[new_comp]["probability"] = outcomes_config[new_comp]["probability"]["value"] + parameters[new_comp]["probability"] = outcomes_config[new_comp][ + "probability" + ]["value"] if outcomes_config[new_comp]["probability"]["modifier_parameter"].exists(): parameters[new_comp]["probability::npi_param_name"] = ( - outcomes_config[new_comp]["probability"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["probability"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"probability of outcome {new_comp} is affected by intervention " @@ -181,13 +202,19 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::probability" ) else: - parameters[new_comp]["probability::npi_param_name"] = f"{new_comp}::probability".lower() + parameters[new_comp][ + "probability::npi_param_name" + ] = f"{new_comp}::probability".lower() if outcomes_config[new_comp]["delay"].exists(): - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] if outcomes_config[new_comp]["delay"]["modifier_parameter"].exists(): parameters[new_comp]["delay::npi_param_name"] = ( - outcomes_config[new_comp]["delay"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["delay"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"delay of outcome {new_comp} is affected by intervention " @@ -195,18 +222,28 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::delay" ) else: - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() else: logging.critical(f"No delay for outcome {new_comp}, using a 0 delay") outcomes_config[new_comp]["delay"] = {"value": 0} - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() if outcomes_config[new_comp]["duration"].exists(): - parameters[new_comp]["duration"] = outcomes_config[new_comp]["duration"]["value"] + parameters[new_comp]["duration"] = outcomes_config[new_comp][ + "duration" + ]["value"] if outcomes_config[new_comp]["duration"]["modifier_parameter"].exists(): parameters[new_comp]["duration::npi_param_name"] = ( - outcomes_config[new_comp]["duration"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["duration"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"duration of outcome {new_comp} is affected by intervention " @@ -214,7 +251,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::duration" ) else: - parameters[new_comp]["duration::npi_param_name"] = f"{new_comp}::duration".lower() + parameters[new_comp][ + "duration::npi_param_name" + ] = f"{new_comp}::duration".lower() if outcomes_config[new_comp]["duration"]["name"].exists(): parameters[new_comp]["outcome_prevalence_name"] = ( @@ -231,14 +270,22 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): & (branching_data["quantity"] == "relative_probability") ].copy(deep=True) if len(rel_probability) > 0: - logging.debug(f"Using 'param_from_file' for relative probability in outcome {new_comp}") + logging.debug( + f"Using 'param_from_file' for relative probability in outcome {new_comp}" + ) # Sort it in case the relative probablity file is mispecified - rel_probability.subpop = rel_probability.subpop.astype("category") - rel_probability.subpop = rel_probability.subpop.cat.set_categories( - modinf.subpop_struct.subpop_names + rel_probability.subpop = rel_probability.subpop.astype( + "category" + ) + rel_probability.subpop = ( + rel_probability.subpop.cat.set_categories( + modinf.subpop_struct.subpop_names + ) ) rel_probability = rel_probability.sort_values(["subpop"]) - parameters[new_comp]["rel_probability"] = rel_probability["value"].to_numpy() + parameters[new_comp]["rel_probability"] = rel_probability[ + "value" + ].to_numpy() else: logging.debug( f"*NOT* Using 'param_from_file' for relative probability in outcome {new_comp}" @@ -348,7 +395,9 @@ def compute_all_multioutcomes( outcome_name=new_comp, ) else: - raise ValueError(f"Unknown type for seir simulation provided, got f{type(seir_sim)}") + raise ValueError( + f"Unknown type for seir simulation provided, got f{type(seir_sim)}" + ) # we don't keep source in this cases else: # already defined outcomes if source_name in all_data: @@ -358,16 +407,22 @@ def compute_all_multioutcomes( f"ERROR with outcome {new_comp}: the specified source {source_name} is not a dictionnary (for seir outcome) nor an existing pre-identified outcomes." ) - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): ## This may be unnecessary probabilities = loaded_values[ - (loaded_values["quantity"] == "probability") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "probability") + & (loaded_values["outcome"] == new_comp) + ]["value"].to_numpy() + delays = loaded_values[ + (loaded_values["quantity"] == "delay") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() - delays = loaded_values[(loaded_values["quantity"] == "delay") & (loaded_values["outcome"] == new_comp)][ - "value" - ].to_numpy() else: - probabilities = parameters[new_comp]["probability"].as_random_distribution()( + probabilities = parameters[new_comp][ + "probability" + ].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop if "rel_probability" in parameters[new_comp]: @@ -378,8 +433,12 @@ def compute_all_multioutcomes( ) # one draw per subpop probabilities[probabilities > 1] = 1 probabilities[probabilities < 0] = 0 - probabilities = np.repeat(probabilities[:, np.newaxis], len(dates), axis=1).T # duplicate in time - delays = np.repeat(delays[:, np.newaxis], len(dates), axis=1).T # duplicate in time + probabilities = np.repeat( + probabilities[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time + delays = np.repeat( + delays[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time delays = np.round(delays).astype(int) # Write hpar before NPI subpop_names_len = len(modinf.subpop_struct.subpop_names) @@ -387,7 +446,7 @@ def compute_all_multioutcomes( { "subpop": 2 * modinf.subpop_struct.subpop_names, "quantity": (subpop_names_len * ["probability"]) - + (subpop_names_len * ["delay"]), + + (subpop_names_len * ["delay"]), "outcome": 2 * subpop_names_len * [new_comp], "value": np.concatenate( ( @@ -402,42 +461,59 @@ def compute_all_multioutcomes( if npi is not None: delays = NPI.reduce_parameter( parameter=delays, - modification=npi.getReduction(parameters[new_comp]["delay::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["delay::npi_param_name"].lower() + ), ) delays = np.round(delays).astype(int) probabilities = NPI.reduce_parameter( parameter=probabilities, - modification=npi.getReduction(parameters[new_comp]["probability::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["probability::npi_param_name"].lower() + ), ) # Create new compartment incidence: all_data[new_comp] = np.empty_like(source_array) # Draw with from source compartment if modinf.stoch_traj_flag: - all_data[new_comp] = np.random.binomial(source_array.astype(np.int32), probabilities) + all_data[new_comp] = np.random.binomial( + source_array.astype(np.int32), probabilities + ) else: - all_data[new_comp] = source_array * (probabilities * np.ones_like(source_array)) + all_data[new_comp] = source_array * ( + probabilities * np.ones_like(source_array) + ) # Shift to account for the delay ## stoch_delay_flag is whether to use stochastic delays or not stoch_delay_flag = False - all_data[new_comp] = multishift(all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag) + all_data[new_comp] = multishift( + all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag + ) # Produce a dataframe an merge it - df_p = dataframe_from_array(all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Make duration if "duration" in parameters[new_comp]: - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): durations = loaded_values[ - (loaded_values["quantity"] == "duration") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "duration") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() else: durations = parameters[new_comp]["duration"].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop - durations = np.repeat(durations[:, np.newaxis], len(dates), axis=1).T # duplicate in time + durations = np.repeat( + durations[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time durations = np.round(durations).astype(int) hpar = pd.DataFrame( data={ @@ -458,7 +534,9 @@ def compute_all_multioutcomes( # print(f"{new_comp}-duration".lower(), npi.getReduction(f"{new_comp}-duration".lower())) durations = NPI.reduce_parameter( parameter=durations, - modification=npi.getReduction(parameters[new_comp]["duration::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["duration::npi_param_name"].lower() + ), ) # npi.getReduction(f"{new_comp}::duration".lower())) durations = np.round(durations).astype(int) # plt.imshow(durations) @@ -492,7 +570,9 @@ def compute_all_multioutcomes( for cmp in parameters[new_comp]["sum"]: sum_outcome += all_data[cmp] all_data[new_comp] = sum_outcome - df_p = dataframe_from_array(sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Concat our hpar dataframes hpar = ( @@ -554,7 +634,9 @@ def filter_seir_xr(diffI, dates, subpops, filters, outcome_name) -> np.ndarray: if isinstance(mc_value, str): mc_value = [mc_value] # Filter data along the specified mc_type dimension - diffI_filtered = diffI_filtered.where(diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True) + diffI_filtered = diffI_filtered.where( + diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True + ) # Sum along the compartment dimension incidI_arr += diffI_filtered[vtype].sum(dim="compartment") @@ -626,7 +708,9 @@ def multishift(arr, shifts, stoch_delay_flag=True): # for k,case in enumerate(cases): # results[i+k][j] = cases[k] else: - for i in range(arr.shape[0]): # numba nopython does not allow iterating over 2D array + for i in range( + arr.shape[0] + ): # numba nopython does not allow iterating over 2D array for j in range(arr.shape[1]): if i + shifts[i, j] < arr.shape[0]: result[i + shifts[i, j], j] += arr[i, j] diff --git a/flepimop/gempyor_pkg/src/gempyor/parameters.py b/flepimop/gempyor_pkg/src/gempyor/parameters.py index 79292fcd9..689f846b8 100644 --- a/flepimop/gempyor_pkg/src/gempyor/parameters.py +++ b/flepimop/gempyor_pkg/src/gempyor/parameters.py @@ -95,20 +95,14 @@ def __init__( # Parameter characterized by it's distribution if self.pconfig[pn]["value"].exists(): - self.pdata[pn]["dist"] = self.pconfig[pn][ - "value" - ].as_random_distribution() + self.pdata[pn]["dist"] = self.pconfig[pn]["value"].as_random_distribution() # Parameter given as a file elif self.pconfig[pn]["timeseries"].exists(): - fn_name = os.path.join( - path_prefix, self.pconfig[pn]["timeseries"].get() - ) + fn_name = os.path.join(path_prefix, self.pconfig[pn]["timeseries"].get()) df = utils.read_df(fn_name).set_index("date") df.index = pd.to_datetime(df.index) - if ( - len(df.columns) == 1 - ): # if only one ts, assume it applies to all subpops + if len(df.columns) == 1: # if only one ts, assume it applies to all subpops df = pd.DataFrame( pd.concat([df] * len(subpop_names), axis=1).values, index=df.index, @@ -165,9 +159,9 @@ def __init__( "rolling_mean_windows" ].get() - self.stacked_modifier_method[ - self.pdata[pn]["stacked_modifier_method"] - ].append(pn.lower()) + self.stacked_modifier_method[self.pdata[pn]["stacked_modifier_method"]].append( + pn.lower() + ) logging.debug(f"We have {self.npar} parameter: {self.pnames}") logging.debug(f"Data to sample is: {self.pdata}") @@ -315,9 +309,7 @@ def getParameterDF(self, p_draw: ndarray) -> pd.DataFrame: if "dist" in self.pdata[pn] ], columns=["value"], - index=[ - pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn] - ], + index=[pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn]], ) out_df["parameter"] = out_df.index return out_df diff --git a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py index 2b42e2944..684fa886c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py @@ -21,7 +21,15 @@ import pandas as pd import pyarrow.parquet as pq import xarray as xr -from gempyor import config, model_info, outcomes, seir, inference_parameter, logloss, inference +from gempyor import ( + config, + model_info, + outcomes, + seir, + inference_parameter, + logloss, + inference, +) from gempyor.inference import GempyorInference import tqdm import os @@ -37,7 +45,9 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): last_llik = sampler_output.get_log_prob()[-1, :] sampled_slots = last_llik > (last_llik.mean() - 1 * last_llik.std()) - print(f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these") + print( + f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these" + ) # TODO this function give back good_samples = sampler.get_chain()[:, sampled_slots, :] @@ -50,9 +60,9 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): ] # parentesis around i//(sampled_slots.sum() are very important - - -def plot_chains(inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None): +def plot_chains( + inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None +): """ Plot the chains of the inference :param inferpar: the inference parameter object @@ -113,24 +123,44 @@ def plot_single_chain(frompt, ax, chain, label, gt=None): for sp in tqdm.tqdm(set(inferpar.subpops)): # find unique supopulation these_pars = inferpar.get_parameters_for_subpop(sp) - fig, axes = plt.subplots(max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2)) + fig, axes = plt.subplots( + max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2) + ) for idx, par_id in enumerate(these_pars): - plot_single_chain(first_thresh, axes[idx, 0], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) - plot_single_chain(second_thresh, axes[idx, 1], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) + plot_single_chain( + first_thresh, + axes[idx, 0], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) + plot_single_chain( + second_thresh, + axes[idx, 1], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) fig.tight_layout() pdf.savefig(fig) plt.close(fig) + def plot_fit(modinf, loss): subpop_names = modinf.subpop_struct.subpop_names fig, axes = plt.subplots( - len(subpop_names), len(loss.statistics), figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), sharex=True + len(subpop_names), + len(loss.statistics), + figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), + sharex=True, ) for j, subpop in enumerate(modinf.subpop_struct.subpop_names): gt_s = loss.gt[loss.gt["subpop"] == subpop].sort_index() first_date = max(gt_s.index.min(), results[0].index.min()) last_date = min(gt_s.index.max(), results[0].index.max()) - gt_s = gt_s.loc[first_date:last_date].drop(["subpop"], axis=1).resample("W-SAT").sum() + gt_s = ( + gt_s.loc[first_date:last_date].drop(["subpop"], axis=1).resample("W-SAT").sum() + ) for i, (stat_name, stat) in enumerate(loss.statistics.items()): ax = axes[j, i] diff --git a/flepimop/gempyor_pkg/src/gempyor/seeding.py b/flepimop/gempyor_pkg/src/gempyor/seeding.py index fe58657c0..53b81c8dc 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seeding.py +++ b/flepimop/gempyor_pkg/src/gempyor/seeding.py @@ -17,9 +17,13 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: if not df["date"].is_monotonic_increasing: - raise ValueError("_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense") + raise ValueError( + "_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense" + ) - cmp_grp_names = [col for col in modinf.compartments.compartments.columns if col != "name"] + cmp_grp_names = [ + col for col in modinf.compartments.compartments.columns if col != "name" + ] seeding_dict: nb.typed.Dict = nb.typed.Dict.empty( key_type=nb.types.unicode_type, value_type=nb.types.int64[:], @@ -45,16 +49,25 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: nb_seed_perday[(row["date"].date() - modinf.ti).days] = ( nb_seed_perday[(row["date"].date() - modinf.ti).days] + 1 ) - source_dict = {grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names} - destination_dict = {grp_name: row[f"destination_{grp_name}"] for grp_name in cmp_grp_names} + source_dict = { + grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names + } + destination_dict = { + grp_name: row[f"destination_{grp_name}"] for grp_name in cmp_grp_names + } seeding_dict["seeding_sources"][idx] = modinf.compartments.get_comp_idx( - source_dict, error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)" + source_dict, + error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) + seeding_dict["seeding_destinations"][idx] = ( + modinf.compartments.get_comp_idx( + destination_dict, + error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) ) - seeding_dict["seeding_destinations"][idx] = modinf.compartments.get_comp_idx( - destination_dict, - error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + seeding_dict["seeding_subpops"][idx] = ( + modinf.subpop_struct.subpop_names.index(row["subpop"]) ) - seeding_dict["seeding_subpops"][idx] = modinf.subpop_struct.subpop_names.index(row["subpop"]) seeding_amounts[idx] = amounts[idx] # id_seed+=1 else: @@ -97,7 +110,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: ) dupes = seeding[seeding.duplicated(["subpop", "date"])].index + 1 if not dupes.empty: - raise ValueError(f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file.") + raise ValueError( + f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file." + ) elif method == "FolderDraw": seeding = pd.read_csv( self.path_prefix @@ -127,7 +142,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: # print(seeding.shape) seeding = seeding.sort_values(by="date", axis="index").reset_index() # print(seeding) - mask = (seeding["date"].dt.date > modinf.ti) & (seeding["date"].dt.date <= modinf.tf) + mask = (seeding["date"].dt.date > modinf.ti) & ( + seeding["date"].dt.date <= modinf.tf + ) seeding = seeding.loc[mask].reset_index() # print(seeding.shape) # print(seeding) @@ -138,7 +155,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: if method == "PoissonDistributed": amounts = np.random.poisson(seeding["amount"]) elif method == "NegativeBinomialDistributed": - raise ValueError("Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore.") + raise ValueError( + "Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore." + ) elif method == "FolderDraw" or method == "FromFile": amounts = seeding["amount"] else: diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index 4e59761f2..5ea236c98 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -41,7 +41,9 @@ def build_step_source_arg( else: integration_method = "rk4.jit" dt = 2.0 - logging.info(f"Integration method not provided, assuming type {integration_method} with dt=2") + logging.info( + f"Integration method not provided, assuming type {integration_method} with dt=2" + ) ## The type is very important for the call to the compiled function, and e.g mixing an int64 for an int32 can ## result in serious error. Note that "In Microsoft C, even on a 64 bit system, the size of the long int data type @@ -58,7 +60,10 @@ def build_step_source_arg( assert type(transition_array[0][0]) == np.int64 assert type(proportion_array[0]) == np.int64 assert type(proportion_info[0][0]) == np.int64 - assert initial_conditions.shape == (modinf.compartments.compartments.shape[0], modinf.nsubpops) + assert initial_conditions.shape == ( + modinf.compartments.compartments.shape[0], + modinf.nsubpops, + ) assert type(initial_conditions[0][0]) == np.float64 # Test of empty seeding: assert len(seeding_data.keys()) == 4 @@ -162,7 +167,9 @@ def steps_SEIR( f"with method {integration_method}, only deterministic " f"integration is possible (got stoch_straj_flag={modinf.stoch_traj_flag}" ) - seir_sim = steps_experimental.ode_integration(**fnct_args, integration_method=integration_method) + seir_sim = steps_experimental.ode_integration( + **fnct_args, integration_method=integration_method + ) elif integration_method == "rk4.jit1": seir_sim = steps_experimental.rk4_integration1(**fnct_args) elif integration_method == "rk4.jit2": @@ -200,7 +207,9 @@ def steps_SEIR( **compartment_coords, subpop=modinf.subpop_struct.subpop_names, ), - attrs=dict(description="Dynamical simulation results", run_id=modinf.in_run_id), # TODO add more information + attrs=dict( + description="Dynamical simulation results", run_id=modinf.in_run_id + ), # TODO add more information ) return states @@ -223,8 +232,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, loaded_df=loaded_df, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) else: npi = NPI.NPIBase.execute( @@ -232,8 +245,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modinf=modinf, modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) return npi @@ -248,7 +265,9 @@ def onerun_SEIR( np.random.seed() npi = None if modinf.npi_config_seir: - npi = build_npi_SEIR(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi = build_npi_SEIR( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) with Timer("onerun_SEIR.compartments"): ( @@ -260,11 +279,19 @@ def onerun_SEIR( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = modinf.initial_conditions.get_from_file(sim_id2load, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id2load, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_file( + sim_id2load, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id2load, modinf=modinf + ) else: - initial_conditions = modinf.initial_conditions.get_from_config(sim_id2write, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id2write, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id2write, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_config( + sim_id2write, modinf=modinf + ) with Timer("onerun_SEIR.parameters"): # Draw or load parameters @@ -275,14 +302,18 @@ def onerun_SEIR( nsubpops=modinf.nsubpops, ) else: - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi) log_debug_parameters(p_draw, "Parameters without seir_modifiers") log_debug_parameters(parameters, "Parameters with seir_modifiers") # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) log_debug_parameters(parsed_parameters, "Unique Parameters used by transitions") with Timer("onerun_SEIR.compute"): @@ -310,7 +341,13 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): if n_jobs == 1: # run single process for debugging/profiling purposes for sim_id in tqdm.tqdm(sim_ids): - onerun_SEIR(sim_id2write=sim_id, modinf=modinf, load_ID=False, sim_id2load=None, config=config) + onerun_SEIR( + sim_id2write=sim_id, + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + ) else: tqdm.contrib.concurrent.process_map( onerun_SEIR, @@ -322,7 +359,9 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): max_workers=n_jobs, ) - logging.info(f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""") + logging.info( + f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""" + ) def states2Df(modinf, states): @@ -337,12 +376,17 @@ def states2Df(modinf, states): # states_diff = np.diff(states_diff, axis=0) ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) # prevalence data, we use multi.index dataframe, sparring us the array manipulation we use to do prev_df = pd.DataFrame( - data=states["prevalence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["prevalence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -355,12 +399,17 @@ def states2Df(modinf, states): prev_df.insert(loc=0, column="mc_value_type", value="prevalence") ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) incid_df = pd.DataFrame( - data=states["incidence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["incidence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -384,7 +433,9 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi): if npi is not None: modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF()) # Parameters - modinf.write_simID(ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw)) + modinf.write_simID( + ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw) + ) def write_seir(sim_id, modinf, states): diff --git a/flepimop/gempyor_pkg/src/gempyor/simulate.py b/flepimop/gempyor_pkg/src/gempyor/simulate.py index 34fdd9d4b..d97b94a93 100644 --- a/flepimop/gempyor_pkg/src/gempyor/simulate.py +++ b/flepimop/gempyor_pkg/src/gempyor/simulate.py @@ -299,23 +299,31 @@ def simulate( seir_modifiers_scenarios = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - seir_modifiers_scenarios = config["seir_modifiers"]["scenarios"].as_str_seq() + seir_modifiers_scenarios = config["seir_modifiers"][ + "scenarios" + ].as_str_seq() # Model Info handles the case of the default scneario if not outcome_modifiers_scenarios: outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) print(outcome_modifiers_scenarios, seir_modifiers_scenarios) - scenarios_combinations = [[s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios] + scenarios_combinations = [ + [s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios + ] print("Combination of modifiers scenarios to be run: ") print(scenarios_combinations) for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations: - print(f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}") + print( + f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}" + ) if not nslots: nslots = config["nslots"].as_number() @@ -354,7 +362,9 @@ def simulate( if config["seir"].exists(): seir.run_parallel_SEIR(modinf, config=config, n_jobs=jobs) if config["outcomes"].exists(): - outcomes.run_parallel_outcomes(sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs) + outcomes.run_parallel_outcomes( + sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs + ) print( f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} completed in {time.monotonic() - start:.1f} seconds" ) diff --git a/flepimop/gempyor_pkg/src/gempyor/steps_source.py b/flepimop/gempyor_pkg/src/gempyor/steps_source.py index b8af1d493..ea6f7ad21 100644 --- a/flepimop/gempyor_pkg/src/gempyor/steps_source.py +++ b/flepimop/gempyor_pkg/src/gempyor/steps_source.py @@ -30,9 +30,12 @@ ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -84,7 +87,11 @@ def steps_SEIR_nb( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -111,18 +118,22 @@ def steps_SEIR_nb( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_infected = 0 for transition_index in range(ntransitions): @@ -138,9 +149,13 @@ def steps_SEIR_nb( proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( transitions[transition_proportion_start_col][transition_index] + 1 @@ -149,41 +164,56 @@ def steps_SEIR_nb( source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -221,14 +251,20 @@ def steps_SEIR_nb( for spatial_node in range(nspatial_nodes): if ( number_move[spatial_node] - > states_next[transitions[transition_source_col][transition_index]][spatial_node] - ): - number_move[spatial_node] = states_next[transitions[transition_source_col][transition_index]][ + > states_next[transitions[transition_source_col][transition_index]][ spatial_node ] + ): + number_move[spatial_node] = states_next[ + transitions[transition_source_col][transition_index] + ][spatial_node] states_next[transitions[transition_source_col][transition_index]] -= number_move - states_next[transitions[transition_destination_col][transition_index]] += number_move - states_daily_incid[today, transitions[transition_destination_col][transition_index], :] += number_move + states_next[ + transitions[transition_destination_col][transition_index] + ] += number_move + states_daily_incid[ + today, transitions[transition_destination_col][transition_index], : + ] += number_move states_current = states_next.copy() diff --git a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py index 1e2b9b8de..7c46a9960 100644 --- a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py +++ b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py @@ -27,7 +27,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.setup_name = setup_name self.data = pd.read_csv( - geodata_file, converters={subpop_names_key: lambda x: str(x).strip()}, skipinitialspace=True + geodata_file, + converters={subpop_names_key: lambda x: str(x).strip()}, + skipinitialspace=True, ) # subpops and populations, strip whitespaces self.nsubpops = len(self.data) # K = # of locations @@ -44,7 +46,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): # subpop_names_key is the name of the column in geodata_file with subpops if subpop_names_key not in self.data: - raise ValueError(f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata.") + raise ValueError( + f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata." + ) self.subpop_names = self.data[subpop_names_key].tolist() if len(self.subpop_names) != len(set(self.subpop_names)): raise ValueError(f"There are duplicate subpop_names in geodata.") @@ -53,7 +57,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): mobility_file = path_prefix / subpop_config["mobility"].get() mobility_file = pathlib.Path(mobility_file) if mobility_file.suffix == ".txt": - print("Mobility files as matrices are not recommended. Please switch soon to long form csv files.") + print( + "Mobility files as matrices are not recommended. Please switch soon to long form csv files." + ) self.mobility = scipy.sparse.csr_matrix( np.loadtxt(mobility_file), dtype=int ) # K x K matrix of people moving @@ -64,7 +70,11 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) elif mobility_file.suffix == ".csv": - mobility_data = pd.read_csv(mobility_file, converters={"ori": str, "dest": str}, skipinitialspace=True) + mobility_data = pd.read_csv( + mobility_file, + converters={"ori": str, "dest": str}, + skipinitialspace=True, + ) nn_dict = {v: k for k, v in enumerate(self.subpop_names)} mobility_data["ori_idx"] = mobility_data["ori"].apply(nn_dict.__getitem__) mobility_data["dest_idx"] = mobility_data["dest"].apply(nn_dict.__getitem__) @@ -115,7 +125,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) else: logging.critical("No mobility matrix specified -- assuming no one moves") - self.mobility = scipy.sparse.csr_matrix(np.zeros((self.nsubpops, self.nsubpops)), dtype=int) + self.mobility = scipy.sparse.csr_matrix( + np.zeros((self.nsubpops, self.nsubpops)), dtype=int + ) if subpop_config["selected"].exists(): selected = subpop_config["selected"].get() @@ -129,4 +141,6 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.subpop_names = selected self.nsubpops = len(self.data) # TODO: this needs to be tested - self.mobility = self.mobility[selected_subpop_indices][:, selected_subpop_indices] + self.mobility = self.mobility[selected_subpop_indices][ + :, selected_subpop_indices + ] diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 25c33d574..377dcdee5 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -99,7 +99,9 @@ def read_df( ) -def command_safe_run(command: str, command_name: str="mycommand", fail_on_fail: bool=True) -> tuple[int, str, str]: +def command_safe_run( + command: str, command_name: str = "mycommand", fail_on_fail: bool = True +) -> tuple[int, str, str]: """ Runs a shell command and prints diagnostics if command fails. @@ -110,14 +112,16 @@ def command_safe_run(command: str, command_name: str="mycommand", fail_on_fail: Returns: As a tuple; the return code, the standard output, and standard error from running the command. - + Raises: Exception: If fail_on_fail=True and the command fails, an exception will be thrown. """ import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen - sr = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + sr = subprocess.Popen( + shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) (stdout, stderr) = sr.communicate() if sr.returncode != 0: print(f"{command_name} failed failed with returncode {sr.returncode}") @@ -143,7 +147,7 @@ def add_method(cls): Args: cls: The class you want to add a method to. - + Returns: decorator: The decorator. """ @@ -159,7 +163,9 @@ def wrapper(*args, **kwargs): return decorator -def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, class_name: str, **kwargs: dict[str, Any]) -> Any: +def search_and_import_plugins_class( + plugin_file_path: str, path_prefix: str, class_name: str, **kwargs: dict[str, Any] +) -> Any: """ Function serving to create a class that finds and imports the necessary modules. @@ -174,11 +180,11 @@ def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, cla Examples: Suppose there is a module called `my_plugin.py with a class `MyClass` located at `/path/to/plugin/`. - + Dynamically import and instantiate the class: >>> instance = search_and_import_plugins_class('/path/to/plugin', path_prefix, 'MyClass', **params) - + View the instance: >>> print(instance) @@ -206,28 +212,33 @@ def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, cla from functools import wraps -def profile(output_file: str = None, sort_by: str = "cumulative", lines_to_print: int = None, strip_dirs: bool = False): +def profile( + output_file: str = None, + sort_by: str = "cumulative", + lines_to_print: int = None, + strip_dirs: bool = False, +): """ A time profiler decorator. Inspired by and modified the profile decorator of Giampaolo Rodola: http://code.activestate.com/recipes/577817-profile-decorator/ Args: - output_file: + output_file: Path of the output file. If only name of the file is given, it's saved in the current directory. If it's None, the name of the decorated function is used. - sort_by: + sort_by: Sorting criteria for the Stats object. For a list of valid string and SortKey refer to: https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats - lines_to_print: + lines_to_print: Number of lines to print. Default (None) is for all the lines. This is useful in reducing the size of the printout, especially that sorting by 'cumulative', the time consuming operations are printed toward the top of the file. - strip_dirs: - Whether to remove the leading path info from file names. + strip_dirs: + Whether to remove the leading path info from file names. Returns: Profile of the decorated function. @@ -238,7 +249,7 @@ def profile(output_file: str = None, sort_by: str = "cumulative", lines_to_print # Function body content pass >>> my_function() - After running ``my_function``, a file named ``my_function.prof`` will be created in the current WD. + After running ``my_function``, a file named ``my_function.prof`` will be created in the current WD. This file contains the profiling data. """ @@ -282,6 +293,7 @@ class Timer(object): name: Name of event. tstart: Time start. """ + def __init__(self, name): self.name = name @@ -297,6 +309,7 @@ class ISO8601Date(confuse.Template): """ Reads in config dates into datetimes.dates. """ + def convert(self, value: any, view: confuse.ConfigView): """ Converts the given value to a datetime.date object. @@ -331,7 +344,7 @@ def as_date(self) -> datetime.date: @add_method(confuse.ConfigView) def as_evaled_expression(self): """ - Evaluates an expression string, returning a float. + Evaluates an expression string, returning a float. Returns: A float data type of the value associated with the object. @@ -373,7 +386,7 @@ def get_truncated_normal( Returns: rv_frozen: A frozen instance of the truncated normal distribution with the specified parameters. - + Examples: Create a truncated normal distribution with specified parameters (truncated between 1 and 10): >>> truncated_normal_dist = get_truncated_normal(mean=5, sd=2, a=1, b=10) @@ -402,7 +415,7 @@ def get_log_normal( Returns: rv_frozen: A frozen instance of the log normal distribution with the specified parameters. - + Examples: Create a log-normal distribution with specified parameters: >>> log_normal_dist = get_log_normal(meanlog=1, sdlog=0.5) @@ -415,17 +428,17 @@ def get_log_normal( def random_distribution_sampler( distribution: Literal[ "fixed", "uniform", "poisson", "binomial", "truncnorm", "lognorm" - ], - **kwargs: dict[str, Any] + ], + **kwargs: dict[str, Any], ) -> Callable[[], float | int]: """ Create function to sample from a random distribution. - + Args: distribution: The type of distribution to generate a sampling function for. **kwargs: Further parameters that are passed to the underlying function for the given distribution. - + Notes: The further args expected by each distribution type are: - fixed: value, @@ -434,14 +447,14 @@ def random_distribution_sampler( - binomial: n, p, - truncnorm: mean, sd, a, b, - lognorm: meanlog, sdlog. - + Returns: A function that can be called to sample from that distribution. - + Raises: ValueError: If `distribution` is 'binomial' the given `p` must be in (0,1). NotImplementedError: If `distribution` is not one of the type hinted options. - + Examples: >>> import numpy as np >>> np.random.seed(123) @@ -454,15 +467,15 @@ def random_distribution_sampler( if distribution == "fixed": # Fixed value is the same as uniform on [a, a) return functools.partial( - np.random.uniform, - kwargs.get("value"), + np.random.uniform, + kwargs.get("value"), kwargs.get("value"), ) elif distribution == "uniform": # Uniform on [low, high) return functools.partial( - np.random.uniform, - kwargs.get("low"), + np.random.uniform, + kwargs.get("low"), kwargs.get("high"), ) elif distribution == "poisson": @@ -476,9 +489,9 @@ def random_distribution_sampler( elif distribution == "truncnorm": # Truncated normal with mean, sd on interval [a, b] return get_truncated_normal( - mean=kwargs.get("mean"), - sd=kwargs.get("sd"), - a=kwargs.get("a"), + mean=kwargs.get("mean"), + sd=kwargs.get("sd"), + a=kwargs.get("a"), b=kwargs.get("b"), ).rvs elif distribution == "lognorm": @@ -497,8 +510,8 @@ def as_random_distribution(self): Returns: A partial object containing the random distribution. - - Raises: + + Raises: ValueError: When values are out of range. NotImplementedError: If an unknown distribution is found. @@ -515,7 +528,7 @@ def as_random_distribution(self): "distribution": "truncnorm", "mean": 0 "sd": 1, - "a": -1, + "a": -1, "b": 1 }) >>> truncnorm_dist_function = config_truncnorm.as_random_distribution() @@ -528,11 +541,15 @@ def as_random_distribution(self): dist = self["distribution"].get() if dist == "fixed": return functools.partial( - np.random.uniform, self["value"].as_evaled_expression(), self["value"].as_evaled_expression(), + np.random.uniform, + self["value"].as_evaled_expression(), + self["value"].as_evaled_expression(), ) elif dist == "uniform": return functools.partial( - np.random.uniform, self["low"].as_evaled_expression(), self["high"].as_evaled_expression(), + np.random.uniform, + self["low"].as_evaled_expression(), + self["high"].as_evaled_expression(), ) elif dist == "poisson": return functools.partial(np.random.poisson, self["lam"].as_evaled_expression()) @@ -557,13 +574,18 @@ def as_random_distribution(self): ).rvs elif dist == "lognorm": return get_log_normal( - meanlog=self["meanlog"].as_evaled_expression(), sdlog=self["sdlog"].as_evaled_expression(), + meanlog=self["meanlog"].as_evaled_expression(), + sdlog=self["sdlog"].as_evaled_expression(), ).rvs else: raise NotImplementedError(f"unknown distribution [got: {dist}]") else: # we allow a fixed value specified directly: - return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) + return functools.partial( + np.random.uniform, + self.as_evaled_expression(), + self.as_evaled_expression(), + ) def list_filenames( @@ -578,9 +600,9 @@ def list_filenames( in the filters will be returned. Args: - folder: + folder: The directory to search for files. Defaults to the current directory. - filters: + filters: A string or a list of strings to filter filenames. Only files containing all the provided substrings will be returned. Defaults to an empty list. @@ -652,14 +674,14 @@ def rolling_mean_pad( [22.6, 23.6, 24.6, 25.6]]) ``` """ - weights = (1. / window) * np.ones(window) + weights = (1.0 / window) * np.ones(window) output = scipy.ndimage.convolve1d(data, weights, axis=0, mode="nearest") if window % 2 == 0: rows, cols = data.shape i = rows - 1 - output[i, :] = 0. + output[i, :] = 0.0 window -= 1 - weight = 1. / window + weight = 1.0 / window for l in range(-((window - 1) // 2), 1 + (window // 2)): i_star = min(max(i + l, 0), i) for j in range(cols): @@ -705,7 +727,12 @@ def bash(command: str) -> str: def create_resume_out_filename( - flepi_run_index: str, flepi_prefix: str, flepi_slot_index: str, flepi_block_index: str, filetype: str, liketype: str + flepi_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + flepi_block_index: str, + filetype: str, + liketype: str, ) -> str: """ Compiles run output information. @@ -717,10 +744,10 @@ def create_resume_out_filename( flepi_block_index: Index of the block. filetype: File type. liketype: Chimeric or global. - + Returns: The path to a corresponding output file. - + Examples: Generate an output file with specified parameters: >>> filename = create_resume_out_filename( @@ -753,7 +780,11 @@ def create_resume_out_filename( def create_resume_input_filename( - resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str + resume_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + filetype: str, + liketype: str, ) -> str: """ Compiles run input information. @@ -764,10 +795,10 @@ def create_resume_input_filename( flepi_slot_index: Index of the slot. filetype: File type. liketype: Chimeric or global. - + Returns: The path to the a corresponding input file. - + Examples: Generate an input file with specified parameters: >>> filename = create_resume_input_filename( @@ -796,10 +827,12 @@ def create_resume_input_filename( ) -def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) -> list[str]: +def get_filetype_for_resume( + resume_discard_seeding: str, flepi_block_index: str +) -> list[str]: """ Retrieves a list of parquet file types that are relevant for resuming a process based on - specific environment variable settings. + specific environment variable settings. This function dynamically determines the list based on the current operational context given by the environment. @@ -809,7 +842,7 @@ def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) Returns: List of file types. - + Examples: Determine file types for block index 1 with seeding data NOT discarded: >>> filetypes = get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") @@ -855,7 +888,7 @@ def create_resume_file_names_map( Returns: A dictionary where keys are input file paths and values are corresponding - output file paths. + output file paths. The mappings depend on: - Parquet file types appropriate for resuming a process, as determined by the environment. @@ -870,7 +903,7 @@ def create_resume_file_names_map( No explicit exceptions are raised within the function, but it relies heavily on external functions and environment variables which if improperly configured could lead to unexpected behavior. - + Examples: Generate a mapping of file names for a given resume process: >>> file_names_map = create_resume_file_names_map( @@ -935,7 +968,7 @@ def download_file_from_s3(name_map: dict[str, str]) -> None: local path, and handles errors if the S3 URI format is incorrect or if the download fails. Args: - name_map: + name_map: A dictionary where keys are S3 URIs (strings) and values are the local file paths (strings) where the files should be saved. @@ -969,10 +1002,12 @@ def download_file_from_s3(name_map: dict[str, str]) -> None: import boto3 from botocore.exceptions import ClientError except ModuleNotFoundError: - raise ModuleNotFoundError(( - "No module named 'boto3', which is required for " - "gempyor.utils.download_file_from_s3. Please install the aws target." - )) + raise ModuleNotFoundError( + ( + "No module named 'boto3', which is required for " + "gempyor.utils.download_file_from_s3. Please install the aws target." + ) + ) s3 = boto3.client("s3") first_output_filename = next(iter(name_map.values())) output_dir = os.path.dirname(first_output_filename) @@ -994,9 +1029,9 @@ def download_file_from_s3(name_map: dict[str, str]) -> None: def move_file_at_local(name_map: dict[str, str]) -> None: """ Moves files locally according to a given mapping. - This function takes a dictionary where the keys are source file paths and - the values are destination file paths. It ensures that the destination - directories exist and then copies the files from the source paths to the + This function takes a dictionary where the keys are source file paths and + the values are destination file paths. It ensures that the destination + directories exist and then copies the files from the source paths to the destination paths. Args: diff --git a/flepimop/gempyor_pkg/tests/file_paths/test_run_id.py b/flepimop/gempyor_pkg/tests/file_paths/test_run_id.py index 2e3f1c79b..cb8513ad6 100644 --- a/flepimop/gempyor_pkg/tests/file_paths/test_run_id.py +++ b/flepimop/gempyor_pkg/tests/file_paths/test_run_id.py @@ -37,9 +37,7 @@ def test_get_run_id_default_timestamp(self) -> None: datetime(2023, 8, 9, 16, 0, 0), ], ) - def test_get_run_id_user_provided_timestamp( - self, timestamp: None | datetime - ) -> None: + def test_get_run_id_user_provided_timestamp(self, timestamp: None | datetime) -> None: # Setup rid = run_id(timestamp=timestamp) diff --git a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py index 7e3c2bc59..704949cfb 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py @@ -49,7 +49,9 @@ def test_SinglePeriodModifier_start_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, match=r".*at least one period start or end date is not between.*" + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, @@ -72,7 +74,9 @@ def test_SinglePeriodModifier_end_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, match=r".*at least one period start or end date is not between.*" + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, diff --git a/flepimop/gempyor_pkg/tests/npi/test_npis.py b/flepimop/gempyor_pkg/tests/npi/test_npis.py index bad306b1e..c2a4047ac 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_npis.py +++ b/flepimop/gempyor_pkg/tests/npi/test_npis.py @@ -47,12 +47,18 @@ def test_full_npis_read_write(): # inference_simulator.s, load_ID=False, sim_id2load=None, config=config # ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -74,10 +80,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -98,10 +110,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() @@ -116,10 +134,15 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) # all independent: r1 - assert len(npi.getReduction("r1")["2021-01-01"].unique()) == inference_simulator.modinf.nsubpops + assert ( + len(npi.getReduction("r1")["2021-01-01"].unique()) + == inference_simulator.modinf.nsubpops + ) assert npi.getReduction("r1").isna().sum().sum() == 0 # all the same: r2 @@ -127,7 +150,10 @@ def test_spatial_groups(): assert npi.getReduction("r2").isna().sum().sum() == 0 # two groups: r3 - assert len(npi.getReduction("r3")["2020-04-15"].unique()) == inference_simulator.modinf.nsubpops - 2 + assert ( + len(npi.getReduction("r3")["2020-04-15"].unique()) + == inference_simulator.modinf.nsubpops - 2 + ) assert npi.getReduction("r3").isna().sum().sum() == 0 assert len(npi.getReduction("r3").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 assert len(npi.getReduction("r3").loc[["04000", "06000"], "2020-04-15"].unique()) == 1 @@ -160,7 +186,9 @@ def test_spatial_groups(): # all the same: r2 df = npi_df[npi_df["modifier_name"] == "all_together"] assert len(df) == 1 - assert set(df["subpop"].iloc[0].split(",")) == set(inference_simulator.modinf.subpop_struct.subpop_names) + assert set(df["subpop"].iloc[0].split(",")) == set( + inference_simulator.modinf.subpop_struct.subpop_names + ) assert len(df["subpop"].iloc[0].split(",")) == inference_simulator.modinf.nsubpops # two groups: r3 @@ -175,7 +203,9 @@ def test_spatial_groups(): df = npi_df[npi_df["modifier_name"] == "mt_reduce"] assert len(df) == 4 assert df.subpop.to_list() == ["09000,10000", "02000", "06000", "01000,04000"] - assert df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] == "2020-12-01,2021-12-01" + assert ( + df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] == "2020-12-01,2021-12-01" + ) assert ( df[df["subpop"] == "01000,04000"]["start_date"].iloc[0] == df[df["subpop"] == "06000"]["start_date"].iloc[0] @@ -194,15 +224,21 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) npi_df = npi.getReductionDF() inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_df) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet" + ).to_pandas() snpi_read["value"] = np.random.random(len(snpi_read)) * 2 - 1 out_snpi = pa.Table.from_pandas(snpi_read, preserve_index=False) - pa.parquet.write_table(out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet")) + pa.parquet.write_table( + out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet") + ) inference_simulator = gempyor.GempyorInference( config_filepath=f"{config_filepath_prefix}config_test_spatial_group_npi.yml", @@ -213,11 +249,19 @@ def test_spatial_groups(): out_run_id=107, ) - npi_seir = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config) - inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_seir.getReductionDF()) + npi_seir = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config + ) + inference_simulator.modinf.write_simID( + ftype="snpi", sim_id=1, df=npi_seir.getReductionDF() + ) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet").to_pandas() - snpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet" + ).to_pandas() + snpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet" + ).to_pandas() # now the order can change, so we need to sort by subpop and start_date snpi_wrote = snpi_wrote.sort_values(by=["subpop", "start_date"]).reset_index(drop=True) @@ -225,10 +269,18 @@ def test_spatial_groups(): assert (snpi_read == snpi_wrote).all().all() npi_read = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_read + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_read, ) npi_wrote = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_wrote + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_wrote, ) assert (npi_read.getReductionDF() == npi_wrote.getReductionDF()).all().all() diff --git a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py index 56df652cf..7a192ce13 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py +++ b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py @@ -54,7 +54,9 @@ diffI = np.arange(5) * 2 date_data = datetime.date(2020, 4, 15) for i in range(5): - b.loc[(b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i]] = diffI[i] + b.loc[ + (b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i] + ] = diffI[i] pa_df = pa.Table.from_pandas(b, preserve_index=False) pa.parquet.write_table(pa_df, "new_test_no_vacc.parquet") diff --git a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py index 35f490920..eba5e9ad3 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py +++ b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py @@ -40,87 +40,140 @@ def test_outcome(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 ) @@ -136,10 +189,16 @@ def test_outcome_modifiers_scenario_with_load(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hpar_config = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() - hpar_rel = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() + hpar_config = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() + hpar_rel = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() for out in ["incidH", "incidD", "incidICU"]: for i, place in enumerate(subpop): @@ -171,16 +230,30 @@ def test_outcomes_read_write_hpar(): stoch_traj_flag=False, out_run_id=3, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -201,7 +274,9 @@ def test_multishift_notstochdelays(): [36, 29], ] ) - shifts = np.array([[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]]) + shifts = np.array( + [[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]] + ) expected = np.array( [ [0, 39], @@ -232,87 +307,138 @@ def test_outcomes_npi(): ) outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -330,17 +456,31 @@ def test_outcomes_read_write_hnpi(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -356,17 +496,27 @@ def test_outcomes_read_write_hnpi2(): out_run_id=106, ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -378,16 +528,30 @@ def test_outcomes_read_write_hnpi2(): stoch_traj_flag=False, out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -402,89 +566,142 @@ def test_outcomes_npi_custom_pname(): stoch_traj_flag=False, out_run_id=105, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1 + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -502,16 +719,30 @@ def test_outcomes_read_write_hnpi_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -520,10 +751,14 @@ def test_outcomes_read_write_hnpi2_custom_pname(): prefix = "" - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -537,10 +772,16 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -553,16 +794,30 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -580,7 +835,9 @@ def test_outcomes_pcomp(): ) p_compmult = [1, 3] - seir = pq.read_table(f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet").to_pandas() + seir = pq.read_table( + f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet" + ).to_pandas() seir2 = seir.copy() seir2["mc_vaccination_stage"] = "first_dose" @@ -591,10 +848,16 @@ def test_outcomes_pcomp(): seir2[pl] = seir2[pl] * p_compmult[1] new_seir = pd.concat([seir, seir2]) out_df = pa.Table.from_pandas(new_seir, preserve_index=False) - pa.parquet.write_table(out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet")) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + pa.parquet.write_table( + out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet") + ) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp_f = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() + hosp_f = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() hosp_f.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for k, p_comp in enumerate(["0dose", "1dose"]): @@ -602,42 +865,90 @@ def test_outcomes_pcomp(): for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == diffI[i] * p_compmult[k] assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] + == diffI[i] * p_compmult[k] + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] - diffI[i] * 0.01 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * 0.4 * p_compmult[k] < 1e-8 ) for j in range(7): assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + j)] + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + j) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][ + dt - datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt - datetime.timedelta(4) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt] == 0 - hpar_f = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() + hpar_f = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml # for k, p_comp in enumerate(["unvaccinated", "first_dose"]): for k, p_comp in enumerate(["0dose", "1dose"]): @@ -727,16 +1038,30 @@ def test_outcomes_pcomp_read_write(): out_run_id=112, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() diff --git a/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py b/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py index 4bfb86fd2..0286c9b89 100644 --- a/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py +++ b/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py @@ -323,9 +323,9 @@ def test_parameters_instance_attributes( assert set(params.pdata.keys()) == set(mock_inputs.config.keys()) for param_name, param_conf in mock_inputs.config.items(): assert params.pdata[param_name]["idx"] == params.pnames2pindex[param_name] - assert params.pdata[param_name][ - "stacked_modifier_method" - ] == param_conf.get("stacked_modifier_method", "product") + assert params.pdata[param_name]["stacked_modifier_method"] == param_conf.get( + "stacked_modifier_method", "product" + ) if "timeseries" in param_conf: assert params.pdata[param_name]["ts"].equals( mock_inputs.get_timeseries_df(param_name) @@ -355,8 +355,7 @@ def test_parameters_instance_attributes( }, ) assert ( - params.pdata[param_name]["dist"].__self__.kwds - == expected.__self__.kwds + params.pdata[param_name]["dist"].__self__.kwds == expected.__self__.kwds ) assert ( params.pdata[param_name]["dist"].__self__.support() @@ -453,9 +452,7 @@ def test_get_pnames2pindex( # Assertions assert params.get_pnames2pindex() == params.pnames2pindex - assert params.pnames2pindex == { - p: params.pnames.index(p) for p in params.pnames - } + assert params.pnames2pindex == {p: params.pnames.index(p) for p in params.pnames} @pytest.mark.parametrize( "factory,n_days,nsubpops", diff --git a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py index 1e0915b82..53c34e039 100644 --- a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py +++ b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py @@ -42,7 +42,9 @@ def test_parameters_from_timeserie_file(): ) # p = inference_simulator.s.parameters - p_draw = p.parameters_quick_draw(n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes) + p_draw = p.parameters_quick_draw( + n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes + ) p_df = p.getParameterDF(p_draw)["parameter"] diff --git a/flepimop/gempyor_pkg/tests/seir/test_compartments.py b/flepimop/gempyor_pkg/tests/seir/test_compartments.py index 1d4319e3b..492d37400 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_compartments.py +++ b/flepimop/gempyor_pkg/tests/seir/test_compartments.py @@ -24,7 +24,9 @@ def test_check_transitions_parquet_creation(): config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") original_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.parquet" original_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.parquet" - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=original_compartments_file, @@ -43,10 +45,16 @@ def test_check_transitions_parquet_writing_and_loading(): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) temp_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.test.parquet" temp_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.test.parquet" - lhs.toFile(compartments_file=temp_compartments_file, transitions_file=temp_transitions_file, write_parquet=True) + lhs.toFile( + compartments_file=temp_compartments_file, + transitions_file=temp_transitions_file, + write_parquet=True, + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=temp_compartments_file, @@ -86,4 +94,3 @@ def test_ModelInfo_has_compartments_component(): ) assert type(s.compartments) == compartments.Compartments assert type(s.compartments) == compartments.Compartments - diff --git a/flepimop/gempyor_pkg/tests/seir/test_ic.py b/flepimop/gempyor_pkg/tests/seir/test_ic.py index b4cd240ee..e9c275347 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_ic.py +++ b/flepimop/gempyor_pkg/tests/seir/test_ic.py @@ -21,7 +21,9 @@ def test_IC_success(self): outcome_modifiers_scenario=None, write_csv=False, ) - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) assert sic.initial_conditions_config == s.initial_conditions_config def test_IC_allow_missing_node_compartments_success(self): @@ -40,7 +42,9 @@ def test_IC_allow_missing_node_compartments_success(self): s.initial_conditions_config["allow_missing_nodes"] = True s.initial_conditions_config["allow_missing_compartments"] = True - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) def test_IC_IC_notImplemented_fail(self): @@ -58,6 +62,8 @@ def test_IC_IC_notImplemented_fail(self): write_csv=False, ) s.initial_conditions_config["method"] = "unknown" - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) diff --git a/flepimop/gempyor_pkg/tests/seir/test_parameters.py b/flepimop/gempyor_pkg/tests/seir/test_parameters.py index 9e03bf87d..384c3eaa2 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_parameters.py +++ b/flepimop/gempyor_pkg/tests/seir/test_parameters.py @@ -65,7 +65,9 @@ def test_parameters_from_config_plus_read_write(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() @@ -122,7 +124,10 @@ def test_parameters_quick_draw_old(): assert ((2 <= R0s) & (R0s <= 3)).all() assert sigma.shape == (modinf.n_days, modinf.nsubpops) - assert (sigma == config["seir"]["parameters"]["sigma"]["value"]["value"].as_evaled_expression()).all() + assert ( + sigma + == config["seir"]["parameters"]["sigma"]["value"]["value"].as_evaled_expression() + ).all() assert gamma.shape == (modinf.n_days, modinf.nsubpops) assert len(np.unique(gamma)) == 1 @@ -174,6 +179,8 @@ def test_parameters_from_timeseries_file(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() diff --git a/flepimop/gempyor_pkg/tests/seir/test_seir.py b/flepimop/gempyor_pkg/tests/seir/test_seir.py index 99a4cc236..d9d6d696f 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_seir.py +++ b/flepimop/gempyor_pkg/tests/seir/test_seir.py @@ -74,7 +74,9 @@ def test_constant_population_legacy_integration(): integration_method = "legacy" seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -82,7 +84,9 @@ def test_constant_population_legacy_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -94,7 +98,9 @@ def test_constant_population_legacy_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -142,8 +148,12 @@ def test_constant_population_rk4jit_integration_fail(): ) modinf.seir_config["integration"]["method"] = "rk4.jit" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -151,7 +161,9 @@ def test_constant_population_rk4jit_integration_fail(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -163,7 +175,9 @@ def test_constant_population_rk4jit_integration_fail(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -213,7 +227,9 @@ def test_constant_population_rk4jit_integration(): assert modinf.seir_config["integration"]["method"].get() == "rk4" seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -221,7 +237,9 @@ def test_constant_population_rk4jit_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -233,7 +251,9 @@ def test_constant_population_rk4jit_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, parsed_parameters, @@ -281,7 +301,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -289,7 +311,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -301,7 +325,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -316,11 +342,15 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "10001"] + df[ + (df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "10001"] > 1 ) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) @@ -336,11 +366,23 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): @@ -367,7 +409,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): ) seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -375,7 +419,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -387,7 +433,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -402,8 +450,18 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): ) df = seir.states2Df(modinf, states) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_no_spread(): @@ -427,7 +485,9 @@ def test_steps_SEIR_no_spread(): ) seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) modinf.mobility.data = modinf.mobility.data * 0 @@ -437,7 +497,9 @@ def test_steps_SEIR_no_spread(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -449,7 +511,9 @@ def test_steps_SEIR_no_spread(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(10): states = seir.steps_SEIR( @@ -464,7 +528,9 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -480,7 +546,9 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -515,7 +583,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) states_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet" + ), ).to_pandas() states_old = states_old[states_old["date"] == "2020-03-15"].reset_index(drop=True) @@ -547,7 +617,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=sim_id2write, modinf=modinf, config=config) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) assert ( @@ -560,10 +632,16 @@ def test_continuation_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) for path in ["model_output/seir", "model_output/snpi", "model_output/spar"]: @@ -604,7 +682,9 @@ def test_inference_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) npis_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet" + ) ).to_pandas() config.clear() @@ -632,10 +712,16 @@ def test_inference_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) npis_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet" + ) ).to_pandas() assert npis_old["modifier_name"].isin(["None", "Wuhan", "KansasCity"]).all() @@ -675,7 +761,9 @@ def test_parallel_compartments_with_vacc(): ) seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -683,7 +771,9 @@ def test_parallel_compartments_with_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -695,7 +785,9 @@ def test_parallel_compartments_with_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -762,7 +854,9 @@ def test_parallel_compartments_no_vacc(): ) seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -770,7 +864,9 @@ def test_parallel_compartments_no_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -782,7 +878,9 @@ def test_parallel_compartments_no_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( diff --git a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py index 34df630c3..b2161cc9b 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py +++ b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py @@ -135,7 +135,9 @@ def test_subpopulation_structure_mobility_shape_fail(): temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -155,7 +157,9 @@ def test_subpopulation_structure_mobility_fluxes_same_ori_and_dest_fail(): temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"Mobility fluxes with same origin and destination.*"): + with pytest.raises( + ValueError, match=r"Mobility fluxes with same origin and destination.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -175,7 +179,9 @@ def test_subpopulation_structure_mobility_npz_shape_fail(): temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -216,7 +222,8 @@ def test_subpopulation_structure_mobility_exceed_source_node_pop_fail(): config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] @@ -238,7 +245,8 @@ def test_subpopulation_structure_mobility_rows_exceed_source_node_pop_fail(): config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index e18861e9e..4843986c1 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -445,9 +445,7 @@ def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> No ) if (scale_func := mock_inputs.config.get("scale")) is not None: # Scale config - expected_transformed_data = getattr(np, scale_func)( - expected_transformed_data - ) + expected_transformed_data = getattr(np, scale_func)(expected_transformed_data) assert transformed_data.identical(expected_transformed_data) @pytest.mark.parametrize("factory", all_valid_factories) @@ -464,8 +462,7 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: assert isinstance(log_likelihood, xr.DataArray) assert ( - log_likelihood.dims - == mock_inputs.gt_data[mock_inputs.config["data_var"]].dims + log_likelihood.dims == mock_inputs.gt_data[mock_inputs.config["data_var"]].dims ) assert log_likelihood.coords.identical( mock_inputs.gt_data[mock_inputs.config["data_var"]].coords @@ -523,9 +520,7 @@ def test_compute_logloss_data_misshape_value_error( mock_inputs = factory() statistic = mock_inputs.create_statistic_instance() - model_rows, model_cols = mock_inputs.model_data[ - mock_inputs.config["sim_var"] - ].shape + model_rows, model_cols = mock_inputs.model_data[mock_inputs.config["sim_var"]].shape gt_rows, gt_cols = mock_inputs.gt_data[mock_inputs.config["data_var"]].shape expected_match = ( rf"^{mock_inputs.name} Statistic error\: data and groundtruth do not have " diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py index 367a7f550..b12989ec4 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -7,7 +7,7 @@ class TestGetLogNormal: """Unit tests for the `gempyor.utils.get_log_normal` function.""" - + @pytest.mark.parametrize( "meanlog,sdlog", [ @@ -22,15 +22,15 @@ class TestGetLogNormal: ], ) def test_construct_distribution( - self, - meanlog: float | int, - sdlog: float | int, + self, + meanlog: float | int, + sdlog: float | int, ) -> None: """Test the construction of a log normal distribution. - This test checks whether the `get_log_normal` function correctly constructs - a log normal distribution with the specified parameters. It verifies that - the returned object is an instance of `rv_frozen`, and that its support and + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and parameters (log mean and log standard deviation) are correctly set. Args: diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py index 23e4fad58..c3ccbca79 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -7,7 +7,7 @@ class TestGetTruncatedNormal: """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" - + @pytest.mark.parametrize( "mean,sd,a,b", [ @@ -21,10 +21,10 @@ class TestGetTruncatedNormal: ], ) def test_construct_distribution( - self, - mean: float | int, - sd: float | int, - a: float | int, + self, + mean: float | int, + sd: float | int, + a: float | int, b: float | int, ) -> None: """Test the construction of a truncated normal distribution. diff --git a/flepimop/gempyor_pkg/tests/utils/test_random_distribution_sampler.py b/flepimop/gempyor_pkg/tests/utils/test_random_distribution_sampler.py index 64efc98b1..4ca9c0ac7 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_random_distribution_sampler.py +++ b/flepimop/gempyor_pkg/tests/utils/test_random_distribution_sampler.py @@ -51,9 +51,7 @@ def test_binomial_p_value_error(self, p: float) -> None: def test_output_validation(self, distribution: str, kwargs: dict[str, Any]) -> None: actual = random_distribution_sampler(distribution, **kwargs) if distribution == "fixed": - expected = partial( - np.random.uniform, kwargs.get("value"), kwargs.get("value") - ) + expected = partial(np.random.uniform, kwargs.get("value"), kwargs.get("value")) assert partials_are_similar(actual, expected) elif distribution == "uniform": expected = partial(np.random.uniform, kwargs.get("low"), kwargs.get("high")) @@ -68,12 +66,12 @@ def test_output_validation(self, distribution: str, kwargs: dict[str, Any]) -> N assert inspect.ismethod(actual) assert actual.__self__.kwds.get("loc") == kwargs.get("mean") assert actual.__self__.kwds.get("scale") == kwargs.get("sd") - assert actual.__self__.a == ( - kwargs.get("a") - kwargs.get("mean") - ) / kwargs.get("sd") - assert actual.__self__.b == ( - kwargs.get("b") - kwargs.get("mean") - ) / kwargs.get("sd") + assert actual.__self__.a == (kwargs.get("a") - kwargs.get("mean")) / kwargs.get( + "sd" + ) + assert actual.__self__.b == (kwargs.get("b") - kwargs.get("mean")) / kwargs.get( + "sd" + ) elif distribution == "lognorm": assert inspect.ismethod(actual) assert actual.__self__.kwds.get("s") == kwargs.get("sdlog") diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index 7a0a0c581..48f05dbd5 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -137,8 +137,7 @@ def test_subpop_is_cast_as_str(self) -> None: temp_path = Path(temp_file.name) assert temp_path.stat().st_size == 0 assert ( - self.subpop_df.to_parquet(temp_path, engine="pyarrow", index=False) - is None + self.subpop_df.to_parquet(temp_path, engine="pyarrow", index=False) is None ) assert temp_path.stat().st_size > 0 test_df = read_df(fname=temp_path) diff --git a/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py index 94be3394a..665f34159 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py +++ b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py @@ -112,9 +112,7 @@ def test_rolling_mean_pad( rolling_mean_data = rolling_mean_pad(test_data, window) rolling_mean_reference = self._rolling_mean_pad_reference(test_data, window) assert rolling_mean_data.shape == expected_shape - assert np.isclose( - rolling_mean_data, rolling_mean_reference, equal_nan=True - ).all() + assert np.isclose(rolling_mean_data, rolling_mean_reference, equal_nan=True).all() def _rolling_mean_pad_reference( self, data: npt.NDArray[np.number], window: int diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 768451b2a..058e2401b 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -10,7 +10,11 @@ @pytest.mark.parametrize( - ("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet"),], + ("fname", "extension"), + [ + ("mobility", "csv"), + ("usa-geoid-params-output", "parquet"), + ], ) def test_read_df_and_write_success(fname, extension): os.chdir(tmp_path) @@ -29,7 +33,9 @@ def test_read_df_and_write_success(fname, extension): assert os.path.isfile(tmp_path + "/data/" + fname + "." + extension) -@pytest.mark.parametrize(("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet")]) +@pytest.mark.parametrize( + ("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet")] +) def test_read_df_and_write_fail(fname, extension): with pytest.raises(NotImplementedError, match=r".*Invalid.*extension.*Must.*"): os.chdir(tmp_path) @@ -41,7 +47,9 @@ def test_read_df_and_write_fail(fname, extension): assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") elif extension == "parquet": - df2 = pa.parquet.read_table(f"{DATA_DIR}/" + fname + "." + extension).to_pandas() + df2 = pa.parquet.read_table( + f"{DATA_DIR}/" + fname + "." + extension + ).to_pandas() assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") @@ -91,9 +99,7 @@ def test_create_resume_out_filename(): filetype="spar", liketype="global", ) - expected_filename = ( - "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" - ) + expected_filename = "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" assert result == expected_filename result2 = utils.create_resume_out_filename( @@ -111,14 +117,22 @@ def test_create_resume_out_filename(): def test_create_resume_input_filename(): result = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="spar", + liketype="global", ) expect_filename = "model_output/output/321/spar/global/final/000000002.321.spar.parquet" assert result == expect_filename result2 = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="seed", liketype="chimeric" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="seed", + liketype="chimeric", ) expect_filename2 = "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" assert result2 == expect_filename2 @@ -126,17 +140,26 @@ def test_create_resume_input_filename(): def test_get_filetype_resume_discard_seeding_true_flepi_block_index_1(): expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") + == expected_types + ) def test_get_filetype_resume_discard_seeding_false_flepi_block_index_1(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") + == expected_types + ) def test_get_filetype_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") == expected_types + assert ( + utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") + == expected_types + ) def test_create_resume_file_names_map(): diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils2.py b/flepimop/gempyor_pkg/tests/utils/test_utils2.py index 4b0ae59ba..0822604ed 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils2.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils2.py @@ -26,7 +26,9 @@ class SampleClass: def __init__(self): self.value = 11 - @utils.profile(output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True) + @utils.profile( + output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True + ) def get_value(self): return self.value @@ -198,7 +200,9 @@ def test_as_random_distribution_binomial_w_fraction_error(config): def test_as_random_distribution_truncnorm(config): - config.add({"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}}) + config.add( + {"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}} + ) dist = config["value"].as_random_distribution() rvs = dist(size=1000) assert len(rvs) == 1000 diff --git a/postprocessing/postprocess_auto.py b/postprocessing/postprocess_auto.py index aaf4a0bff..7cb5edfcf 100644 --- a/postprocessing/postprocess_auto.py +++ b/postprocessing/postprocess_auto.py @@ -26,7 +26,9 @@ def __init__(self, run_id, config_filepath=None, folder_path=None): self.folder_path = folder_path -def get_all_filenames(file_type, all_runs, finals_only=False, intermediates_only=False, ignore_chimeric=True) -> dict: +def get_all_filenames( + file_type, all_runs, finals_only=False, intermediates_only=False, ignore_chimeric=True +) -> dict: """ return dictionanary for each run name """ @@ -159,7 +161,14 @@ def slack_multiple_files_v2(slack_token, message, file_list, channel): help="Maximum number of files to load for in depth plot and individual sim plot", ) def generate_pdf( - config_filepath, run_id, job_name, fs_results_path, slack_token, slack_channel, max_files, max_files_deep + config_filepath, + run_id, + job_name, + fs_results_path, + slack_token, + slack_channel, + max_files, + max_files_deep, ): print("Generating plots") print(f">> config {config_filepath} for run_id {run_id}") @@ -217,7 +226,9 @@ def generate_pdf( for filename in file_list: slot = int(filename.split("/")[-1].split(".")[0]) block = int(filename.split("/")[-1].split(".")[1]) - sim_str = filename.split("/")[-1].split(".")[2] # not necessarily a sim number now + sim_str = filename.split("/")[-1].split(".")[ + 2 + ] # not necessarily a sim number now if sim_str.isdigit(): sim = int(sim_str) if block == 1 and (sim == 1 or sim % 5 == 0): ## first block, only one @@ -238,7 +249,9 @@ def generate_pdf( # In[23]: - fig, axes = plt.subplots(len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True) + fig, axes = plt.subplots( + len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True + ) colors = ["b", "r", "y", "c"] icl = 0 @@ -255,32 +268,60 @@ def generate_pdf( lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl] ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"National, {feature}") axes[idp, ift].grid() for idp, nn in enumerate(node_names): idp = idp + 1 - all_nn = full_df[full_df["subpop"] == nn][["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"]] + all_nn = full_df[full_df["subpop"] == nn][ + ["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"] + ] for ift, feature in enumerate(["ll", "accept", "accept_avg", "accept_prob"]): lls = all_nn.pivot(index="sim", columns="slot", values=feature) if feature == "accept": lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, + lls.median(axis=1), + marker="o", + label=run_id, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3 ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) - axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"{nn}, {feature}") axes[idp, ift].grid() if idp == len(node_names) - 1: @@ -292,8 +333,11 @@ def generate_pdf( pass import gempyor.utils - llik_filenames = gempyor.utils.list_filenames(folder="model_output/", filters=["final", "llik" , ".parquet"]) - #get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) + + llik_filenames = gempyor.utils.list_filenames( + folder="model_output/", filters=["final", "llik", ".parquet"] + ) + # get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) # In[7]: resultST = [] for filename in llik_filenames: diff --git a/utilities/clean_s3.py b/utilities/clean_s3.py index 2998c65c2..08982b2b2 100644 --- a/utilities/clean_s3.py +++ b/utilities/clean_s3.py @@ -9,7 +9,9 @@ s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") -pages = paginator.paginate(Bucket=bucket, Prefix="", Delimiter="/") # needs paginator cause more than 1000 files +pages = paginator.paginate( + Bucket=bucket, Prefix="", Delimiter="/" +) # needs paginator cause more than 1000 files to_prun = [] # folders: diff --git a/utilities/copy_for_continuation.py b/utilities/copy_for_continuation.py index 33d9da40f..05b7a803a 100644 --- a/utilities/copy_for_continuation.py +++ b/utilities/copy_for_continuation.py @@ -77,10 +77,14 @@ def detect_old_run_id(fp): fn = files[0] old_run_id = detect_old_run_id(fn) new_name = ( - fn.replace("seir", "cont").replace(f"{input_folder}/model_output", "model_output").replace(old_run_id, run_id) + fn.replace("seir", "cont") + .replace(f"{input_folder}/model_output", "model_output") + .replace(old_run_id, run_id) ) - print(f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}") + print( + f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}" + ) empty_str = "°" * len(input_folder) print(f"file: \n OLD NAME: {fn}\n NEW NAME: {empty_str}{new_name}") for fn in tqdm.tqdm(files): diff --git a/utilities/prune_by_llik.py b/utilities/prune_by_llik.py index 5b1f3224b..78ef6c8af 100644 --- a/utilities/prune_by_llik.py +++ b/utilities/prune_by_llik.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -113,14 +117,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: @@ -143,7 +151,7 @@ def copy_path(src, dst): file_types = [ "llik", - #"seed", + # "seed", "init", "snpi", "hnpi", diff --git a/utilities/prune_by_llik_and_proj.py b/utilities/prune_by_llik_and_proj.py index 53e623224..585473bf7 100644 --- a/utilities/prune_by_llik_and_proj.py +++ b/utilities/prune_by_llik_and_proj.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -23,7 +27,7 @@ def get_all_filenames( l = [] for f in Path(str(fs_results_path + "model_output")).rglob(f"*.{ext}"): f = str(f) - + if file_type in f: print(f) if ( @@ -61,7 +65,9 @@ def get_all_filenames( fs_results_path = "to_prune/" best_n = 200 -llik_filenames = get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) +llik_filenames = get_all_filenames( + "llik", fs_results_path, finals_only=True, intermediates_only=False +) # In[7]: resultST = [] for filename in llik_filenames: @@ -100,7 +106,6 @@ def get_all_filenames( print(f" - {slot:4}, llik: {sorted_llik.loc[slot]['ll']:0.3f}") - #### RERUN FROM HERE TO CHANGE THE REGULARIZATION files_to_keep = list(full_df.loc[best_slots]["filename"].unique()) # important to sort by llik @@ -109,8 +114,9 @@ def get_all_filenames( files_to_keep = [] for fn in all_files: if fn in files_to_keep3: - outcome_fn = fn.replace("llik", "hosp") + outcome_fn = fn.replace("llik", "hosp") import gempyor.utils + outcomes_df = gempyor.utils.read_df(outcome_fn) outcomes_df = outcomes_df.set_index("date") reg = 1.5 @@ -118,19 +124,27 @@ def get_all_filenames( this_bad = 0 bad_subpops = [] for sp in outcomes_df["subpop"].unique(): - max_fit = outcomes_df[outcomes_df["subpop"]==sp]["incidC"][:"2024-04-08"].max() - max_summer = outcomes_df[outcomes_df["subpop"]==sp]["incidC"]["2024-04-08":"2024-09-30"].max() - if max_summer > max_fit*reg: + max_fit = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + :"2024-04-08" + ].max() + max_summer = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + "2024-04-08":"2024-09-30" + ].max() + if max_summer > max_fit * reg: this_bad += 1 - max_reg = max(max_reg, max_summer/max_fit) + max_reg = max(max_reg, max_summer / max_fit) bad_subpops.append(sp) - #print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") - #print(f">>> MULT BY {max_summer/max_fit*mult:2f}") - #outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult - if this_bad>4 or max_reg>4: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + # print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") + # print(f">>> MULT BY {max_summer/max_fit*mult:2f}") + # outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult + if this_bad > 4 or max_reg > 4: + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) else: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) files_to_keep.append(fn) print(len(files_to_keep)) ### END OF CODE @@ -146,14 +160,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: