From c8fce036a951b1b47bc09ac2f5edc79d8a63c73b Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Wed, 17 Apr 2024 14:37:38 -0400 Subject: [PATCH 01/33] add function to generate filename for resume files --- flepimop/gempyor_pkg/src/gempyor/utils.py | 21 +++++++++++++++++++ .../gempyor_pkg/tests/utils/test_utils.py | 14 +++++++++++++ 2 files changed, 35 insertions(+) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 905584cf5..9d515cf17 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -1,3 +1,4 @@ +import os import datetime import functools import numbers @@ -9,6 +10,7 @@ import scipy.stats import sympy.parsing.sympy_parser import logging +from gempyor import file_paths logger = logging.getLogger(__name__) @@ -287,3 +289,22 @@ def bash(command): print("------------") print(f"lsblk: {bash('lsblk')}") print("END AWS DIAGNOSIS ================================") + +def create_resume_out_filename(filetype: str, liketype: str) -> str: + run_id = os.environ.get("FLEPI_RUN_INDEX") + prefix = f"{os.environ.get("FLEPI_PREFIX")}/{os.environ.get("FLEPI_RUN_INDEX")}" + inference_filepath_suffix = f"{liketype}/intermidate" + FLEPI_SLOT_INDEX = int(os.environ.get("FLEPI_SLOT_INDEX")) + inference_filename_prefix='%09d.' % FLEPI_SLOT_INDEX + index='{:09d}.{:09d}'.format(1, int(os.environ.get("FLEPI_BLOCK_INDEX")-1)) + extension = "parquet" + if filetype == "seed": + extension = "csv" + return file_paths.create_file_name(run_id=run_id, + prefix=prefix, + inference_filename_prefix=inference_filename_prefix, + inference_filepath_suffix=inference_filepath_suffix, + index=index, + extension=extension) + + \ No newline at end of file diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 694a7296f..1a1353009 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -90,3 +90,17 @@ def test_get_truncated_normal_success(): def test_get_log_normal_success(): utils.get_log_normal(meanlog=0, sdlog=1) + +def test_create_resume_out_filename(): + os.environ["FLEPI_RUN_INDEX"] = "123" + os.environ["FLEPI_PREFIX"] = "prefix" + os.environ["FLEPI_SLOT_INDEX"] = "2" + os.environ["FLEPI_BLOCK_INDEX"] = "2" + + expected_filename = "prefix/123/000000002./intermidate/000000001.000000001.parquet" + assert utils.create_resume_out_filename("spar", "like") == expected_filename + + expected_filename = "prefix/123/000000002./intermidate/000000001.000000001.csv" + assert utils.create_resume_out_filename("seed", "like") == expected_filename + + os.environ.clear() \ No newline at end of file From a63105ea4d148e728061c2d570e35c1855309c1e Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Wed, 24 Apr 2024 10:37:21 -0400 Subject: [PATCH 02/33] fix functions and adding unit tests --- flepimop/gempyor_pkg/src/gempyor/utils.py | 21 +++++++-- .../gempyor_pkg/tests/utils/test_utils.py | 43 +++++++++++++------ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 9d515cf17..f28487f0a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -292,11 +292,11 @@ def bash(command): def create_resume_out_filename(filetype: str, liketype: str) -> str: run_id = os.environ.get("FLEPI_RUN_INDEX") - prefix = f"{os.environ.get("FLEPI_PREFIX")}/{os.environ.get("FLEPI_RUN_INDEX")}" + prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('FLEPI_RUN_INDEX')}" inference_filepath_suffix = f"{liketype}/intermidate" FLEPI_SLOT_INDEX = int(os.environ.get("FLEPI_SLOT_INDEX")) inference_filename_prefix='%09d.' % FLEPI_SLOT_INDEX - index='{:09d}.{:09d}'.format(1, int(os.environ.get("FLEPI_BLOCK_INDEX")-1)) + index='{:09d}.{:09d}'.format(1, int(os.environ.get("FLEPI_BLOCK_INDEX"))-1) extension = "parquet" if filetype == "seed": extension = "csv" @@ -305,6 +305,21 @@ def create_resume_out_filename(filetype: str, liketype: str) -> str: inference_filename_prefix=inference_filename_prefix, inference_filepath_suffix=inference_filepath_suffix, index=index, + ftype=filetype, extension=extension) - \ No newline at end of file +def create_resume_input_filename(filetype: str, liketype: str) -> str: + run_id = os.environ.get("RESUME_RUN_INDEX") + prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('RESUME_RUN_INDEX')}" + inference_filepath_suffix = f"{liketype}/final" + index = os.environ.get("FLEPI_SLOT_INDEX") + extension = "parquet" + if filetype == "seed": + extension = "csv" + return file_paths.create_file_name(run_id=run_id, + prefix=prefix, + inference_filepath_suffix=inference_filepath_suffix, + index=index, + ftype=filetype, + extension=extension) + \ No newline at end of file diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 1a1353009..83597bbd5 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -1,5 +1,4 @@ import pytest -import datetime import os import pandas as pd @@ -10,8 +9,6 @@ from gempyor import utils DATA_DIR = os.path.dirname(__file__) + "/data" -# os.chdir(os.path.dirname(__file__)) - tmp_path = "/tmp" @@ -91,16 +88,36 @@ def test_get_truncated_normal_success(): def test_get_log_normal_success(): utils.get_log_normal(meanlog=0, sdlog=1) -def test_create_resume_out_filename(): - os.environ["FLEPI_RUN_INDEX"] = "123" - os.environ["FLEPI_PREFIX"] = "prefix" - os.environ["FLEPI_SLOT_INDEX"] = "2" - os.environ["FLEPI_BLOCK_INDEX"] = "2" - expected_filename = "prefix/123/000000002./intermidate/000000001.000000001.parquet" - assert utils.create_resume_out_filename("spar", "like") == expected_filename +@pytest.fixture +def env_vars(monkeypatch): + # Setting environment variables for the test + monkeypatch.setenv("RESUME_RUN_INDEX", "321") + monkeypatch.setenv("FLEPI_PREFIX", "output") + monkeypatch.setenv("FLEPI_SLOT_INDEX", "2") + monkeypatch.setenv("FLEPI_BLOCK_INDEX", "2") + monkeypatch.setenv("FLEPI_RUN_INDEX", "123") + + +def test_create_resume_out_filename(env_vars): + result = utils.create_resume_out_filename("spar", "global") + expected_filename = """model_output/output/123/spar/global/intermidate + /000000002.000000001.000000001.123.spar.parquet""" + assert result == expected_filename + + result2 = utils.create_resume_out_filename("seed", "chimeric") + expected_filename2 = """model_output/output/123/seed/chimeric/intermidate + /000000002.000000001.000000001.123.seed.csv'""" + assert result2 == expected_filename2 + + +def test_create_resume_input_filename(env_vars): - expected_filename = "prefix/123/000000002./intermidate/000000001.000000001.csv" - assert utils.create_resume_out_filename("seed", "like") == expected_filename + result = utils.create_resume_input_filename("spar", "global") + expect_filename = 'model_output/output/321/spar/global/final/000000002.321.spar.parquet' - os.environ.clear() \ No newline at end of file + assert result == expect_filename + + result2 = utils.create_resume_input_filename("seed", "chimeric") + expect_filename2 = 'model_output/output/321/seed/chimeric/final/000000002.321.seed.csv' + assert result2 == expect_filename2 From aca80bcabc856cee6937258d180b5b97e37c41a9 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Wed, 24 Apr 2024 14:38:40 -0400 Subject: [PATCH 03/33] add copy function --- flepimop/gempyor_pkg/src/gempyor/utils.py | 45 ++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index f28487f0a..9af2156c3 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -9,6 +9,8 @@ import pyarrow as pa import scipy.stats import sympy.parsing.sympy_parser +import subprocess +import shutil import logging from gempyor import file_paths @@ -322,4 +324,45 @@ def create_resume_input_filename(filetype: str, liketype: str) -> str: index=index, ftype=filetype, extension=extension) - \ No newline at end of file + + +def copy_file_based_on_last_job_output(): + last_job_output = os.environ.get("LAST_JOB_OUTPUT") + resume_discard_seeding = os.environ.get("RESUME_DISCARD_SEEDING") + parquet_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] + if resume_discard_seeding == "true": + parquet_types.remove("seed") + liketypes = ["global", "chimeric"] + file_name_map = dict() + + for filetype in parquet_types: + for liketype in liketypes: + input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) + output_file_name = create_resume_out_filename(filetype=filetype, liketype=liketype) + file_name_map[input_file_name] = output_file_name + + if last_job_output.find("s3://") >= 0: + for in_filename in file_name_map: + command = ['aws', 's3', 'cp', '--quiet', last_job_output+"/"+in_filename, file_name_map[in_filename]] + try: + result = subprocess.run(command, check=True, stdout = subprocess.PIPE, + stderr = subprocess.PIPE) + print("Output:", result.stdout.decode()) + except subprocess.CalledProcessError as e: + print("Error: ", e.stderr.decode()) + else: + first_output_filename = next(iter(file_name_map.values())) + output_dir = os.path.dirname(first_output_filename) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + for in_filename in file_name_map: + shutil.copy(os.path.join(last_job_output, in_filename), file_name_map[in_filename]) + + + for in_filename in file_name_map: + output_file_name = file_name_map[in_filename] + parquet_type = [ptype for ptype in parquet_types if ptype in output_file_name] + if os.path.exists(output_file_name): + print(f"Copy successful for file of type {parquet_type} {in_filename}->{output_file_name}") + else: + print(f"Could not copy file of type {parquet_type} {in_filename}->{output_file_name}") \ No newline at end of file From 727169cb79bba3e869676402e9466e0d221b15da Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Mon, 29 Apr 2024 10:22:48 -0400 Subject: [PATCH 04/33] format change --- flepimop/gempyor_pkg/src/gempyor/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 9af2156c3..3fc810183 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -334,19 +334,19 @@ def copy_file_based_on_last_job_output(): parquet_types.remove("seed") liketypes = ["global", "chimeric"] file_name_map = dict() - + for filetype in parquet_types: for liketype in liketypes: input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) output_file_name = create_resume_out_filename(filetype=filetype, liketype=liketype) file_name_map[input_file_name] = output_file_name - + if last_job_output.find("s3://") >= 0: for in_filename in file_name_map: command = ['aws', 's3', 'cp', '--quiet', last_job_output+"/"+in_filename, file_name_map[in_filename]] try: - result = subprocess.run(command, check=True, stdout = subprocess.PIPE, - stderr = subprocess.PIPE) + result = subprocess.run(command, check=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) print("Output:", result.stdout.decode()) except subprocess.CalledProcessError as e: print("Error: ", e.stderr.decode()) @@ -358,11 +358,10 @@ def copy_file_based_on_last_job_output(): for in_filename in file_name_map: shutil.copy(os.path.join(last_job_output, in_filename), file_name_map[in_filename]) - for in_filename in file_name_map: output_file_name = file_name_map[in_filename] parquet_type = [ptype for ptype in parquet_types if ptype in output_file_name] if os.path.exists(output_file_name): print(f"Copy successful for file of type {parquet_type} {in_filename}->{output_file_name}") else: - print(f"Could not copy file of type {parquet_type} {in_filename}->{output_file_name}") \ No newline at end of file + print(f"Could not copy file of type {parquet_type} {in_filename}->{output_file_name}") From 2d35b8d1c9e4328423207d9f5488a3e97fe780b2 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Mon, 29 Apr 2024 14:18:14 -0400 Subject: [PATCH 05/33] add new function to set parquest types and add tests for it --- flepimop/gempyor_pkg/src/gempyor/utils.py | 33 ++++++++++++++++--- .../gempyor_pkg/tests/utils/test_utils.py | 27 ++++++++++++--- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index a8f3094ef..44e42463e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -13,6 +13,7 @@ import shutil import logging from gempyor import file_paths +from typing import List logger = logging.getLogger(__name__) @@ -385,12 +386,31 @@ def create_resume_input_filename(filetype: str, liketype: str) -> str: extension=extension) -def copy_file_based_on_last_job_output(): - last_job_output = os.environ.get("LAST_JOB_OUTPUT") +def get_parquet_types()-> List[str]: resume_discard_seeding = os.environ.get("RESUME_DISCARD_SEEDING") - parquet_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - if resume_discard_seeding == "true": - parquet_types.remove("seed") + flepi_block_index = os.environ.get("FLEPI_BLOCK_INDEX") + if flepi_block_index == "1": + if resume_discard_seeding == "true": + return ["spar", "snpi", "hpar", "hnpi", "init"] + else: + return ["seed", "spar", "snpi", "hpar", "hnpi", "init"] + else: + return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] + + +def copy_file_based_on_last_job_output() -> bool: + """ + Copies files based on the last job output. + + This function copies files from the last job output directory to the corresponding output directory + based on the file types and like types. The file names are determined using the `create_resume_input_filename` + and `create_resume_out_filename` functions. + + Returns: + bool: True if all files are successfully copied, False otherwise. + """ + last_job_output = os.environ.get("LAST_JOB_OUTPUT") + parquet_types = get_parquet_types() liketypes = ["global", "chimeric"] file_name_map = dict() @@ -424,3 +444,6 @@ def copy_file_based_on_last_job_output(): print(f"Copy successful for file of type {parquet_type} {in_filename}->{output_file_name}") else: print(f"Could not copy file of type {parquet_type} {in_filename}->{output_file_name}") + return False + + return True diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index a51a593c1..7f828b103 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -5,7 +5,8 @@ # import dask.dataframe as dd import pyarrow as pa import time - +from typing import List +from unittest.mock import patch from gempyor import utils DATA_DIR = os.path.dirname(__file__) + "/data" @@ -101,13 +102,11 @@ def env_vars(monkeypatch): def test_create_resume_out_filename(env_vars): result = utils.create_resume_out_filename("spar", "global") - expected_filename = """model_output/output/123/spar/global/intermidate - /000000002.000000001.000000001.123.spar.parquet""" + expected_filename = "model_output/output/123/spar/global/intermidate/000000002.000000001.000000001.123.spar.parquet" assert result == expected_filename result2 = utils.create_resume_out_filename("seed", "chimeric") - expected_filename2 = """model_output/output/123/seed/chimeric/intermidate - /000000002.000000001.000000001.123.seed.csv'""" + expected_filename2 = "model_output/output/123/seed/chimeric/intermidate/000000002.000000001.000000001.123.seed.csv" assert result2 == expected_filename2 @@ -121,3 +120,21 @@ def test_create_resume_input_filename(env_vars): result2 = utils.create_resume_input_filename("seed", "chimeric") expect_filename2 = 'model_output/output/321/seed/chimeric/final/000000002.321.seed.csv' assert result2 == expect_filename2 + + +@patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "true", "FLEPI_BLOCK_INDEX": "1"}) +def test_get_parquet_types_resume_discard_seeding_true_flepi_block_index_1(): + expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] + assert utils.get_parquet_types() == expected_types + + +@patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "false", "FLEPI_BLOCK_INDEX": "1"}) +def test_get_parquet_types_resume_discard_seeding_false_flepi_block_index_1(): + expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] + assert utils.get_parquet_types() == expected_types + + +@patch.dict(os.environ, {"FLEPI_BLOCK_INDEX": "2"}) +def test_get_parquet_types_flepi_block_index_2(): + expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] + assert utils.get_parquet_types() == expected_types \ No newline at end of file From bb5c191e647a43b5a5fc6fd677e0eb80c070eef2 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Fri, 10 May 2024 10:53:46 -0400 Subject: [PATCH 06/33] add functions to create resume file name map and download from s3 bucket --- flepimop/gempyor_pkg/src/gempyor/utils.py | 147 +++++++++++++----- .../gempyor_pkg/tests/utils/test_utils.py | 10 +- 2 files changed, 113 insertions(+), 44 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 44e42463e..6b937cd18 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -12,8 +12,10 @@ import subprocess import shutil import logging +import boto3 from gempyor import file_paths -from typing import List +from typing import List, Dict +from botocore.exceptions import ClientError logger = logging.getLogger(__name__) @@ -352,6 +354,7 @@ def bash(command): print(f"lsblk: {bash('lsblk')}") print("END AWS DIAGNOSIS ================================") + def create_resume_out_filename(filetype: str, liketype: str) -> str: run_id = os.environ.get("FLEPI_RUN_INDEX") prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('FLEPI_RUN_INDEX')}" @@ -386,7 +389,16 @@ def create_resume_input_filename(filetype: str, liketype: str) -> str: extension=extension) -def get_parquet_types()-> List[str]: +def get_parquet_types_for_resume() -> List[str]: + """ + Retrieves a list of parquet file types that are relevant for resuming a process based on + specific environment variable settings. This function dynamically determines the list + based on the current operational context given by the environment. + + The function checks two environment variables: + - `RESUME_DISCARD_SEEDING`: Determines whether seeding-related file types should be included. + - `FLEPI_BLOCK_INDEX`: Determines a specific operational mode or block of the process. + """ resume_discard_seeding = os.environ.get("RESUME_DISCARD_SEEDING") flepi_block_index = os.environ.get("FLEPI_BLOCK_INDEX") if flepi_block_index == "1": @@ -398,52 +410,103 @@ def get_parquet_types()-> List[str]: return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] -def copy_file_based_on_last_job_output() -> bool: +def create_resume_file_names_map() -> Dict[str, str]: """ - Copies files based on the last job output. - - This function copies files from the last job output directory to the corresponding output directory - based on the file types and like types. The file names are determined using the `create_resume_input_filename` - and `create_resume_out_filename` functions. + Generates a mapping of input file names to output file names for a resume process based on + parquet file types and environmental conditions. The function adjusts the file name mappings + based on the operational block index and the location of the last job output. + + The mappings depend on: + - Parquet file types appropriate for resuming a process, as determined by the environment. + - Whether the files are for 'global' or 'chimeric' types, as these liketypes influence the + file naming convention. + - The operational block index ('FLEPI_BLOCK_INDEX'), which can alter the input file names for + block index '1'. + - The presence and value of 'LAST_JOB_OUTPUT' environment variable, which if set to an S3 path, + adjusts the keys in the mapping to be prefixed with this path. Returns: - bool: True if all files are successfully copied, False otherwise. + Dict[str, str]: A dictionary where keys are input file paths and values are corresponding + output file paths. The paths may be modified by the 'LAST_JOB_OUTPUT' if it + is set and points to an S3 location. + + Raises: + No explicit exceptions are raised within the function, but it relies heavily on external + functions and environment variables which if improperly configured could lead to unexpected + behavior. """ - last_job_output = os.environ.get("LAST_JOB_OUTPUT") - parquet_types = get_parquet_types() + parquet_types = get_parquet_types_for_resume() + resume_file_name_mapping = dict() liketypes = ["global", "chimeric"] - file_name_map = dict() - for filetype in parquet_types: for liketype in liketypes: - input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) output_file_name = create_resume_out_filename(filetype=filetype, liketype=liketype) - file_name_map[input_file_name] = output_file_name - + input_file_name = output_file_name + if os.environ.get("FLEPI_BLOCK_INDEX") == "1": + input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) + resume_file_name_mapping[input_file_name] = output_file_name + + last_job_output = os.environ.get("LAST_JOB_OUTPUT") if last_job_output.find("s3://") >= 0: - for in_filename in file_name_map: - command = ['aws', 's3', 'cp', '--quiet', last_job_output+"/"+in_filename, file_name_map[in_filename]] - try: - result = subprocess.run(command, check=True, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - print("Output:", result.stdout.decode()) - except subprocess.CalledProcessError as e: - print("Error: ", e.stderr.decode()) - else: - first_output_filename = next(iter(file_name_map.values())) - output_dir = os.path.dirname(first_output_filename) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - for in_filename in file_name_map: - shutil.copy(os.path.join(last_job_output, in_filename), file_name_map[in_filename]) - - for in_filename in file_name_map: - output_file_name = file_name_map[in_filename] - parquet_type = [ptype for ptype in parquet_types if ptype in output_file_name] - if os.path.exists(output_file_name): - print(f"Copy successful for file of type {parquet_type} {in_filename}->{output_file_name}") - else: - print(f"Could not copy file of type {parquet_type} {in_filename}->{output_file_name}") - return False - - return True + old_keys = list(resume_file_name_mapping.keys()) + for k in old_keys: + new_key = os.path.join(last_job_output, k) + resume_file_name_mapping[new_key] = resume_file_name_mapping[k] + del resume_file_name_mapping[k] + return resume_file_name_mapping + + +def download_file_from_s3(name_map: Dict[str, str]) -> None: + """ + Downloads files from AWS S3 based on a mapping of S3 URIs to local file paths. The function + checks if the directory for the first output file exists and creates it if necessary. It + then iterates over each S3 URI in the provided mapping, downloads the file to the corresponding + local path, and handles errors if the S3 URI format is incorrect or if the download fails. + + Parameters: + name_map (Dict[str, str]): A dictionary where keys are S3 URIs (strings) and values + are the local file paths (strings) where the files should + be saved. + + Returns: + None: This function does not return a value; its primary effect is the side effect of + downloading files and potentially creating directories. + + Raises: + ValueError: If an S3 URI does not start with 's3://', indicating an invalid format. + ClientError: If an error occurs during the download from S3, such as a permissions issue, + a missing file, or network-related errors. These are caught and logged but not + re-raised, to allow the function to attempt subsequent downloads. + + Examples: + >>> name_map = { + "s3://mybucket/data/file1.txt": "/local/path/to/file1.txt", + "s3://mybucket/data/file2.txt": "/local/path/to/file2.txt" + } + >>> download_file_from_s3(name_map) + # This would download 'file1.txt' and 'file2.txt' from 'mybucket' on S3 to the specified local paths. + + # If an S3 URI is malformed: + >>> name_map = { + "http://wrongurl.com/data/file1.txt": "/local/path/to/file1.txt" + } + >>> download_file_from_s3(name_map) + # This will raise a ValueError indicating the invalid S3 URI format. + """ + s3 = boto3.client('s3') + first_output_filename = next(iter(name_map.values())) + output_dir = os.path.dirname(first_output_filename) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + for s3_uri in name_map: + try: + if s3_uri.startswith('s3://'): + bucket = s3_uri.split('/')[2] + object = s3_uri[len(bucket)+6:] + s3.download_file(bucket, object, name_map[s3_uri]) + else: + raise ValueError(f'Invalid S3 URI format {s3_uri}') + except ClientError as e: + print(f"An error occurred: {e}") + print("Could not download file from s3") \ No newline at end of file diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 7f828b103..0d408c90e 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -5,7 +5,6 @@ # import dask.dataframe as dd import pyarrow as pa import time -from typing import List from unittest.mock import patch from gempyor import utils @@ -98,6 +97,7 @@ def env_vars(monkeypatch): monkeypatch.setenv("FLEPI_SLOT_INDEX", "2") monkeypatch.setenv("FLEPI_BLOCK_INDEX", "2") monkeypatch.setenv("FLEPI_RUN_INDEX", "123") + monkeypatch.setenv("LAST_JOB_OUTPUT", "s3://bucket") def test_create_resume_out_filename(env_vars): @@ -137,4 +137,10 @@ def test_get_parquet_types_resume_discard_seeding_false_flepi_block_index_1(): @patch.dict(os.environ, {"FLEPI_BLOCK_INDEX": "2"}) def test_get_parquet_types_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_parquet_types() == expected_types \ No newline at end of file + assert utils.get_parquet_types() == expected_types + + +def test_create_resume_file_names_map(env_vars): + name_map = utils.create_resume_file_names_map() + for k in name_map: + assert k.find("s3://bucket") >= 0 From 34394a341d5b89e17b29fa189447490adb2c618c Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Tue, 14 May 2024 08:59:56 -0400 Subject: [PATCH 07/33] correct wrong functions name in tests --- flepimop/gempyor_pkg/tests/utils/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 0d408c90e..260a623ac 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -125,19 +125,19 @@ def test_create_resume_input_filename(env_vars): @patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "true", "FLEPI_BLOCK_INDEX": "1"}) def test_get_parquet_types_resume_discard_seeding_true_flepi_block_index_1(): expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_parquet_types() == expected_types + assert utils.get_parquet_types_for_resume() == expected_types @patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "false", "FLEPI_BLOCK_INDEX": "1"}) def test_get_parquet_types_resume_discard_seeding_false_flepi_block_index_1(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_parquet_types() == expected_types + assert utils.get_parquet_types_for_resume() == expected_types @patch.dict(os.environ, {"FLEPI_BLOCK_INDEX": "2"}) def test_get_parquet_types_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_parquet_types() == expected_types + assert utils.get_parquet_types_for_resume() == expected_types def test_create_resume_file_names_map(env_vars): From aab8388b290492a3e9326524e0b5b7c9e3ae83da Mon Sep 17 00:00:00 2001 From: kjsato Date: Tue, 14 May 2024 12:41:47 -0400 Subject: [PATCH 08/33] modified to use a common style in the function 'create_resume_out_filename()' in utils.py --- flepimop/gempyor_pkg/src/gempyor/utils.py | 86 ++++++++++++----------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 6b937cd18..1f4cc55e6 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -359,20 +359,24 @@ def create_resume_out_filename(filetype: str, liketype: str) -> str: run_id = os.environ.get("FLEPI_RUN_INDEX") prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('FLEPI_RUN_INDEX')}" inference_filepath_suffix = f"{liketype}/intermidate" - FLEPI_SLOT_INDEX = int(os.environ.get("FLEPI_SLOT_INDEX")) - inference_filename_prefix='%09d.' % FLEPI_SLOT_INDEX - index='{:09d}.{:09d}'.format(1, int(os.environ.get("FLEPI_BLOCK_INDEX"))-1) + # FLEPI_SLOT_INDEX = int(os.environ.get("FLEPI_SLOT_INDEX")) + # inference_filename_prefix='%09d.' % FLEPI_SLOT_INDEX + inference_filename_prefix = "{:09d}.".format(int(os.environ.get("FLEPI_SLOT_INDEX"))) + index = "{:09d}.{:09d}".format(1, int(os.environ.get("FLEPI_BLOCK_INDEX")) - 1) extension = "parquet" if filetype == "seed": extension = "csv" - return file_paths.create_file_name(run_id=run_id, - prefix=prefix, - inference_filename_prefix=inference_filename_prefix, - inference_filepath_suffix=inference_filepath_suffix, - index=index, - ftype=filetype, - extension=extension) - + return file_paths.create_file_name( + run_id=run_id, + prefix=prefix, + inference_filename_prefix=inference_filename_prefix, + inference_filepath_suffix=inference_filepath_suffix, + index=index, + ftype=filetype, + extension=extension, + ) + + def create_resume_input_filename(filetype: str, liketype: str) -> str: run_id = os.environ.get("RESUME_RUN_INDEX") prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('RESUME_RUN_INDEX')}" @@ -381,18 +385,20 @@ def create_resume_input_filename(filetype: str, liketype: str) -> str: extension = "parquet" if filetype == "seed": extension = "csv" - return file_paths.create_file_name(run_id=run_id, - prefix=prefix, - inference_filepath_suffix=inference_filepath_suffix, - index=index, - ftype=filetype, - extension=extension) + return file_paths.create_file_name( + run_id=run_id, + prefix=prefix, + inference_filepath_suffix=inference_filepath_suffix, + index=index, + ftype=filetype, + extension=extension, + ) def get_parquet_types_for_resume() -> List[str]: """ - Retrieves a list of parquet file types that are relevant for resuming a process based on - specific environment variable settings. This function dynamically determines the list + Retrieves a list of parquet file types that are relevant for resuming a process based on + specific environment variable settings. This function dynamically determines the list based on the current operational context given by the environment. The function checks two environment variables: @@ -408,7 +414,7 @@ def get_parquet_types_for_resume() -> List[str]: return ["seed", "spar", "snpi", "hpar", "hnpi", "init"] else: return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - + def create_resume_file_names_map() -> Dict[str, str]: """ @@ -418,21 +424,21 @@ def create_resume_file_names_map() -> Dict[str, str]: The mappings depend on: - Parquet file types appropriate for resuming a process, as determined by the environment. - - Whether the files are for 'global' or 'chimeric' types, as these liketypes influence the + - Whether the files are for 'global' or 'chimeric' types, as these liketypes influence the file naming convention. - - The operational block index ('FLEPI_BLOCK_INDEX'), which can alter the input file names for + - The operational block index ('FLEPI_BLOCK_INDEX'), which can alter the input file names for block index '1'. - The presence and value of 'LAST_JOB_OUTPUT' environment variable, which if set to an S3 path, adjusts the keys in the mapping to be prefixed with this path. Returns: - Dict[str, str]: A dictionary where keys are input file paths and values are corresponding - output file paths. The paths may be modified by the 'LAST_JOB_OUTPUT' if it + Dict[str, str]: A dictionary where keys are input file paths and values are corresponding + output file paths. The paths may be modified by the 'LAST_JOB_OUTPUT' if it is set and points to an S3 location. Raises: - No explicit exceptions are raised within the function, but it relies heavily on external - functions and environment variables which if improperly configured could lead to unexpected + No explicit exceptions are raised within the function, but it relies heavily on external + functions and environment variables which if improperly configured could lead to unexpected behavior. """ parquet_types = get_parquet_types_for_resume() @@ -445,7 +451,7 @@ def create_resume_file_names_map() -> Dict[str, str]: if os.environ.get("FLEPI_BLOCK_INDEX") == "1": input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) resume_file_name_mapping[input_file_name] = output_file_name - + last_job_output = os.environ.get("LAST_JOB_OUTPUT") if last_job_output.find("s3://") >= 0: old_keys = list(resume_file_name_mapping.keys()) @@ -458,24 +464,24 @@ def create_resume_file_names_map() -> Dict[str, str]: def download_file_from_s3(name_map: Dict[str, str]) -> None: """ - Downloads files from AWS S3 based on a mapping of S3 URIs to local file paths. The function - checks if the directory for the first output file exists and creates it if necessary. It - then iterates over each S3 URI in the provided mapping, downloads the file to the corresponding + Downloads files from AWS S3 based on a mapping of S3 URIs to local file paths. The function + checks if the directory for the first output file exists and creates it if necessary. It + then iterates over each S3 URI in the provided mapping, downloads the file to the corresponding local path, and handles errors if the S3 URI format is incorrect or if the download fails. Parameters: - name_map (Dict[str, str]): A dictionary where keys are S3 URIs (strings) and values - are the local file paths (strings) where the files should + name_map (Dict[str, str]): A dictionary where keys are S3 URIs (strings) and values + are the local file paths (strings) where the files should be saved. Returns: - None: This function does not return a value; its primary effect is the side effect of + None: This function does not return a value; its primary effect is the side effect of downloading files and potentially creating directories. Raises: ValueError: If an S3 URI does not start with 's3://', indicating an invalid format. ClientError: If an error occurs during the download from S3, such as a permissions issue, - a missing file, or network-related errors. These are caught and logged but not + a missing file, or network-related errors. These are caught and logged but not re-raised, to allow the function to attempt subsequent downloads. Examples: @@ -493,7 +499,7 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: >>> download_file_from_s3(name_map) # This will raise a ValueError indicating the invalid S3 URI format. """ - s3 = boto3.client('s3') + s3 = boto3.client("s3") first_output_filename = next(iter(name_map.values())) output_dir = os.path.dirname(first_output_filename) if not os.path.exists(output_dir): @@ -501,12 +507,12 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: for s3_uri in name_map: try: - if s3_uri.startswith('s3://'): - bucket = s3_uri.split('/')[2] - object = s3_uri[len(bucket)+6:] + if s3_uri.startswith("s3://"): + bucket = s3_uri.split("/")[2] + object = s3_uri[len(bucket) + 6 :] s3.download_file(bucket, object, name_map[s3_uri]) else: - raise ValueError(f'Invalid S3 URI format {s3_uri}') + raise ValueError(f"Invalid S3 URI format {s3_uri}") except ClientError as e: print(f"An error occurred: {e}") - print("Could not download file from s3") \ No newline at end of file + print("Could not download file from s3") From dc44991003f33f14f7c39bf4cfe129820f628b95 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Wed, 15 May 2024 09:46:29 -0400 Subject: [PATCH 09/33] address requested changes by koji --- flepimop/gempyor_pkg/src/gempyor/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 1f4cc55e6..6c2aeb18d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -359,8 +359,6 @@ def create_resume_out_filename(filetype: str, liketype: str) -> str: run_id = os.environ.get("FLEPI_RUN_INDEX") prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('FLEPI_RUN_INDEX')}" inference_filepath_suffix = f"{liketype}/intermidate" - # FLEPI_SLOT_INDEX = int(os.environ.get("FLEPI_SLOT_INDEX")) - # inference_filename_prefix='%09d.' % FLEPI_SLOT_INDEX inference_filename_prefix = "{:09d}.".format(int(os.environ.get("FLEPI_SLOT_INDEX"))) index = "{:09d}.{:09d}".format(1, int(os.environ.get("FLEPI_BLOCK_INDEX")) - 1) extension = "parquet" From 5498298cf3b0d69f18c99b5bf1b0af730d89ce8a Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Mon, 17 Jun 2024 10:49:40 -0400 Subject: [PATCH 10/33] add functions to copy files at local and get rid of the reading of environment variables --- flepimop/gempyor_pkg/src/gempyor/utils.py | 84 +++++++++++++------ .../gempyor_pkg/tests/utils/test_utils.py | 54 ++++++++---- 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 6c2aeb18d..ab0e4302f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -355,17 +355,21 @@ def bash(command): print("END AWS DIAGNOSIS ================================") -def create_resume_out_filename(filetype: str, liketype: str) -> str: - run_id = os.environ.get("FLEPI_RUN_INDEX") - prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('FLEPI_RUN_INDEX')}" - inference_filepath_suffix = f"{liketype}/intermidate" - inference_filename_prefix = "{:09d}.".format(int(os.environ.get("FLEPI_SLOT_INDEX"))) - index = "{:09d}.{:09d}".format(1, int(os.environ.get("FLEPI_BLOCK_INDEX")) - 1) +def create_resume_out_filename(flepi_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + flepi_block_index: str, + filetype: str, + liketype: str) -> str: + prefix = f"{flepi_prefix}/{flepi_run_index}" + inference_filepath_suffix = f"{liketype}/intermediate" + inference_filename_prefix = "{:09d}.".format(int(flepi_slot_index)) + index = "{:09d}.{:09d}".format(1, int(flepi_block_index) - 1) extension = "parquet" if filetype == "seed": extension = "csv" return file_paths.create_file_name( - run_id=run_id, + run_id=flepi_run_index, prefix=prefix, inference_filename_prefix=inference_filename_prefix, inference_filepath_suffix=inference_filepath_suffix, @@ -375,16 +379,15 @@ def create_resume_out_filename(filetype: str, liketype: str) -> str: ) -def create_resume_input_filename(filetype: str, liketype: str) -> str: - run_id = os.environ.get("RESUME_RUN_INDEX") - prefix = f"{os.environ.get('FLEPI_PREFIX')}/{os.environ.get('RESUME_RUN_INDEX')}" +def create_resume_input_filename(resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str) -> str: + prefix = f"{flepi_prefix}/{resume_run_index}" inference_filepath_suffix = f"{liketype}/final" - index = os.environ.get("FLEPI_SLOT_INDEX") + index = flepi_slot_index extension = "parquet" if filetype == "seed": extension = "csv" return file_paths.create_file_name( - run_id=run_id, + run_id=resume_run_index, prefix=prefix, inference_filepath_suffix=inference_filepath_suffix, index=index, @@ -393,18 +396,16 @@ def create_resume_input_filename(filetype: str, liketype: str) -> str: ) -def get_parquet_types_for_resume() -> List[str]: +def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) -> List[str]: """ Retrieves a list of parquet file types that are relevant for resuming a process based on specific environment variable settings. This function dynamically determines the list based on the current operational context given by the environment. The function checks two environment variables: - - `RESUME_DISCARD_SEEDING`: Determines whether seeding-related file types should be included. - - `FLEPI_BLOCK_INDEX`: Determines a specific operational mode or block of the process. + - `resume_discard_seeding`: Determines whether seeding-related file types should be included. + - `flepi_block_index`: Determines a specific operational mode or block of the process. """ - resume_discard_seeding = os.environ.get("RESUME_DISCARD_SEEDING") - flepi_block_index = os.environ.get("FLEPI_BLOCK_INDEX") if flepi_block_index == "1": if resume_discard_seeding == "true": return ["spar", "snpi", "hpar", "hnpi", "init"] @@ -414,7 +415,14 @@ def get_parquet_types_for_resume() -> List[str]: return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] -def create_resume_file_names_map() -> Dict[str, str]: +def create_resume_file_names_map(resume_discard_seeding, + flepi_block_index, + resume_run_index, + flepi_prefix, + flepi_slot_index, + flepi_run_index, + last_job_output + ) -> Dict[str, str]: """ Generates a mapping of input file names to output file names for a resume process based on parquet file types and environmental conditions. The function adjusts the file name mappings @@ -439,18 +447,26 @@ def create_resume_file_names_map() -> Dict[str, str]: functions and environment variables which if improperly configured could lead to unexpected behavior. """ - parquet_types = get_parquet_types_for_resume() + file_types = get_filetype_for_resume(resume_discard_seeding=resume_discard_seeding, + flepi_block_index=flepi_block_index) resume_file_name_mapping = dict() liketypes = ["global", "chimeric"] - for filetype in parquet_types: + for filetype in file_types: for liketype in liketypes: - output_file_name = create_resume_out_filename(filetype=filetype, liketype=liketype) + output_file_name = create_resume_out_filename(flepi_run_index=flepi_run_index, + flepi_prefix=flepi_prefix, + flepi_slot_index=flepi_slot_index, + flepi_block_index=flepi_block_index, + filetype=filetype, + liketype=liketype) input_file_name = output_file_name if os.environ.get("FLEPI_BLOCK_INDEX") == "1": - input_file_name = create_resume_input_filename(filetype=filetype, liketype=liketype) + input_file_name = create_resume_input_filename(resume_run_index=resume_run_index, + flepi_prefix=flepi_prefix, + flepi_slot_index=flepi_slot_index, + filetype=filetype, + liketype=liketype) resume_file_name_mapping[input_file_name] = output_file_name - - last_job_output = os.environ.get("LAST_JOB_OUTPUT") if last_job_output.find("s3://") >= 0: old_keys = list(resume_file_name_mapping.keys()) for k in old_keys: @@ -514,3 +530,23 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: except ClientError as e: print(f"An error occurred: {e}") print("Could not download file from s3") + +def move_file_at_local(name_map: Dict[str, str]) -> None: + """ + Moves files locally according to a given mapping. + + This function takes a dictionary where the keys are source file paths and + the values are destination file paths. It ensures that the destination + directories exist and then copies the files from the source paths to the + destination paths. + + Parameters: + name_map (Dict[str, str]): A dictionary mapping source file paths to + destination file paths. + + Returns: + None + """ + for src, dst in name_map.items(): + os.path.makedirs(os.path.dirname(dst), exist_ok = True) + shutil.copy(src, dst) \ No newline at end of file diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 260a623ac..3743df4f3 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -101,46 +101,66 @@ def env_vars(monkeypatch): def test_create_resume_out_filename(env_vars): - result = utils.create_resume_out_filename("spar", "global") - expected_filename = "model_output/output/123/spar/global/intermidate/000000002.000000001.000000001.123.spar.parquet" + result = utils.create_resume_out_filename(flepi_run_index="123", + flepi_prefix="output", + flepi_slot_index="2", + flepi_block_index="2", + filetype = "spar", + liketype = "global") + expected_filename = "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" assert result == expected_filename - result2 = utils.create_resume_out_filename("seed", "chimeric") - expected_filename2 = "model_output/output/123/seed/chimeric/intermidate/000000002.000000001.000000001.123.seed.csv" + result2 = utils.create_resume_out_filename(flepi_run_index="123", + flepi_prefix="output", + flepi_slot_index="2", + flepi_block_index="2", + filetype = "seed", + liketype = "chimeric") + expected_filename2 = "model_output/output/123/seed/chimeric/intermediate/000000002.000000001.000000001.123.seed.csv" assert result2 == expected_filename2 def test_create_resume_input_filename(env_vars): - result = utils.create_resume_input_filename("spar", "global") + result = utils.create_resume_input_filename(flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="spar", + liketype="global") expect_filename = 'model_output/output/321/spar/global/final/000000002.321.spar.parquet' assert result == expect_filename - result2 = utils.create_resume_input_filename("seed", "chimeric") + result2 = utils.create_resume_input_filename(flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="seed", liketype="chimeric") expect_filename2 = 'model_output/output/321/seed/chimeric/final/000000002.321.seed.csv' assert result2 == expect_filename2 -@patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "true", "FLEPI_BLOCK_INDEX": "1"}) -def test_get_parquet_types_resume_discard_seeding_true_flepi_block_index_1(): +def test_get_filetype_resume_discard_seeding_true_flepi_block_index_1(): expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_parquet_types_for_resume() == expected_types + assert utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") == expected_types -@patch.dict(os.environ, {"RESUME_DISCARD_SEEDING": "false", "FLEPI_BLOCK_INDEX": "1"}) -def test_get_parquet_types_resume_discard_seeding_false_flepi_block_index_1(): +def test_get_filetype_resume_discard_seeding_false_flepi_block_index_1(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_parquet_types_for_resume() == expected_types + assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") == expected_types -@patch.dict(os.environ, {"FLEPI_BLOCK_INDEX": "2"}) -def test_get_parquet_types_flepi_block_index_2(): +def test_get_filetype_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_parquet_types_for_resume() == expected_types + assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") == expected_types -def test_create_resume_file_names_map(env_vars): - name_map = utils.create_resume_file_names_map() +def test_create_resume_file_names_map(): + name_map = utils.create_resume_file_names_map(resume_discard_seeding="false", + flepi_block_index="2", + resume_run_index="321", + flepi_prefix="output", + flepi_slot_index="2", + flepi_run_index="123", + last_job_output="s3://bucket") for k in name_map: assert k.find("s3://bucket") >= 0 From ef0c86447e77fea4ddf95aa8502a4b279d78fea0 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Tue, 18 Jun 2024 09:26:31 -0400 Subject: [PATCH 11/33] remove unnecessary test code --- flepimop/gempyor_pkg/tests/utils/test_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 3743df4f3..254e8290f 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -1,11 +1,8 @@ import pytest import os import pandas as pd - -# import dask.dataframe as dd import pyarrow as pa import time -from unittest.mock import patch from gempyor import utils DATA_DIR = os.path.dirname(__file__) + "/data" @@ -89,17 +86,6 @@ def test_get_log_normal_success(): utils.get_log_normal(meanlog=0, sdlog=1) -@pytest.fixture -def env_vars(monkeypatch): - # Setting environment variables for the test - monkeypatch.setenv("RESUME_RUN_INDEX", "321") - monkeypatch.setenv("FLEPI_PREFIX", "output") - monkeypatch.setenv("FLEPI_SLOT_INDEX", "2") - monkeypatch.setenv("FLEPI_BLOCK_INDEX", "2") - monkeypatch.setenv("FLEPI_RUN_INDEX", "123") - monkeypatch.setenv("LAST_JOB_OUTPUT", "s3://bucket") - - def test_create_resume_out_filename(env_vars): result = utils.create_resume_out_filename(flepi_run_index="123", flepi_prefix="output", From 413485a1f9ce47c855d06a745ccd7ddde682f241 Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Tue, 18 Jun 2024 09:42:03 -0400 Subject: [PATCH 12/33] format code --- flepimop/gempyor_pkg/src/gempyor/utils.py | 83 +++++++++---------- .../gempyor_pkg/tests/utils/test_utils.py | 77 ++++++++--------- 2 files changed, 79 insertions(+), 81 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index ab0e4302f..4d3209061 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -227,15 +227,11 @@ def as_random_distribution(self): dist = self["distribution"].get() if dist == "fixed": return functools.partial( - np.random.uniform, - self["value"].as_evaled_expression(), - self["value"].as_evaled_expression(), + np.random.uniform, self["value"].as_evaled_expression(), self["value"].as_evaled_expression(), ) elif dist == "uniform": return functools.partial( - np.random.uniform, - self["low"].as_evaled_expression(), - self["high"].as_evaled_expression(), + np.random.uniform, self["low"].as_evaled_expression(), self["high"].as_evaled_expression(), ) elif dist == "poisson": return functools.partial(np.random.poisson, self["lam"].as_evaled_expression()) @@ -260,18 +256,13 @@ def as_random_distribution(self): ).rvs elif dist == "lognorm": return get_log_normal( - meanlog=self["meanlog"].as_evaled_expression(), - sdlog=self["sdlog"].as_evaled_expression(), + meanlog=self["meanlog"].as_evaled_expression(), sdlog=self["sdlog"].as_evaled_expression(), ).rvs else: raise NotImplementedError(f"unknown distribution [got: {dist}]") else: # we allow a fixed value specified directly: - return functools.partial( - np.random.uniform, - self.as_evaled_expression(), - self.as_evaled_expression(), - ) + return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) def list_filenames(folder: str = ".", filters: list = []) -> list: @@ -355,12 +346,9 @@ def bash(command): print("END AWS DIAGNOSIS ================================") -def create_resume_out_filename(flepi_run_index: str, - flepi_prefix: str, - flepi_slot_index: str, - flepi_block_index: str, - filetype: str, - liketype: str) -> str: +def create_resume_out_filename( + flepi_run_index: str, flepi_prefix: str, flepi_slot_index: str, flepi_block_index: str, filetype: str, liketype: str +) -> str: prefix = f"{flepi_prefix}/{flepi_run_index}" inference_filepath_suffix = f"{liketype}/intermediate" inference_filename_prefix = "{:09d}.".format(int(flepi_slot_index)) @@ -379,7 +367,9 @@ def create_resume_out_filename(flepi_run_index: str, ) -def create_resume_input_filename(resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str) -> str: +def create_resume_input_filename( + resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str +) -> str: prefix = f"{flepi_prefix}/{resume_run_index}" inference_filepath_suffix = f"{liketype}/final" index = flepi_slot_index @@ -415,14 +405,15 @@ def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] -def create_resume_file_names_map(resume_discard_seeding, - flepi_block_index, - resume_run_index, - flepi_prefix, - flepi_slot_index, - flepi_run_index, - last_job_output - ) -> Dict[str, str]: +def create_resume_file_names_map( + resume_discard_seeding, + flepi_block_index, + resume_run_index, + flepi_prefix, + flepi_slot_index, + flepi_run_index, + last_job_output, +) -> Dict[str, str]: """ Generates a mapping of input file names to output file names for a resume process based on parquet file types and environmental conditions. The function adjusts the file name mappings @@ -447,25 +438,30 @@ def create_resume_file_names_map(resume_discard_seeding, functions and environment variables which if improperly configured could lead to unexpected behavior. """ - file_types = get_filetype_for_resume(resume_discard_seeding=resume_discard_seeding, - flepi_block_index=flepi_block_index) + file_types = get_filetype_for_resume( + resume_discard_seeding=resume_discard_seeding, flepi_block_index=flepi_block_index + ) resume_file_name_mapping = dict() liketypes = ["global", "chimeric"] for filetype in file_types: for liketype in liketypes: - output_file_name = create_resume_out_filename(flepi_run_index=flepi_run_index, - flepi_prefix=flepi_prefix, - flepi_slot_index=flepi_slot_index, - flepi_block_index=flepi_block_index, - filetype=filetype, - liketype=liketype) + output_file_name = create_resume_out_filename( + flepi_run_index=flepi_run_index, + flepi_prefix=flepi_prefix, + flepi_slot_index=flepi_slot_index, + flepi_block_index=flepi_block_index, + filetype=filetype, + liketype=liketype, + ) input_file_name = output_file_name if os.environ.get("FLEPI_BLOCK_INDEX") == "1": - input_file_name = create_resume_input_filename(resume_run_index=resume_run_index, - flepi_prefix=flepi_prefix, - flepi_slot_index=flepi_slot_index, - filetype=filetype, - liketype=liketype) + input_file_name = create_resume_input_filename( + resume_run_index=resume_run_index, + flepi_prefix=flepi_prefix, + flepi_slot_index=flepi_slot_index, + filetype=filetype, + liketype=liketype, + ) resume_file_name_mapping[input_file_name] = output_file_name if last_job_output.find("s3://") >= 0: old_keys = list(resume_file_name_mapping.keys()) @@ -531,6 +527,7 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: print(f"An error occurred: {e}") print("Could not download file from s3") + def move_file_at_local(name_map: Dict[str, str]) -> None: """ Moves files locally according to a given mapping. @@ -548,5 +545,5 @@ def move_file_at_local(name_map: Dict[str, str]) -> None: None """ for src, dst in name_map.items(): - os.path.makedirs(os.path.dirname(dst), exist_ok = True) - shutil.copy(src, dst) \ No newline at end of file + os.path.makedirs(os.path.dirname(dst), exist_ok=True) + shutil.copy(src, dst) diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 254e8290f..d1d3de50d 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -10,11 +10,7 @@ @pytest.mark.parametrize( - ("fname", "extension"), - [ - ("mobility", "csv"), - ("usa-geoid-params-output", "parquet"), - ], + ("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet"),], ) def test_read_df_and_write_success(fname, extension): os.chdir(tmp_path) @@ -87,41 +83,44 @@ def test_get_log_normal_success(): def test_create_resume_out_filename(env_vars): - result = utils.create_resume_out_filename(flepi_run_index="123", - flepi_prefix="output", - flepi_slot_index="2", - flepi_block_index="2", - filetype = "spar", - liketype = "global") - expected_filename = "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" + result = utils.create_resume_out_filename( + flepi_run_index="123", + flepi_prefix="output", + flepi_slot_index="2", + flepi_block_index="2", + filetype="spar", + liketype="global", + ) + expected_filename = ( + "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" + ) assert result == expected_filename - - result2 = utils.create_resume_out_filename(flepi_run_index="123", - flepi_prefix="output", - flepi_slot_index="2", - flepi_block_index="2", - filetype = "seed", - liketype = "chimeric") + + result2 = utils.create_resume_out_filename( + flepi_run_index="123", + flepi_prefix="output", + flepi_slot_index="2", + flepi_block_index="2", + filetype="seed", + liketype="chimeric", + ) expected_filename2 = "model_output/output/123/seed/chimeric/intermediate/000000002.000000001.000000001.123.seed.csv" assert result2 == expected_filename2 def test_create_resume_input_filename(env_vars): - result = utils.create_resume_input_filename(flepi_slot_index="2", - resume_run_index="321", - flepi_prefix="output", - filetype="spar", - liketype="global") - expect_filename = 'model_output/output/321/spar/global/final/000000002.321.spar.parquet' + result = utils.create_resume_input_filename( + flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global" + ) + expect_filename = "model_output/output/321/spar/global/final/000000002.321.spar.parquet" assert result == expect_filename - - result2 = utils.create_resume_input_filename(flepi_slot_index="2", - resume_run_index="321", - flepi_prefix="output", - filetype="seed", liketype="chimeric") - expect_filename2 = 'model_output/output/321/seed/chimeric/final/000000002.321.seed.csv' + + result2 = utils.create_resume_input_filename( + flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="seed", liketype="chimeric" + ) + expect_filename2 = "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" assert result2 == expect_filename2 @@ -141,12 +140,14 @@ def test_get_filetype_flepi_block_index_2(): def test_create_resume_file_names_map(): - name_map = utils.create_resume_file_names_map(resume_discard_seeding="false", - flepi_block_index="2", - resume_run_index="321", - flepi_prefix="output", - flepi_slot_index="2", - flepi_run_index="123", - last_job_output="s3://bucket") + name_map = utils.create_resume_file_names_map( + resume_discard_seeding="false", + flepi_block_index="2", + resume_run_index="321", + flepi_prefix="output", + flepi_slot_index="2", + flepi_run_index="123", + last_job_output="s3://bucket", + ) for k in name_map: assert k.find("s3://bucket") >= 0 From 59cc58eccbec86189016fbf57dd48d071c3ef7bd Mon Sep 17 00:00:00 2001 From: fang19911030 Date: Tue, 18 Jun 2024 09:53:14 -0400 Subject: [PATCH 13/33] fix --- flepimop/gempyor_pkg/tests/utils/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index d1d3de50d..768451b2a 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -82,7 +82,7 @@ def test_get_log_normal_success(): utils.get_log_normal(meanlog=0, sdlog=1) -def test_create_resume_out_filename(env_vars): +def test_create_resume_out_filename(): result = utils.create_resume_out_filename( flepi_run_index="123", flepi_prefix="output", @@ -108,7 +108,7 @@ def test_create_resume_out_filename(env_vars): assert result2 == expected_filename2 -def test_create_resume_input_filename(env_vars): +def test_create_resume_input_filename(): result = utils.create_resume_input_filename( flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global" From bb207a203c5619c7f002829da3ba2cf4325a6bb0 Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:05:07 -0400 Subject: [PATCH 14/33] fix errors in aggregation stuff --- .../{config_library => }/config_sample_2pop_inference.yml | 8 ++++---- flepimop/main_scripts/inference_slot.R | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) rename examples/tutorial_two_subpops/{config_library => }/config_sample_2pop_inference.yml (98%) diff --git a/examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml b/examples/tutorial_two_subpops/config_sample_2pop_inference.yml similarity index 98% rename from examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml rename to examples/tutorial_two_subpops/config_sample_2pop_inference.yml index e6f644201..bd3cba6d5 100644 --- a/examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml +++ b/examples/tutorial_two_subpops/config_sample_2pop_inference.yml @@ -47,7 +47,7 @@ seir: seir_modifiers: scenarios: - - Ro_all + - inference modifiers: Ro_mod: # assume same for all subpopulations method: SinglePeriodModifier @@ -86,7 +86,7 @@ seir_modifiers: sd: 0.025 a: -0.1 b: 0.1 - Ro_all: + inference: method: StackedModifier modifiers: ["Ro_mod","Ro_lockdown"] @@ -122,10 +122,10 @@ outcomes: outcome_modifiers: scenarios: - - test_limits + - all modifiers: # assume that due to limitations in testing, initially the case detection probability was lower - test_limits: + all: method: SinglePeriodModifier parameter: incidCase subpop: "all" diff --git a/flepimop/main_scripts/inference_slot.R b/flepimop/main_scripts/inference_slot.R index 90d289640..031838c03 100644 --- a/flepimop/main_scripts/inference_slot.R +++ b/flepimop/main_scripts/inference_slot.R @@ -91,6 +91,7 @@ if (opt$config == ""){ } config = flepicommon::load_config(opt$config) +opt$total_ll_multiplier <- 1 if (!is.null(config$inference$incl_aggr_likelihood)){ print("Using config option for `incl_aggr_likelihood`.") opt$incl_aggr_likelihood <- config$inference$incl_aggr_likelihood @@ -450,7 +451,7 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { autowrite_seir = TRUE ) }, error = function(e) { - print("GempyorInference failed to run (call on l. 405 of inference_slot.R).") + print("GempyorInference failed to run (call on l. 443 of inference_slot.R).") print("Here is all the debug information I could find:") for(m in reticulate::py_last_error()) print(m) stop("GempyorInference failed to run... stopping") @@ -637,8 +638,8 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { sim_hosp <- sim_hosp %>% dplyr::bind_rows( sim_hosp %>% - dplyr::select(-tidyselect::all_of(obs_subpop), -tidyselect::starts_with("date")) %>% - dplyr::group_by(time) %>% + dplyr::select(-tidyselect::all_of(obs_subpop)) %>% + dplyr::group_by(date) %>% dplyr::summarise(dplyr::across(tidyselect::everything(), sum)) %>% # no likelihood is calculated for time periods with missing data for any subpop dplyr::mutate(!!obs_subpop := "Total") ) From 9b99b89b823ada1c7c9a8a06d268e99cd91e15d3 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 11:46:20 -0400 Subject: [PATCH 15/33] Overhual write_df unit tests. Overhauled the gempyor.utils.write_df unit tests by placing them in a new file with a class grouping similar fixtures. Added tests for the NotImplementedError, writing to csv, and writing to parquet. --- .../gempyor_pkg/tests/utils/test_write_df.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_write_df.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py new file mode 100644 index 000000000..d465f6434 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -0,0 +1,127 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any + +import pytest +import pandas as pd + +from gempyor.utils import write_df + + +class TestWriteDf: + """ + Unit tests for the `gempyor.utils.write_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that write_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + write_df(fname=temp_file.name, df=self.sample_df) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_write_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests writing a DataFrame to a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".csv", + path_reader=lambda x: pd.read_csv(x, index_col=False), + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_write_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests writing a DataFrame to a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".parquet", + path_reader=lambda x: pd.read_parquet(x), + ) + + def _test_write_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + df: pd.DataFrame, + extension: str, + suffix: str | None, + path_reader: Callable[[os.PathLike], pd.DataFrame], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + df: The DataFrame to write. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_reader: A function to read the DataFrame from the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + write_df(fname=fname_transformer(temp_path), df=df, extension=extension) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = path_reader(temp_path) + assert test_df.equals(df) From 6ae23b34ca3c094ff26371acf12a742816363926 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:19:56 -0400 Subject: [PATCH 16/33] Overhual read_df unit tests Overhauled the gempyor.utils.read_df unit tests by blacing them in a new file with a class for grouping similar fixtures. Added tests for the NotImplementedError, reading from csv, and reading from parquet. --- .../gempyor_pkg/tests/utils/test_read_df.py | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_read_df.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py new file mode 100644 index 000000000..d25ed9a59 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -0,0 +1,128 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any + +import pytest +import pandas as pd + +from gempyor.utils import read_df + + +class TestReadDf: + """ + Unit tests for the `gempyor.utils.read_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that write_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + read_df(fname=temp_file.name) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + read_df(fname=fname, extension="txt") + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_read_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests reading a DataFrame from a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".csv", + path_writer=lambda p, df: df.to_csv(p, index=False) + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_read_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests reading a DataFrame from a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".parquet", + path_writer=lambda p, df: df.to_parquet(p, engine='pyarrow', index=False), + ) + + def _test_read_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + suffix: str | None, + path_writer: Callable[[os.PathLike, pd.DataFrame], None], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_writer: A function to write the DataFrame to the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + path_writer(temp_path, self.sample_df) + test_df = read_df(fname=fname_transformer(temp_path), extension=extension) + assert isinstance(test_df, pd.DataFrame) + assert temp_path.stat().st_size > 0 + assert test_df.equals(self.sample_df) From 9ad1abdd9cfc898f69c14144162988533726a5ea Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:38:24 -0400 Subject: [PATCH 17/33] Formatted read_df tests, explicit parquet engine * Formatted the `tests/utils/test_read_df.py` file. * Added `engine="pyarrow"` to `write_df` unit tests. --- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 10 +++++----- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index d25ed9a59..0b2334602 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -41,7 +41,7 @@ def test_raises_not_implemented_error(self) -> None: with NamedTemporaryFile(suffix=".txt") as temp_file: fname = temp_file.name[:-4] read_df(fname=fname, extension="txt") - + @pytest.mark.parametrize( "fname_transformer,extension", [ @@ -69,9 +69,9 @@ def test_read_csv_dataframe( fname_transformer=fname_transformer, extension=extension, suffix=".csv", - path_writer=lambda p, df: df.to_csv(p, index=False) + path_writer=lambda p, df: df.to_csv(p, index=False), ) - + @pytest.mark.parametrize( "fname_transformer,extension", [ @@ -99,9 +99,9 @@ def test_read_parquet_dataframe( fname_transformer=fname_transformer, extension=extension, suffix=".parquet", - path_writer=lambda p, df: df.to_parquet(p, engine='pyarrow', index=False), + path_writer=lambda p, df: df.to_parquet(p, engine="pyarrow", index=False), ) - + def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index d465f6434..aad3bc1c4 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -94,7 +94,7 @@ def test_write_parquet_dataframe( df=self.sample_df, extension=extension, suffix=".parquet", - path_reader=lambda x: pd.read_parquet(x), + path_reader=lambda x: pd.read_parquet(x, engine="pyarrow"), ) def _test_write_df( From 3d22841f76812f2de65c0eaaf087c9b76e435c57 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:54:11 -0400 Subject: [PATCH 18/33] Reorganized utils.py Moved read_df to be next to write_df in utils.py. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 4d3209061..2f13f0f57 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -38,6 +38,23 @@ def write_df(fname: str, df: pd.DataFrame, extension: str = ""): raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") +def read_df(fname: str, extension: str = "") -> pd.DataFrame: + """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension + can be provided as an argument or it is infered""" + fname = str(fname) + if extension: # Empty strings are falsy in python + fname = f"{fname}.{extension}" + extension = fname.split(".")[-1] + if extension == "csv": + # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent + df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) + elif extension == "parquet": + df = pa.parquet.read_table(fname).to_pandas() + else: + raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") + return df + + def command_safe_run(command, command_name="mycommand", fail_on_fail=True): import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen @@ -62,23 +79,6 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True): return sr.returncode, stdout, stderr -def read_df(fname: str, extension: str = "") -> pd.DataFrame: - """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension - can be provided as an argument or it is infered""" - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent - df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) - elif extension == "parquet": - df = pa.parquet.read_table(fname).to_pandas() - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") - return df - - def add_method(cls): "Decorator to add a method to a class" From df794e92609d1b56a2b8faf9340d74364c242902 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:02:08 -0400 Subject: [PATCH 19/33] Documented and extended write_df * Documented the gempyor.utils.write_df function using the Google style guide. * Extended write_df to explicitly support os.PathLike types for fname (was implicitly supported) and added support for bytes. * Changed the file manipulation logic to use pathlib rather than manipulating strings in write_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 49 ++++++++++++++++------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 2f13f0f57..bace01582 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -16,26 +16,47 @@ from gempyor import file_paths from typing import List, Dict from botocore.exceptions import ClientError +from pathlib import Path logger = logging.getLogger(__name__) config = confuse.Configuration("flepiMoP", read=False) -def write_df(fname: str, df: pd.DataFrame, extension: str = ""): - """write without index, so assume the index has been put a column""" - # cast to str to use .split in case fname is a PosixPath - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - df.to_csv(fname, index=False) - elif extension == "parquet": - df = pa.Table.from_pandas(df, preserve_index=False) - pa.parquet.write_table(df, fname) - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") +def write_df( + fname: str | bytes | os.PathLike, + df: pd.DataFrame, + extension: str = "", +) -> None: + """Writes a pandas DataFrame without its index to a file. + + Writes a pandas DataFrame to either a CSV or Parquet file without its index and can + infer which format to use based on the extension given in `fname` or based on + explicit `extension`. + + Args: + fname: The name of the file to write to. + df: A pandas DataFrame whose contents to write, but without its index. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + None + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Write df to either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return df.to_csv(path, index=False) + elif path.suffix == ".parquet": + return df.to_parquet(path, index=False, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + ) def read_df(fname: str, extension: str = "") -> pd.DataFrame: From 19078a03595284ab42d9ef7ee5d4cc0fa7a8e4fe Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:48:39 -0400 Subject: [PATCH 20/33] Add test for read_df with subpop column * Added unit test reading a file with a column called 'subpop', converted to a string when the file is a csv and left unaltered when the file is a parquet file. * Typo in test_raises_not_implemented_error docstring. --- .../gempyor_pkg/tests/utils/test_read_df.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index 0b2334602..c4c9a14b0 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -5,6 +5,7 @@ import pytest import pandas as pd +from pandas.api.types import is_object_dtype, is_numeric_dtype from gempyor.utils import read_df @@ -23,9 +24,16 @@ class TestReadDf: } ) + subpop_df: pd.DataFrame = pd.DataFrame( + { + "subpop": [1, 2, 3, 4], + "value": [5, 6, 7, 8], + } + ) + def test_raises_not_implemented_error(self) -> None: """ - Tests that write_df raises a NotImplementedError for unsupported file + Tests that read_df raises a NotImplementedError for unsupported file extensions. """ with pytest.raises( @@ -102,6 +110,38 @@ def test_read_parquet_dataframe( path_writer=lambda p, df: df.to_parquet(p, engine="pyarrow", index=False), ) + def test_subpop_is_cast_as_str(self) -> None: + """ + Tests that read_df returns an object dtype for the column 'subpop' when reading + a csv file, but not when reading a parquet file. + """ + # First confirm the dtypes of our test DataFrame + assert is_numeric_dtype(self.subpop_df["subpop"]) + assert is_numeric_dtype(self.subpop_df["value"]) + # Test that the subpop column is converted to a string for a csv file + with NamedTemporaryFile(suffix=".csv") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert self.subpop_df.to_csv(temp_path, index=False) is None + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_object_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + # Test that the subpop column remains unaltered for a parquet file + with NamedTemporaryFile(suffix=".parquet") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + self.subpop_df.to_parquet(temp_path, engine="pyarrow", index=False) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_numeric_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], From fd4021cf5a3b7955d9aab7cf1cf76724f849fb73 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:59:49 -0400 Subject: [PATCH 21/33] Documented and extended read_df * Documented the gempyor.utils.read_df function using the Google style guide. * Extended read_df to explicitly support os.PathLike types for fname (was implicitly supported) and added support for bytes. * Changed the file manipulation logic to use pathlib rather than manipulating strings in read_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 47 +++++++++++++++-------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index bace01582..f59646dc0 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -59,21 +59,38 @@ def write_df( ) -def read_df(fname: str, extension: str = "") -> pd.DataFrame: - """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension - can be provided as an argument or it is infered""" - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent - df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) - elif extension == "parquet": - df = pa.parquet.read_table(fname).to_pandas() - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") - return df +def read_df(fname: str | bytes | os.PathLike, extension: str = "") -> pd.DataFrame: + """Reads a pandas DataFrame from either a CSV or Parquet file. + + Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format + to use based on the extension given in `fname` or based on explicit `extension`. If + the file being read is a csv with a column called 'subpop' then that column will be + cast as a string. + + Args: + fname: The name of the file to read from. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + A pandas DataFrame parsed from the file given. + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Read df from either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return pd.read_csv( + path, converters={"subpop": lambda x: str(x)}, skipinitialspace=True + ) + elif path.suffix == ".parquet": + return pd.read_parquet(path, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + ) def command_safe_run(command, command_name="mycommand", fail_on_fail=True): From c51a509aeb677b09c5035efaabf77812e15d2630 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:06:56 -0400 Subject: [PATCH 22/33] Changed extension type from str to Literal * Changed extension param of read_df/write_df from str to Literal[None, "", "csv", "parquet"]. * Added unit tests for `extension=None` for both read_df/write_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 9 ++++++--- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 12 ++++++++---- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 12 ++++++++---- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index f59646dc0..2ddf74f24 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -14,7 +14,7 @@ import logging import boto3 from gempyor import file_paths -from typing import List, Dict +from typing import List, Dict, Literal from botocore.exceptions import ClientError from pathlib import Path @@ -26,7 +26,7 @@ def write_df( fname: str | bytes | os.PathLike, df: pd.DataFrame, - extension: str = "", + extension: Literal[None, "", "csv", "parquet"] = "", ) -> None: """Writes a pandas DataFrame without its index to a file. @@ -59,7 +59,10 @@ def write_df( ) -def read_df(fname: str | bytes | os.PathLike, extension: str = "") -> pd.DataFrame: +def read_df( + fname: str | bytes | os.PathLike, + extension: Literal[None, "", "csv", "parquet"] = "", +) -> pd.DataFrame: """Reads a pandas DataFrame from either a CSV or Parquet file. Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index c4c9a14b0..68345ba29 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -1,7 +1,7 @@ import os from tempfile import NamedTemporaryFile from pathlib import Path -from typing import Callable, Any +from typing import Callable, Any, Literal import pytest import pandas as pd @@ -55,6 +55,8 @@ def test_raises_not_implemented_error(self) -> None: [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "csv"), (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), ], @@ -62,7 +64,7 @@ def test_raises_not_implemented_error(self) -> None: def test_read_csv_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests reading a DataFrame from a CSV file. @@ -85,6 +87,8 @@ def test_read_csv_dataframe( [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "parquet"), (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), ], @@ -92,7 +96,7 @@ def test_read_csv_dataframe( def test_read_parquet_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests reading a DataFrame from a Parquet file. @@ -145,7 +149,7 @@ def test_subpop_is_cast_as_str(self) -> None: def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], suffix: str | None, path_writer: Callable[[os.PathLike, pd.DataFrame], None], ) -> None: diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index aad3bc1c4..22c29240b 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -1,7 +1,7 @@ import os from tempfile import NamedTemporaryFile from pathlib import Path -from typing import Callable, Any +from typing import Callable, Any, Literal import pytest import pandas as pd @@ -40,6 +40,8 @@ def test_raises_not_implemented_error(self) -> None: [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "csv"), (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), ], @@ -47,7 +49,7 @@ def test_raises_not_implemented_error(self) -> None: def test_write_csv_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests writing a DataFrame to a CSV file. @@ -71,6 +73,8 @@ def test_write_csv_dataframe( [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "parquet"), (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), ], @@ -78,7 +82,7 @@ def test_write_csv_dataframe( def test_write_parquet_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests writing a DataFrame to a Parquet file. @@ -101,7 +105,7 @@ def _test_write_df( self, fname_transformer: Callable[[os.PathLike], Any], df: pd.DataFrame, - extension: str, + extension: Literal[None, "", "csv", "parquet"], suffix: str | None, path_reader: Callable[[os.PathLike], pd.DataFrame], ) -> None: From 18d9498e5ba7bf434832719f71388cbd85cdb961 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:13:29 -0400 Subject: [PATCH 23/33] Formatted/organized utils * Applied formatting (all whitespace) to read_df/write_df functions. * Reorganized the imports in utils to be clearer. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 45 ++++++++++++----------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 2ddf74f24..e6056b815 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -1,22 +1,25 @@ -import os import datetime import functools +import logging import numbers +import os +from pathlib import Path +import shutil +import subprocess import time +from typing import List, Dict, Literal + +import boto3 +from botocore.exceptions import ClientError import confuse import numpy as np import pandas as pd import pyarrow as pa import scipy.stats import sympy.parsing.sympy_parser -import subprocess -import shutil -import logging -import boto3 + from gempyor import file_paths -from typing import List, Dict, Literal -from botocore.exceptions import ClientError -from pathlib import Path + logger = logging.getLogger(__name__) @@ -24,25 +27,25 @@ def write_df( - fname: str | bytes | os.PathLike, - df: pd.DataFrame, + fname: str | bytes | os.PathLike, + df: pd.DataFrame, extension: Literal[None, "", "csv", "parquet"] = "", ) -> None: """Writes a pandas DataFrame without its index to a file. - + Writes a pandas DataFrame to either a CSV or Parquet file without its index and can - infer which format to use based on the extension given in `fname` or based on + infer which format to use based on the extension given in `fname` or based on explicit `extension`. - + Args: fname: The name of the file to write to. df: A pandas DataFrame whose contents to write, but without its index. extension: A user specified extension to use for the file if not contained in `fname` already. - + Returns: None - + Raises: NotImplementedError: The given output extension is not supported yet. """ @@ -60,24 +63,24 @@ def write_df( def read_df( - fname: str | bytes | os.PathLike, + fname: str | bytes | os.PathLike, extension: Literal[None, "", "csv", "parquet"] = "", ) -> pd.DataFrame: """Reads a pandas DataFrame from either a CSV or Parquet file. - - Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format + + Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format to use based on the extension given in `fname` or based on explicit `extension`. If the file being read is a csv with a column called 'subpop' then that column will be cast as a string. - + Args: fname: The name of the file to read from. extension: A user specified extension to use for the file if not contained in `fname` already. - + Returns: A pandas DataFrame parsed from the file given. - + Raises: NotImplementedError: The given output extension is not supported yet. """ From 0e8241d62101a9ca56029af373befc52f7fac728 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:49:28 -0400 Subject: [PATCH 24/33] Typo in NotImplementedError, corresponding tests * Added missing period in NotImplementedError in read_df/write_df, updated corresponding unit tests. Also use path suffix directly instead of given extension. * Added missing test for write_df with provided extension raising NotImplementedError. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 4 ++-- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 4 ++-- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 9 ++++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index e6056b815..075162ce4 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -58,7 +58,7 @@ def write_df( elif path.suffix == ".parquet": return df.to_parquet(path, index=False, engine="pyarrow") raise NotImplementedError( - f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." ) @@ -95,7 +95,7 @@ def read_df( elif path.suffix == ".parquet": return pd.read_parquet(path, engine="pyarrow") raise NotImplementedError( - f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." ) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index 68345ba29..7a0a0c581 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -38,13 +38,13 @@ def test_raises_not_implemented_error(self) -> None: """ with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: read_df(fname=temp_file.name) with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: fname = temp_file.name[:-4] diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index 22c29240b..b13e0b948 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -30,10 +30,17 @@ def test_raises_not_implemented_error(self) -> None: """ with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: write_df(fname=temp_file.name, df=self.sample_df) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + write_df(fname=fname, df=self.sample_df, extension="txt") @pytest.mark.parametrize( "fname_transformer,extension", From 1b5b92d97371a91112ba28209cda1d9c9b46b2e9 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 10:39:13 -0400 Subject: [PATCH 25/33] Added unit tests for list_filenames * Created unit tests in tests/utils/test_list_filenames.py, including tests for searching flat and nested folders. * Added relevant documentation to the tests as well. * Created create_directories_with_files pytest class fixture, might need to be extracted into a more general purpose location later. --- .../tests/utils/test_list_filenames.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_list_filenames.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py new file mode 100644 index 000000000..dea2b4cbc --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py @@ -0,0 +1,166 @@ +"""Unit tests for the `gempyor.utils.list_filenames` function. + +These tests cover scenarios for finding files in both flat and nested directories. +""" + +from collections.abc import Generator +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from gempyor.utils import list_filenames + + +@pytest.fixture(scope="class") +def create_directories_with_files( + request: pytest.FixtureRequest, +) -> Generator[tuple[TemporaryDirectory, TemporaryDirectory], None, None]: + """Fixture to create temporary directories with files for testing. + + This fixture creates two temporary directories: + - A flat directory with files. + - A nested directory with files organized in subdirectories. + + The directories and files are cleaned up after the tests are run. + + Args: + request: The pytest fixture request object. + + Yields: + tuple: A tuple containing the flat and nested TemporaryDirectory objects. + """ + # Create a flat and nested directories + flat_temp_dir = TemporaryDirectory() + nested_temp_dir = TemporaryDirectory() + # Setup flat directory + for file in ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]: + Path(f"{flat_temp_dir.name}/{file}").touch() + # Setup nested directory structure + for file in [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ]: + path = Path(f"{nested_temp_dir.name}/{file}") + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + # Yield + request.cls.flat_temp_dir = flat_temp_dir + request.cls.nested_temp_dir = nested_temp_dir + yield (flat_temp_dir, nested_temp_dir) + # Clean up directories on test end + flat_temp_dir.cleanup() + nested_temp_dir.cleanup() + + +@pytest.mark.usefixtures("create_directories_with_files") +class TestListFilenames: + """Unit tests for the `gempyor.utils.list_filenames` function.""" + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ([], ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + (["hosp"], ["hosp.csv", "hosp.parquet"]), + (["spar"], ["spar.csv", "spar.parquet"]), + ([".parquet"], ["hosp.parquet", "spar.parquet"]), + ([".csv"], ["hosp.csv", "spar.csv"]), + (["hosp", ".csv"], ["hosp.csv"]), + (["spar", ".parquet"], ["spar.parquet"]), + (["hosp", "spar"], []), + ([".csv", ".parquet"], []), + ], + ) + def test_finds_files_in_flat_directory( + self, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a flat directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.flat_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ( + [], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), + ( + ["hpar"], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), + (["seed"], ["seed/001.csv", "seed/002.csv"]), + (["global"], ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + ( + ["001"], + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), + (["hpar", "001"], ["hpar/chimeric/001.parquet", "hpar/global/001.parquet"]), + (["seed", "002"], ["seed/002.csv"]), + ([".tsv"], []), + ], + ) + def test_find_files_in_nested_directory( + self, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a nested directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.nested_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + + def _test_list_filenames( + self, + folder: str, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Helper method to test `list_filenames`. + + Args: + folder: The directory to search for files. + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + files = list_filenames(folder=folder, filters=filters) + assert len(files) == len(expected_basenames) + basenames = [f.removeprefix(f"{folder}/") for f in files] + assert sorted(basenames) == sorted(expected_basenames) From d963fa407e90ec2b4c8552576db23e7f614aa942 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 11:36:35 -0400 Subject: [PATCH 26/33] Documented & refactored list_filenames * Added more detail to the `gempyor.utils.list_filenames` type hints. * Formatted the documentation to comply with the Google style guide. * Refactored the internals of list_filenames to be single list comprehension instead of a loop of nested conditionals. * Allow `filters` to accept a single string. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 67 ++++++++++++++--------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 075162ce4..afafc5f89 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -309,33 +309,48 @@ def as_random_distribution(self): return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) -def list_filenames(folder: str = ".", filters: list = []) -> list: - """ - return the list of all filename and path in the provided folders. - If filters [list] is provided, then only the files that contains each of the - substrings in filter will be returned. Example to get all hosp file: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp"]) - ``` - and be sure we only get parquet: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp" , ".parquet"]) - ``` +def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str]: + """Return the list of all filenames and paths in the provided folder. + + This function lists all files in the specified folder and its subdirectories. + If filters are provided, only the files containing each of the substrings + in the filters will be returned. + + Example: + To get all files containing "hosp": + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp"], + ) + ``` + + To get only "hosp" files with a ".parquet" extension: + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp", ".parquet"], + ) + ``` + + Args: + folder: The directory to search for files. Defaults to the current directory. + filters: A string or a list of strings to filter filenames. Only files + containing all the provided substrings will be returned. Defaults to an + empty list. + + Returns: + A list of strings representing the paths to the files that match the filters. """ - from pathlib import Path - - fn_list = [] - for f in Path(str(folder)).rglob(f"*"): - if f.is_file(): # not a folder - f = str(f) - if not filters: - fn_list.append(f) - else: - if all(c in f for c in filters): - fn_list.append(str(f)) - else: - pass - return fn_list + filters = list(filters) if not isinstance(filters, list) else filters + filters = filters if len(filters) else [""] + folder = Path(folder) + files = [ + str(file) + for file in folder.rglob("*") + if file.is_file() and all(f in str(file) for f in filters) + ] + return files def rolling_mean_pad(data, window): From 0e710f4b5750ec7a6ddd145b8d32a0fa7a037303 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 12:02:36 -0400 Subject: [PATCH 27/33] Expanded types and tests for list_filenames * Expanded type support for the `folder` arg of `gempyor.utils.list_filnames` to support bytes and os.PathLike. * Vastly expanded test suite to target new supported types. * Corrected bug when filters was given as a string, uncovered by extensive test suite. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 9 ++- .../tests/utils/test_list_filenames.py | 69 +++++++++++++++++-- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index afafc5f89..5c894225a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -309,7 +309,10 @@ def as_random_distribution(self): return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) -def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str]: +def list_filenames( + folder: str | bytes | os.PathLike = ".", + filters: str | list[str] = [], +) -> list[str]: """Return the list of all filenames and paths in the provided folder. This function lists all files in the specified folder and its subdirectories. @@ -342,9 +345,9 @@ def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str Returns: A list of strings representing the paths to the files that match the filters. """ - filters = list(filters) if not isinstance(filters, list) else filters + filters = [filters] if not isinstance(filters, list) else filters filters = filters if len(filters) else [""] - folder = Path(folder) + folder = Path(folder.decode() if isinstance(folder, bytes) else folder) files = [ str(file) for file in folder.rglob("*") diff --git a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py index dea2b4cbc..f24c6aaa6 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py +++ b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py @@ -4,6 +4,7 @@ """ from collections.abc import Generator +import os from pathlib import Path from tempfile import TemporaryDirectory @@ -64,11 +65,18 @@ class TestListFilenames: @pytest.mark.parametrize( "filters,expected_basenames", [ + ("", ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), ([], ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + ("hosp", ["hosp.csv", "hosp.parquet"]), (["hosp"], ["hosp.csv", "hosp.parquet"]), + ("spar", ["spar.csv", "spar.parquet"]), (["spar"], ["spar.csv", "spar.parquet"]), + (".parquet", ["hosp.parquet", "spar.parquet"]), ([".parquet"], ["hosp.parquet", "spar.parquet"]), + (".csv", ["hosp.csv", "spar.csv"]), ([".csv"], ["hosp.csv", "spar.csv"]), + (".tsv", []), + ([".tsv"], []), (["hosp", ".csv"], ["hosp.csv"]), (["spar", ".parquet"], ["spar.parquet"]), (["hosp", "spar"], []), @@ -77,7 +85,7 @@ class TestListFilenames: ) def test_finds_files_in_flat_directory( self, - filters: list[str], + filters: str | list[str], expected_basenames: list[str], ) -> None: """Test `list_filenames` in a flat directory. @@ -91,10 +99,31 @@ def test_finds_files_in_flat_directory( filters=filters, expected_basenames=expected_basenames, ) + self._test_list_filenames( + folder=self.flat_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.flat_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) @pytest.mark.parametrize( "filters,expected_basenames", [ + ( + "", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), ( [], [ @@ -106,6 +135,15 @@ def test_finds_files_in_flat_directory( "seed/002.csv", ], ), + ( + "hpar", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), ( ["hpar"], [ @@ -115,8 +153,18 @@ def test_finds_files_in_flat_directory( "hpar/global/002.parquet", ], ), + ("seed", ["seed/001.csv", "seed/002.csv"]), (["seed"], ["seed/001.csv", "seed/002.csv"]), + ("global", ["hpar/global/001.parquet", "hpar/global/002.parquet"]), (["global"], ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + ( + "001", + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), ( ["001"], [ @@ -127,12 +175,14 @@ def test_finds_files_in_flat_directory( ), (["hpar", "001"], ["hpar/chimeric/001.parquet", "hpar/global/001.parquet"]), (["seed", "002"], ["seed/002.csv"]), + (["hpar", "001", "global"], ["hpar/global/001.parquet"]), + (".tsv", []), ([".tsv"], []), ], ) def test_find_files_in_nested_directory( self, - filters: list[str], + filters: str | list[str], expected_basenames: list[str], ) -> None: """Test `list_filenames` in a nested directory. @@ -146,11 +196,21 @@ def test_find_files_in_nested_directory( filters=filters, expected_basenames=expected_basenames, ) + self._test_list_filenames( + folder=self.nested_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.nested_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) def _test_list_filenames( self, - folder: str, - filters: list[str], + folder: str | bytes | os.PathLike, + filters: str | list[str], expected_basenames: list[str], ) -> None: """Helper method to test `list_filenames`. @@ -162,5 +222,6 @@ def _test_list_filenames( """ files = list_filenames(folder=folder, filters=filters) assert len(files) == len(expected_basenames) + folder = folder.decode() if isinstance(folder, bytes) else str(folder) basenames = [f.removeprefix(f"{folder}/") for f in files] assert sorted(basenames) == sorted(expected_basenames) From 2c7a81d78ce8e5cfc19ac9fe2865a02e7b31900d Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 13:28:27 -0400 Subject: [PATCH 28/33] Added unit tests for get_truncated_normal Added unit tests for `gempyor.utils.get_truncated_normal` function. --- .../tests/utils/test_get_truncated_normal.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py new file mode 100644 index 000000000..23e4fad58 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_truncated_normal + + +class TestGetTruncatedNormal: + """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" + + @pytest.mark.parametrize( + "mean,sd,a,b", + [ + (0.0, 1.0, 0.0, 10.0), + (0.0, 2.0, -4.0, 4.0), + (-5.0, 3.0, -5.0, 10.0), + (-3.25, 1.4, -8.74, 4.89), + (0, 1, 0, 10), + (0, 2, -4, 4), + (-5, 3, -5, 10), + ], + ) + def test_construct_distribution( + self, + mean: float | int, + sd: float | int, + a: float | int, + b: float | int, + ) -> None: + """Test the construction of a truncated normal distribution. + + This test checks whether the `get_truncated_normal` function correctly + constructs a truncated normal distribution with the specified parameters. + It verifies that the returned object is an instance of `rv_frozen`, and that + its support and parameters (mean and standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_truncated_normal(mean=mean, sd=sd, a=a, b=b) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, a) + assert np.isclose(upper, b) + assert np.isclose(dist.kwds.get("loc"), mean) + assert np.isclose(dist.kwds.get("scale"), sd) From 17c715d344023fd4e3a0926739e1da2814859d97 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 13:36:33 -0400 Subject: [PATCH 29/33] Documented and refactored get_truncated_normal * Documented `gempyor.utils.get_truncated_normal` including adding appropriate type hints. * Refactored the function lightly for legibility. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 27 ++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 5c894225a..826727c86 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -252,10 +252,31 @@ def as_evaled_expression(self): raise ValueError(f"expected numeric or string expression [got: {value}]") -def get_truncated_normal(*, mean=0, sd=1, a=0, b=10): - "Returns the truncated normal distribution" +def get_truncated_normal( + mean: float | int = 0, + sd: float | int = 1, + a: float | int = 0, + b: float | int = 10, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a truncated normal distribution. + + This function constructs a truncated normal distribution with the specified + mean, standard deviation, and bounds. The truncated normal distribution is + a normal distribution bounded within the interval [a, b]. - return scipy.stats.truncnorm((a - mean) / sd, (b - mean) / sd, loc=mean, scale=sd) + Args: + mean: The mean of the truncated normal distribution. Defaults to 0. + sd: The standard deviation of the truncated normal distribution. Defaults to 1. + a: The lower bound of the truncated normal distribution. Defaults to 0. + b: The upper bound of the truncated normal distribution. Defaults to 10. + + Returns: + rv_frozen: A frozen instance of the truncated normal distribution with the + specified parameters. + """ + lower = (a - mean) / sd + upper = (b - mean) / sd + return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd) def get_log_normal(meanlog, sdlog): From 2aaa6fbb8e7dd27f8e29c41ca9015920ce03130c Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 14:30:15 -0400 Subject: [PATCH 30/33] Added unit tests for get_log_normal Added unit tests for the `gempyor.utils.get_log_normal` function. --- .../tests/utils/test_get_log_normal.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py new file mode 100644 index 000000000..367a7f550 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_log_normal + + +class TestGetLogNormal: + """Unit tests for the `gempyor.utils.get_log_normal` function.""" + + @pytest.mark.parametrize( + "meanlog,sdlog", + [ + (1.0, 1.0), + (0.0, 2.0), + (10.0, 30.0), + (0.33, 4.56), + (9.87, 4.21), + (1, 1), + (0, 2), + (10, 30), + ], + ) + def test_construct_distribution( + self, + meanlog: float | int, + sdlog: float | int, + ) -> None: + """Test the construction of a log normal distribution. + + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and + parameters (log mean and log standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_log_normal(meanlog=meanlog, sdlog=sdlog) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, 0.0) + assert np.isclose(upper, np.inf) + assert np.isclose(dist.kwds.get("s"), sdlog) + assert np.isclose(dist.kwds.get("scale"), np.exp(meanlog)) + assert np.isclose(dist.kwds.get("loc"), 0.0) From c09b9535ddfc364702bb048fad168fceab71ac3d Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 14:35:47 -0400 Subject: [PATCH 31/33] Documented get_log_normal Added documentation for `gempyor.utils.get_log_normal` including adding appropriate type hints. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 826727c86..98c54a1f1 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -279,8 +279,23 @@ def get_truncated_normal( return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd) -def get_log_normal(meanlog, sdlog): - "Returns the log normal distribution" +def get_log_normal( + meanlog: float | int, + sdlog: float | int, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a log normal distribution. + + This function constructs a log normal distribution with the specified + log mean and log standard deviation. + + Args: + meanlog: The log of the mean of the log normal distribution. + sdlog: The log of the standard deviation of the log normal distribution. + + Returns: + rv_frozen: A frozen instance of the log normal distribution with the + specified parameters. + """ return scipy.stats.lognorm(s=sdlog, scale=np.exp(meanlog), loc=0) From e073f883fa028467386f7292f4cf3f2703a82358 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 8 Jul 2024 10:33:06 -0400 Subject: [PATCH 32/33] Added unit tests for rolling_mean_pad * Added unit tests for `gempyor.utils.rolling_mean_pad` in `tests/utils/test_rolling_mean_pad.py`. * Wrote a reference implementation of `rolling_mean_pad` for comparison purposes. --- .../tests/utils/test_rolling_mean_pad.py | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py new file mode 100644 index 000000000..94be3394a --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py @@ -0,0 +1,149 @@ +import numpy as np +import numpy.typing as npt +import pytest + +from gempyor.utils import rolling_mean_pad + + +class TestRollingMeanPad: + """Unit tests for the `gempyor.utils.rolling_mean_pad` function.""" + + # Test data for various matrix configurations + test_data = { + # 1x1 matrices + "one_by_one_const": np.array([[1.0]]), + "one_by_one_nan": np.array([[np.nan]]), + "one_by_one_rand": np.random.uniform(size=(1, 1)), + # 1xN matrices + "one_by_many_const": np.arange(start=1.0, stop=6.0).reshape((1, 5)), + "one_by_many_nan": np.repeat(np.nan, 5).reshape((1, 5)), + "one_by_many_rand": np.random.uniform(size=(1, 5)), + # Mx1 matrices + "many_by_one_const": np.arange(start=3.0, stop=9.0).reshape((6, 1)), + "many_by_one_nan": np.repeat(np.nan, 6).reshape((6, 1)), + "many_by_one_rand": np.random.uniform(size=(6, 1)), + # MxN matrices + "many_by_many_const": np.arange(start=1.0, stop=49.0).reshape((12, 4)), + "many_by_many_nan": np.repeat(np.nan, 48).reshape((12, 4)), + "many_by_many_rand": np.random.uniform(size=(12, 4)), + } + + @pytest.mark.parametrize( + "test_data_name,expected_shape,window,put_nans", + [ + # 1x1 matrices + ("one_by_one_const", (1, 1), 3, []), + ("one_by_one_const", (1, 1), 4, []), + ("one_by_one_nan", (1, 1), 3, []), + ("one_by_one_nan", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 3, []), + ("one_by_one_rand", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 5, []), + ("one_by_one_rand", (1, 1), 6, []), + # 1xN matrices + ("one_by_many_const", (1, 5), 3, []), + ("one_by_many_const", (1, 5), 4, []), + ("one_by_many_nan", (1, 5), 3, []), + ("one_by_many_nan", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 3, []), + ("one_by_many_rand", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 5, []), + ("one_by_many_rand", (1, 5), 6, []), + # Mx1 matrices + ("many_by_one_const", (6, 1), 3, []), + ("many_by_one_const", (6, 1), 4, []), + ("many_by_one_nan", (6, 1), 3, []), + ("many_by_one_nan", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 3, []), + ("many_by_one_rand", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 5, []), + ("many_by_one_rand", (6, 1), 6, []), + # MxN matrices + ("many_by_many_const", (12, 4), 3, []), + ("many_by_many_const", (12, 4), 4, []), + ("many_by_many_const", (12, 4), 5, []), + ("many_by_many_const", (12, 4), 6, []), + ("many_by_many_nan", (12, 4), 3, []), + ("many_by_many_nan", (12, 4), 4, []), + ("many_by_many_nan", (12, 4), 5, []), + ("many_by_many_nan", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 3, []), + ("many_by_many_rand", (12, 4), 4, []), + ("many_by_many_rand", (12, 4), 5, []), + ("many_by_many_rand", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 7, []), + ("many_by_many_rand", (12, 4), 8, []), + ("many_by_many_rand", (12, 4), 9, []), + ("many_by_many_rand", (12, 4), 10, []), + ("many_by_many_rand", (12, 4), 11, []), + ("many_by_many_rand", (12, 4), 12, []), + ("many_by_many_rand", (12, 4), 13, []), + ("many_by_many_rand", (12, 4), 14, []), + ("many_by_many_rand", (12, 4), 3, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 4, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 5, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 6, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 7, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 8, [(2, 2), (4, 4)]), + ], + ) + def test_rolling_mean_pad( + self, + test_data_name: str, + expected_shape: tuple[int, int], + window: int, + put_nans: list[tuple[int, int]], + ) -> None: + """Tests `rolling_mean_pad` function with various configurations of input data. + + Args: + test_data_name: The name of the test data set to use. + expected_shape: The expected shape of the output array. + window: The size of the rolling window. + put_nans: A list of indices to insert NaNs into the input data. + + Raises: + AssertionError: If the shape or contents of the output do not match the + expected values. + """ + test_data = self.test_data.get(test_data_name).copy() + if put_nans: + np.put(test_data, put_nans, np.nan) + rolling_mean_data = rolling_mean_pad(test_data, window) + rolling_mean_reference = self._rolling_mean_pad_reference(test_data, window) + assert rolling_mean_data.shape == expected_shape + assert np.isclose( + rolling_mean_data, rolling_mean_reference, equal_nan=True + ).all() + + def _rolling_mean_pad_reference( + self, data: npt.NDArray[np.number], window: int + ) -> npt.NDArray[np.number]: + """Generates a reference rolling mean with padding. + + This implementation should match the `gempyor.utils.rolling_mean_pad` + implementation, but is written for readability. As a result this + reference implementation is extremely slow. + + Args: + data: The input array for which to compute the rolling mean. + window: The size of the rolling window. + + Returns: + An array of the same shape as `data` containing the rolling mean values. + """ + # Setup + rows, cols = data.shape + output = np.zeros((rows, cols), dtype=data.dtype) + # Slow but intuitive triple loop + for i in range(rows): + for j in range(cols): + # If the last row on an even window, change the window to be one less, + # so 4 -> 3, but 5 -> 5. + sub_window = window - 1 if window % 2 == 0 and i == rows - 1 else window + weight = 1.0 / sub_window + for l in range(-((sub_window - 1) // 2), 1 + (sub_window // 2)): + i_star = min(max(i + l, 0), rows - 1) + output[i, j] += weight * data[i_star, j] + # Done + return output From a91fc6df031181b0e537b88e0ca60f9f3df287db Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 8 Jul 2024 14:15:57 -0400 Subject: [PATCH 33/33] Documented `rolling_mean_pad` * Added type hints for `gempyor.utils.rolling_mean_pad`. * Expanded the existing docstring and included an example. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 37 ++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 98c54a1f1..85820d5b2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -13,6 +13,7 @@ from botocore.exceptions import ClientError import confuse import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa import scipy.stats @@ -392,16 +393,44 @@ def list_filenames( return files -def rolling_mean_pad(data, window): +def rolling_mean_pad( + data: npt.NDArray[np.number], + window: int, +) -> npt.NDArray[np.number]: """ - Calculates rolling mean with centered window and pads the edges. + Calculates the column-wise rolling mean with centered window. Args: - data: A NumPy array !!! shape must be (n_days, nsubpops). + data: A two dimensional numpy array, typically the row dimension is time and the + column dimension is subpop. window: The window size for the rolling mean. Returns: - A NumPy array with the padded rolling mean (n_days, nsubpops). + A two dimensional numpy array that is the same shape as `data`. + + Examples: + Below is a brief set of examples showcasing how to smooth a metric, like + hospitalizations, using this function. + + >>> import numpy as np + >>> from gempyor.utils import rolling_mean_pad + >>> hospitalizations = np.arange(1., 29.).reshape((7, 4)) + >>> hospitalizations + array([[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.], + [25., 26., 27., 28.]]) + >>> rolling_mean_pad(hospitalizations, 5) + array([[ 3.4, 4.4, 5.4, 6.4], + [ 5.8, 6.8, 7.8, 8.8], + [ 9. , 10. , 11. , 12. ], + [13. , 14. , 15. , 16. ], + [17. , 18. , 19. , 20. ], + [20.2, 21.2, 22.2, 23.2], + [22.6, 23.6, 24.6, 25.6]]) """ padding_size = (window - 1) // 2 padded_data = np.pad(data, ((padding_size, padding_size), (0, 0)), mode="edge")