Skip to content

Commit

Permalink
Merge pull request #218 from HopkinsIDD/new_inference_pengcheng
Browse files Browse the repository at this point in the history
Add new functions to support downloading from AWS s3
  • Loading branch information
jcblemai authored Jun 20, 2024
2 parents a0392a6 + 9e93b99 commit 9e0ad95
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 24 deletions.
227 changes: 214 additions & 13 deletions flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import datetime
import functools
import numbers
Expand All @@ -8,7 +9,13 @@
import pyarrow as pa
import scipy.stats
import sympy.parsing.sympy_parser
import subprocess
import shutil
import logging
import boto3
from gempyor import file_paths
from typing import List, Dict
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -220,15 +227,11 @@ def as_random_distribution(self):
dist = self["distribution"].get()
if dist == "fixed":
return functools.partial(
np.random.uniform,
self["value"].as_evaled_expression(),
self["value"].as_evaled_expression(),
np.random.uniform, self["value"].as_evaled_expression(), self["value"].as_evaled_expression(),
)
elif dist == "uniform":
return functools.partial(
np.random.uniform,
self["low"].as_evaled_expression(),
self["high"].as_evaled_expression(),
np.random.uniform, self["low"].as_evaled_expression(), self["high"].as_evaled_expression(),
)
elif dist == "poisson":
return functools.partial(np.random.poisson, self["lam"].as_evaled_expression())
Expand All @@ -253,18 +256,13 @@ def as_random_distribution(self):
).rvs
elif dist == "lognorm":
return get_log_normal(
meanlog=self["meanlog"].as_evaled_expression(),
sdlog=self["sdlog"].as_evaled_expression(),
meanlog=self["meanlog"].as_evaled_expression(), sdlog=self["sdlog"].as_evaled_expression(),
).rvs
else:
raise NotImplementedError(f"unknown distribution [got: {dist}]")
else:
# we allow a fixed value specified directly:
return functools.partial(
np.random.uniform,
self.as_evaled_expression(),
self.as_evaled_expression(),
)
return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),)


def list_filenames(folder: str = ".", filters: list = []) -> list:
Expand Down Expand Up @@ -346,3 +344,206 @@ def bash(command):
print("------------")
print(f"lsblk: {bash('lsblk')}")
print("END AWS DIAGNOSIS ================================")


def create_resume_out_filename(
flepi_run_index: str, flepi_prefix: str, flepi_slot_index: str, flepi_block_index: str, filetype: str, liketype: str
) -> str:
prefix = f"{flepi_prefix}/{flepi_run_index}"
inference_filepath_suffix = f"{liketype}/intermediate"
inference_filename_prefix = "{:09d}.".format(int(flepi_slot_index))
index = "{:09d}.{:09d}".format(1, int(flepi_block_index) - 1)
extension = "parquet"
if filetype == "seed":
extension = "csv"
return file_paths.create_file_name(
run_id=flepi_run_index,
prefix=prefix,
inference_filename_prefix=inference_filename_prefix,
inference_filepath_suffix=inference_filepath_suffix,
index=index,
ftype=filetype,
extension=extension,
)


def create_resume_input_filename(
resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str
) -> str:
prefix = f"{flepi_prefix}/{resume_run_index}"
inference_filepath_suffix = f"{liketype}/final"
index = flepi_slot_index
extension = "parquet"
if filetype == "seed":
extension = "csv"
return file_paths.create_file_name(
run_id=resume_run_index,
prefix=prefix,
inference_filepath_suffix=inference_filepath_suffix,
index=index,
ftype=filetype,
extension=extension,
)


def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) -> List[str]:
"""
Retrieves a list of parquet file types that are relevant for resuming a process based on
specific environment variable settings. This function dynamically determines the list
based on the current operational context given by the environment.
The function checks two environment variables:
- `resume_discard_seeding`: Determines whether seeding-related file types should be included.
- `flepi_block_index`: Determines a specific operational mode or block of the process.
"""
if flepi_block_index == "1":
if resume_discard_seeding == "true":
return ["spar", "snpi", "hpar", "hnpi", "init"]
else:
return ["seed", "spar", "snpi", "hpar", "hnpi", "init"]
else:
return ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"]


def create_resume_file_names_map(
resume_discard_seeding,
flepi_block_index,
resume_run_index,
flepi_prefix,
flepi_slot_index,
flepi_run_index,
last_job_output,
) -> Dict[str, str]:
"""
Generates a mapping of input file names to output file names for a resume process based on
parquet file types and environmental conditions. The function adjusts the file name mappings
based on the operational block index and the location of the last job output.
The mappings depend on:
- Parquet file types appropriate for resuming a process, as determined by the environment.
- Whether the files are for 'global' or 'chimeric' types, as these liketypes influence the
file naming convention.
- The operational block index ('FLEPI_BLOCK_INDEX'), which can alter the input file names for
block index '1'.
- The presence and value of 'LAST_JOB_OUTPUT' environment variable, which if set to an S3 path,
adjusts the keys in the mapping to be prefixed with this path.
Returns:
Dict[str, str]: A dictionary where keys are input file paths and values are corresponding
output file paths. The paths may be modified by the 'LAST_JOB_OUTPUT' if it
is set and points to an S3 location.
Raises:
No explicit exceptions are raised within the function, but it relies heavily on external
functions and environment variables which if improperly configured could lead to unexpected
behavior.
"""
file_types = get_filetype_for_resume(
resume_discard_seeding=resume_discard_seeding, flepi_block_index=flepi_block_index
)
resume_file_name_mapping = dict()
liketypes = ["global", "chimeric"]
for filetype in file_types:
for liketype in liketypes:
output_file_name = create_resume_out_filename(
flepi_run_index=flepi_run_index,
flepi_prefix=flepi_prefix,
flepi_slot_index=flepi_slot_index,
flepi_block_index=flepi_block_index,
filetype=filetype,
liketype=liketype,
)
input_file_name = output_file_name
if os.environ.get("FLEPI_BLOCK_INDEX") == "1":
input_file_name = create_resume_input_filename(
resume_run_index=resume_run_index,
flepi_prefix=flepi_prefix,
flepi_slot_index=flepi_slot_index,
filetype=filetype,
liketype=liketype,
)
resume_file_name_mapping[input_file_name] = output_file_name
if last_job_output.find("s3://") >= 0:
old_keys = list(resume_file_name_mapping.keys())
for k in old_keys:
new_key = os.path.join(last_job_output, k)
resume_file_name_mapping[new_key] = resume_file_name_mapping[k]
del resume_file_name_mapping[k]
return resume_file_name_mapping


def download_file_from_s3(name_map: Dict[str, str]) -> None:
"""
Downloads files from AWS S3 based on a mapping of S3 URIs to local file paths. The function
checks if the directory for the first output file exists and creates it if necessary. It
then iterates over each S3 URI in the provided mapping, downloads the file to the corresponding
local path, and handles errors if the S3 URI format is incorrect or if the download fails.
Parameters:
name_map (Dict[str, str]): A dictionary where keys are S3 URIs (strings) and values
are the local file paths (strings) where the files should
be saved.
Returns:
None: This function does not return a value; its primary effect is the side effect of
downloading files and potentially creating directories.
Raises:
ValueError: If an S3 URI does not start with 's3://', indicating an invalid format.
ClientError: If an error occurs during the download from S3, such as a permissions issue,
a missing file, or network-related errors. These are caught and logged but not
re-raised, to allow the function to attempt subsequent downloads.
Examples:
>>> name_map = {
"s3://mybucket/data/file1.txt": "/local/path/to/file1.txt",
"s3://mybucket/data/file2.txt": "/local/path/to/file2.txt"
}
>>> download_file_from_s3(name_map)
# This would download 'file1.txt' and 'file2.txt' from 'mybucket' on S3 to the specified local paths.
# If an S3 URI is malformed:
>>> name_map = {
"http://wrongurl.com/data/file1.txt": "/local/path/to/file1.txt"
}
>>> download_file_from_s3(name_map)
# This will raise a ValueError indicating the invalid S3 URI format.
"""
s3 = boto3.client("s3")
first_output_filename = next(iter(name_map.values()))
output_dir = os.path.dirname(first_output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

for s3_uri in name_map:
try:
if s3_uri.startswith("s3://"):
bucket = s3_uri.split("/")[2]
object = s3_uri[len(bucket) + 6 :]
s3.download_file(bucket, object, name_map[s3_uri])
else:
raise ValueError(f"Invalid S3 URI format {s3_uri}")
except ClientError as e:
print(f"An error occurred: {e}")
print("Could not download file from s3")


def move_file_at_local(name_map: Dict[str, str]) -> None:
"""
Moves files locally according to a given mapping.
This function takes a dictionary where the keys are source file paths and
the values are destination file paths. It ensures that the destination
directories exist and then copies the files from the source paths to the
destination paths.
Parameters:
name_map (Dict[str, str]): A dictionary mapping source file paths to
destination file paths.
Returns:
None
"""
for src, dst in name_map.items():
os.path.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(src, dst)
83 changes: 72 additions & 11 deletions flepimop/gempyor_pkg/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import pytest
import datetime
import os
import pandas as pd

# import dask.dataframe as dd
import pyarrow as pa
import time

from gempyor import utils

DATA_DIR = os.path.dirname(__file__) + "/data"
# os.chdir(os.path.dirname(__file__))

tmp_path = "/tmp"


@pytest.mark.parametrize(
("fname", "extension"),
[
("mobility", "csv"),
("usa-geoid-params-output", "parquet"),
],
("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet"),],
)
def test_read_df_and_write_success(fname, extension):
os.chdir(tmp_path)
Expand Down Expand Up @@ -90,3 +80,74 @@ def test_get_truncated_normal_success():

def test_get_log_normal_success():
utils.get_log_normal(meanlog=0, sdlog=1)


def test_create_resume_out_filename():
result = utils.create_resume_out_filename(
flepi_run_index="123",
flepi_prefix="output",
flepi_slot_index="2",
flepi_block_index="2",
filetype="spar",
liketype="global",
)
expected_filename = (
"model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet"
)
assert result == expected_filename

result2 = utils.create_resume_out_filename(
flepi_run_index="123",
flepi_prefix="output",
flepi_slot_index="2",
flepi_block_index="2",
filetype="seed",
liketype="chimeric",
)
expected_filename2 = "model_output/output/123/seed/chimeric/intermediate/000000002.000000001.000000001.123.seed.csv"
assert result2 == expected_filename2


def test_create_resume_input_filename():

result = utils.create_resume_input_filename(
flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global"
)
expect_filename = "model_output/output/321/spar/global/final/000000002.321.spar.parquet"

assert result == expect_filename

result2 = utils.create_resume_input_filename(
flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="seed", liketype="chimeric"
)
expect_filename2 = "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv"
assert result2 == expect_filename2


def test_get_filetype_resume_discard_seeding_true_flepi_block_index_1():
expected_types = ["spar", "snpi", "hpar", "hnpi", "init"]
assert utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") == expected_types


def test_get_filetype_resume_discard_seeding_false_flepi_block_index_1():
expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"]
assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") == expected_types


def test_get_filetype_flepi_block_index_2():
expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"]
assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") == expected_types


def test_create_resume_file_names_map():
name_map = utils.create_resume_file_names_map(
resume_discard_seeding="false",
flepi_block_index="2",
resume_run_index="321",
flepi_prefix="output",
flepi_slot_index="2",
flepi_run_index="123",
last_job_output="s3://bucket",
)
for k in name_map:
assert k.find("s3://bucket") >= 0

0 comments on commit 9e0ad95

Please sign in to comment.