diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 2ceae8a..a444d41 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -24,7 +24,7 @@ jobs: - name: Install python dependencies run: | python -m pip install --upgrade pip - pip install pytest pylint==2.17.5 python-dateutil==2.8.2 pint==0.21 importlib-metadata==6.7.0 jsonschema==4.19.0 pika==1.3.1 pyproj numpy==1.26.2 shapely==2.0.2 netcdf4==1.6.3 h5netcdf==1.1.0 pillow==10.2.0 python-logging-rabbitmq==2.3.0 + pip install pytest pytest_httpserver pylint==2.17.5 requests==2.31.0 python-dateutil==2.8.2 pint==0.21 importlib-metadata==6.7.0 jsonschema==4.19.0 pika==1.3.1 pyproj numpy==1.26.2 shapely==2.0.2 netcdf4==1.6.3 h5netcdf==1.1.0 pillow==10.2.0 python-logging-rabbitmq==2.3.0 - name: Checkout idss-engine-commons uses: actions/checkout@v2 @@ -35,6 +35,17 @@ jobs: - name: Install IDSSE python commons working-directory: commons/python/idsse_common run: pip install . + + - name: Checkout idsse-testing + uses: actions/checkout@v2 + with: + repository: NOAA-GSL/idsse-testing + ref: main + path: testing/ + + - name: Install IDSSE python testing + working-directory: testing/python + run: pip install . - name: Set PYTHONPATH for pylint run: | diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 35cca26..9c13c0b 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install python dependencies run: | python -m pip install --upgrade pip - pip install pytest pylint==2.17.5 python-dateutil==2.8.2 pint==0.21 importlib-metadata==6.7.0 jsonschema==4.19.0 pika==1.3.1 pyproj==3.6.1 numpy==1.26.2 shapely==2.0.2 netcdf4==1.6.3 h5netcdf==1.1.0 pytest-cov==4.1.0 pillow==10.2.0 python-logging-rabbitmq==2.3.0 + pip install pytest pytest_httpserver requests==2.31.0 pylint==2.17.5 python-dateutil==2.8.2 pint==0.21 importlib-metadata==6.7.0 jsonschema==4.19.0 pika==1.3.1 pyproj==3.6.1 numpy==1.26.2 shapely==2.0.2 netcdf4==1.6.3 h5netcdf==1.1.0 pytest-cov==4.1.0 pillow==10.2.0 python-logging-rabbitmq==2.3.0 - name: Set PYTHONPATH for pytest run: | @@ -39,6 +39,17 @@ jobs: working-directory: commons/python/idsse_common run: pip install . + - name: Checkout idsse-testing + uses: actions/checkout@v2 + with: + repository: NOAA-GSL/idsse-testing + ref: main + path: testing/ + + - name: Install IDSSE python testing + working-directory: testing/python + run: pip install . + - name: Test with pytest working-directory: python/idsse_common # run Pytest, exiting nonzero if pytest throws errors (otherwise "| tee" obfuscates) diff --git a/python/idsse_common/idsse/common/aws_utils.py b/python/idsse_common/idsse/common/aws_utils.py index e54af6d..bb8b15d 100644 --- a/python/idsse_common/idsse/common/aws_utils.py +++ b/python/idsse_common/idsse/common/aws_utils.py @@ -2,56 +2,37 @@ # ------------------------------------------------------------------------------- # Created on Tue Feb 14 2023 # -# Copyright (c) 2023 Regents of the University of Colorado. All rights reserved. +# Copyright (c) 2023 Colorado State University. All rights reserved. (1) +# Copyright (c) 2023 Regents of the University of Colorado. All rights reserved. (2) # # Contributors: -# Geary J Layne +# Geary Layne (2) +# Paul Hamer (1) # # ------------------------------------------------------------------------------- import logging -import fnmatch import os from collections.abc import Sequence -from datetime import datetime, timedelta, UTC -from .path_builder import PathBuilder -from .utils import TimeDelta, datetime_gen, exec_cmd +from .protocol_utils import ProtocolUtils +from .utils import exec_cmd logger = logging.getLogger(__name__) -class AwsUtils(): +class AwsUtils(ProtocolUtils): """AWS Utility Class""" - def __init__(self, - basedir: str, - subdir: str, - file_base: str, - file_ext: str) -> None: - self.path_builder = PathBuilder(basedir, subdir, file_base, file_ext) - - def get_path(self, issue: datetime, valid: datetime) -> str: - """Delegates to instant PathBuilder to get full path given issue and valid - - Args: - issue (datetime): Issue date time - valid (datetime): Valid date time - - Returns: - str: Absolute path to file or object - """ - lead = TimeDelta(valid-issue) - return self.path_builder.build_path(issue=issue, valid=valid, lead=lead) - - def aws_ls(self, path: str, prepend_path: bool = True) -> Sequence[str]: - """Execute an 'ls' on the AWS s3 bucket specified by path + def ls(self, path: str, prepend_path: bool = True) -> Sequence[str]: + """Execute a 'ls' on the AWS s3 bucket specified by path Args: path (str): s3 bucket + prepend_path (bool): Add to the filename Returns: - Sequence[str]: The results sent to stdout from executing an 'ls' on passed path + Sequence[str]: The results sent to stdout from executing a 'ls' on passed path """ try: commands = ['s5cmd', '--no-sign-request', 'ls', path] @@ -65,12 +46,8 @@ def aws_ls(self, path: str, prepend_path: bool = True) -> Sequence[str]: return [os.path.join(path, filename.split(' ')[-1]) for filename in commands_result] return [filename.split(' ')[-1] for filename in commands_result] - def aws_cp(self, - path: str, - dest: str, - concurrency: int | None = None, - chunk_size: int | None = None) -> bool: - """Execute an 'cp' on the AWS s3 bucket specified by path, dest. Attempts to use + def cp(self, path: str, dest: str, concurrency: int | None = None, chunk_size: int | None = None) -> bool: + """Execute a 'cp' on the AWS s3 bucket specified by path, dest. Attempts to use [s5cmd](https://github.com/peak/s5cmd) to copy the file from S3 with parallelization, but falls back to (slower) aws-cli if s5cmd is not installed or throws an error. @@ -78,7 +55,7 @@ def aws_cp(self, path (str): Relative or Absolute path to the object to be copied dest (str): The destination location concurrency (optional, int): Number of parallel threads for s5cmd to use to copy - the file down from AWS (may be helpful to tweak for large files). + the file down from AWS (maybe helpful to tweak for large files). Default is None (s5cmd default). chunk_size (optional, int): Size of chunks (in MB) for s5cmd to split up the source AWS S3 file so it can download quicker with more threads. @@ -111,114 +88,3 @@ def aws_cp(self, return False finally: pass - - def check_for(self, issue: datetime, valid: datetime) -> tuple[datetime, str] | None: - """Checks if an object passed issue/valid exists - - Args: - issue (datetime): The issue date/time used to format the path to the object's location - valid (datetime): The valid date/time used to format the path to the object's location - - Returns: - [tuple[datetime, str] | None]: A tuple of the valid date/time (indicated by object's - location) and location (path) of a object, or None - if object does not exist - """ - lead = TimeDelta(valid-issue) - file_path = self.get_path(issue, valid) - dir_path = os.path.dirname(file_path) - filenames = self.aws_ls(file_path, prepend_path=False) - filename = self.path_builder.build_filename(issue=issue, valid=valid, lead=lead) - for fname in filenames: - # Support wildcard matches - used for '?' as a single wildcard character in - # issue/valid time specs. - if fnmatch.fnmatch(os.path.basename(fname), filename): - return (valid, os.path.join(dir_path, fname)) - return None - - def get_issues(self, - num_issues: int = 1, - issue_start: datetime | None = None, - issue_end: datetime = datetime.now(UTC), - time_delta: timedelta = timedelta(hours=1) - ) -> Sequence[datetime]: - """Determine the available issue date/times - - Args: - num_issues (int): Maximum number of issue to return. Defaults to 1. - issue_start (datetime, optional): The oldest date/time to look for. Defaults to None. - issue_end (datetime): The newest date/time to look for. Defaults to now (UTC). - time_delta (timedelta): The time step size. Defaults to 1 hour. - - Returns: - Sequence[datetime]: A sequence of issue date/times - """ - zero_time_delta = timedelta(seconds=0) - if time_delta == zero_time_delta: - raise ValueError('Time delta must be non zero') - - issues_set: set[datetime] = set() - if issue_start: - datetimes = datetime_gen(issue_end, time_delta, issue_start, num_issues) - else: - # check if time delta is positive, if so make negative - if time_delta > zero_time_delta: - time_delta = timedelta(seconds=-1.0 * time_delta.total_seconds()) - datetimes = datetime_gen(issue_end, time_delta) - for issue_dt in datetimes: - if issue_start and issue_dt < issue_start: - break - try: - dir_path = self.path_builder.build_dir(issue=issue_dt) - issues = {self.path_builder.get_issue(file_path) - for file_path in self.aws_ls(dir_path) - if file_path.endswith(self.path_builder.file_ext)} - issues_set.update(issues) - if num_issues and len(issues_set) >= num_issues: - break - except PermissionError: - pass - return sorted(issues_set, reverse=True)[:num_issues] - - def get_valids(self, - issue: datetime, - valid_start: datetime | None = None, - valid_end: datetime | None = None) -> Sequence[tuple[datetime, str]]: - """Get all objects consistent with the passed issue date/time and filter by valid range - - Args: - issue (datetime): The issue date/time used to format the path to the object's location - valid_start (datetime | None, optional): All returned objects will be for - valids >= valid_start. Defaults to None. - valid_end (datetime | None, optional): All returned objects will be for - valids <= valid_end. Defaults to None. - - Returns: - Sequence[tuple[datetime, str]]: A sequence of tuples with valid date/time (indicated by - object's location) and the object's location (path). - Empty Sequence if no valids found for given time range. - """ - if valid_start and valid_start == valid_end: - valids_and_filenames = self.check_for(issue, valid_start) - return [valids_and_filenames] if valids_and_filenames is not None else [] - - dir_path = self.path_builder.build_dir(issue=issue) - valid_and_file = [(self.path_builder.get_valid(file_path), file_path) - for file_path in self.aws_ls(dir_path) - if file_path.endswith(self.path_builder.file_ext)] - - if valid_start: - if valid_end: - valid_and_file = [(valid, filename) - for valid, filename in valid_and_file - if valid_start <= valid <= valid_end] - else: - valid_and_file = [(valid, filename) - for valid, filename in valid_and_file - if valid >= valid_start] - elif valid_end: - valid_and_file = [(valid, filename) - for valid, filename in valid_and_file - if valid <= valid_end] - - return valid_and_file diff --git a/python/idsse_common/idsse/common/http_utils.py b/python/idsse_common/idsse/common/http_utils.py new file mode 100644 index 0000000..6da83fe --- /dev/null +++ b/python/idsse_common/idsse/common/http_utils.py @@ -0,0 +1,80 @@ +"""Helper function for listing directories and retrieving s3 objects""" +# ------------------------------------------------------------------------------- +# Created on Tue Dec 3 2024 +# +# Copyright (c) 2023 Colorado State University. All rights reserved. (1) +# Copyright (c) 2023 Regents of the University of Colorado. All rights reserved. (2) +# +# Contributors: +# Paul Hamer (1) +# +# ------------------------------------------------------------------------------- +import logging +import os +import shutil +from collections.abc import Sequence + +import requests + +from .protocol_utils import ProtocolUtils + +logger = logging.getLogger(__name__) + +class HttpUtils(ProtocolUtils): + """http Utility Class - Used by DAS for file downloads""" + + + def ls(self, path: str, prepend_path: bool = True) -> Sequence[str]: + """Execute a 'ls' on the http(s) server + Args: + path (str): path + prepend_path (bool): Add path+ to the filename + Returns: + Sequence[str]: The results from executing a request get on passed path + """ + try: + files = [] + response = requests.get(path, timeout=5) + response.raise_for_status() # Raise an exception for bad status codes + + for line in response.text.splitlines(): + if 'href="' in line: + filename = line.split('href="')[1].split('"')[0] + + # Exclude directories and file without expected suffix + if not filename.endswith('/') and filename.endswith(self.path_builder.file_ext): + files.append(filename) + + except requests.exceptions.RequestException as exp: + logger.warning('Unable to query supplied Path : %s', str(exp)) + return [] + files = sorted(files, reverse=True) + if prepend_path: + return [os.path.join(path, filename) for filename in files] + return files + + # pylint: disable=unused-argument + def cp(self, path: str, dest: str, concurrency: int | None = None, chunk_size: int | None = None) -> bool: + """Execute http request download from path to dest. + + Args: + path (str): Path to the object to be copied + dest (str): The destination location + concurrency (optional, int): Number of parallel threads - ignored + chunk_size (optional, int): Size of chunks (in MB) - ignored + Returns: + bool: Returns True if copy is successful + """ + try: + with requests.get(os.path.join(path), timeout=5, stream=True) as response: + # Check if the request was successful + if response.status_code == 200: + # Open a file in binary write mode + with open(dest, "wb") as file: + shutil.copyfileobj(response.raw, file) + return True + + logger.debug('copy fail: request status code: %s', response.status_code) + return False + except Exception: # pylint: disable=broad-exception-caught + return False diff --git a/python/idsse_common/idsse/common/path_builder.py b/python/idsse_common/idsse/common/path_builder.py index 92a5286..2c77cd7 100644 --- a/python/idsse_common/idsse/common/path_builder.py +++ b/python/idsse_common/idsse/common/path_builder.py @@ -18,30 +18,49 @@ import os import re +from copy import deepcopy from datetime import datetime, timedelta, UTC -from typing import Self +from typing import Final, NamedTuple, Self from .utils import TimeDelta +# The public class class PathBuilder: """Path Builder Class""" + ISSUE: Final[str] = 'issue' + VALID: Final[str] = 'valid' + LEAD: Final[str] = 'lead' + INT: Final[str] = 'd' + FLOAT: Final[str] = 'f' + STR: Final[str] = 's' + def __init__(self, basedir: str, subdir: str, file_base: str, file_ext: str) -> None: - self._basedir = basedir - self._subdir = subdir + + # store path format parts to private vars, they accessible via properties + self._base_dir = basedir + self._sub_dir = subdir self._file_base = file_base self._file_ext = file_ext + # create a dictionary to hold lookup info + self._lookup_dict = {} + self._update_lookup(self.path_fmt) + + # these are used for caching the most recent previously parsed paths (for optimization) + self._last_parsed_path = None + self._parsed_args = None + def __str__(self) -> str: - return f"'{self._basedir}','{self._subdir}','{self._file_base}','{self._file_ext}'" + return f"'{self._base_dir}','{self._sub_dir}','{self._file_base}','{self._file_ext}'" def __repr__(self) -> str: - return (f"PathBuilder(basedir='{self._basedir}', subdir='{self._subdir}', " + return (f"PathBuilder(basedir='{self._base_dir}', subdir='{self._sub_dir}', " f"file_base='{self._file_base}', file_ext='{self._file_ext}')") @classmethod @@ -73,51 +92,80 @@ def from_path(cls, path_fmt: str) -> Self: Self: The newly created PathBuilder object """ idx = path_fmt.rindex(os.path.sep) - return PathBuilder(path_fmt[:idx], '', path_fmt[:idx], '') + return PathBuilder(path_fmt[:idx], '', path_fmt[idx+1:], '') @property def dir_fmt(self): - """Provides access to the directory format str""" - return os.path.join(self._basedir, self._subdir) + """Provides access to the directory format string""" + return os.path.join(self.base_dir, self.sub_dir) @property def filename_fmt(self): - """Provides access to the filename format str""" - if not self._file_ext or self._file_ext.startswith('.'): - return f'{self._file_base}{self._file_ext}' - return f'{self._file_base}.{self._file_ext}' + """Provides access to the filename format string""" + if not self.file_ext or self.file_ext.startswith('.'): + return f'{self.file_base}{self.file_ext}' + return f'{self.file_base}.{self.file_ext}' + + @property + def path_fmt(self): + """Provides access to the path format string""" + return os.path.join(self.dir_fmt, self.filename_fmt) @property def base_dir(self): - """Provides access to the file base directory format str""" - return self._basedir + """Provides access to the file base directory format string""" + return self._base_dir @base_dir.setter def base_dir(self, basedir): - """Set the file extension format str""" - self._basedir = basedir + """Set the base directory format string""" + # update base directory format + self._base_dir = basedir + self._update_lookup(self.path_fmt) + + @property + def sub_dir(self): + """Provides access to the file base directory format string""" + return self._sub_dir + + @sub_dir.setter + def sub_dir(self, subdir): + """Set the sub directory format string""" + # update sub directory format + self._sub_dir = subdir + self._update_lookup(self.path_fmt) + + @property + def file_base(self): + """Provides access to the file base format string""" + return self._file_base + + @file_base.setter + def file_base(self, file_base): + """Set the file extension format string""" + # update file extension format + self._file_base = file_base + self._update_lookup(self.path_fmt) @property def file_ext(self): - """Provides access to the file extension format str""" + """Provides access to the file extension format string""" if self._file_ext: return self._file_ext return self._file_base[self._file_base.rindex('.'):] @file_ext.setter def file_ext(self, ext): - """Set the file extension format str""" + """Set the file extension format string""" + # update file extension format self._file_ext = ext - - @property - def path_fmt(self): - """Provides access to the path format str""" - return os.path.join(self.dir_fmt, self.filename_fmt) + self._update_lookup(self.path_fmt) def build_dir(self, issue: datetime | None = None, valid: datetime | None = None, - lead: timedelta | TimeDelta | None = None) -> str: + lead: timedelta | TimeDelta | None = None, + **kwargs) -> str: """Attempts to build the directory with provided arguments Args: @@ -127,6 +175,7 @@ def build_dir(self, directory is dependant on it. . Defaults to None. lead (timedelta | TimeDelta | None, optional): Lead can be provided in addition to issue or valid. Defaults to None. + **kwargs: Any additional key/word args (i.e. 'region'='co') Returns: str: Directory as a string @@ -134,12 +183,13 @@ def build_dir(self, if issue is None: return None lead = self._ensure_lead(issue, valid, lead) - return self.dir_fmt.format(issue=issue, valid=valid, lead=lead) + return self.dir_fmt.format(issue=issue, valid=valid, lead=lead, **kwargs) def build_filename(self, issue: datetime | None = None, valid: datetime | None = None, - lead: timedelta | TimeDelta | None = None) -> str: + lead: timedelta | TimeDelta | None = None, + **kwargs) -> str: """Attempts to build the filename with provided arguments Args: @@ -149,17 +199,19 @@ def build_filename(self, filename is dependant on it. . Defaults to None. lead (timedelta | TimeDelta | None, optional): Lead can be provided in addition to issue or valid. Defaults to None. + **kwargs: Any additional key/word args (i.e. 'region'='co') Returns: str: File name as a string """ lead = self._ensure_lead(issue, valid, lead) - return self.filename_fmt.format(issue=issue, valid=valid, lead=lead) + return self.filename_fmt.format(issue=issue, valid=valid, lead=lead, **kwargs) def build_path(self, issue: datetime | None = None, valid: datetime | None = None, - lead: timedelta | TimeDelta | None = None) -> str: + lead: timedelta | TimeDelta | None = None, + **kwargs: dict) -> str: """Attempts to build the path with provided arguments Args: @@ -169,26 +221,27 @@ def build_path(self, path is dependant on it. . Defaults to None. lead (timedelta | TimeDelta | None, optional): Lead can be provided in addition to issue or valid. Defaults to None. + **kwargs: Any additional key/word args (i.e. 'region'='co') Returns: str: Path as a string """ lead = self._ensure_lead(issue, valid, lead) - return self.path_fmt.format(issue=issue, valid=valid, lead=lead) + return self._apply_format(self.path_fmt, issue=issue, valid=valid, lead=lead, **kwargs) - def parse_dir(self, dir_: str) -> dict: + def parse_dir(self, dir_str: str) -> dict: """Extracts issue, valid, and/or lead information from the provided directory Args: - dir_ (str): A directory consistent with this PathBuilder + dir_str (str): A directory consistent with this PathBuilder Returns: dict: Lookup of all information identified and extracted """ - return self._parse_times(dir_, self.dir_fmt) + return self._get_parsed_arg_parts(dir_str, self.dir_fmt) def parse_filename(self, filename: str) -> dict: - """Extracts issue, valid, and/or lead information from the provided filename + """Extracts issue, valid, lead, and any extras information from the provided filename Args: filename (str): A filename consistent with this PathBuilder @@ -197,10 +250,11 @@ def parse_filename(self, filename: str) -> dict: dict: Lookup of all information identified and extracted """ filename = os.path.basename(filename) - return self._parse_times(filename, self.filename_fmt) + self._parse_path(filename, self.filename_fmt) + return deepcopy(self._parsed_args) def parse_path(self, path: str) -> dict: - """Extracts issue, valid, and/or lead information from the provided path + """Extracts issue, valid, lead, and any extra information from the provided path Args: path (str): A path consistent with this PathBuilder @@ -208,36 +262,177 @@ def parse_path(self, path: str) -> dict: Returns: dict: Lookup of all information identified and extracted """ - return self._parse_times(path, self.path_fmt) + # do the core parsing + self._parse_path(path, self.path_fmt) + # return a copy to parsed_args + return deepcopy(self._parsed_args) - def get_issue(self, path: str) -> datetime: + def get_issue(self, path: str) -> datetime | None: """Retrieves the issue date/time from the provided path Args: path (str): A path consistent with this PathBuilder Returns: - datetime: After parsing the path, builds and returns the issue date/time + datetime | None: After parsing the path, builds and returns the issue date/time if + possible else returns None if insufficient info is available to build """ - time_args = self.parse_path(path) - return self.get_issue_from_time_args(time_args) + # do the core parsing + self._parse_path(path, self.path_fmt) + # return a the issue datetime, if determined + return self._parsed_args.get(self.ISSUE, None) - def get_valid(self, path: str) -> datetime: + def get_valid(self, path: str) -> datetime | None: """Retrieves the valid date/time from the provided path Args: path (str): A path consistent with this PathBuilder Returns: - datetime: After parsing the path, builds and returns the valid date/time + datetime | None: After parsing the path, builds and returns the valid date/time if + possible else returns None if insufficient info is available to build + """ + # do the core parsing + self._parse_path(path, self.path_fmt) + # return a the valid datetime, if determined + return self._parsed_args.get(self.VALID, None) + + def _update_lookup(self, fmt_str: str) -> None: + """This method should be called whenever some part of the format has been changed. + + Args: + fmt_str (str): The change format, either part of, or the complete combined, format + + Raises: + ValueError: If the format is not descriptive enough. Formats must specify size and type. """ - time_args = self.parse_path(path) - return self.get_valid_from_time_args(time_args) + # if a format is being updated any cache will be out of date + self._last_parsed_path = None + + for fmt_part in os.path.normpath(fmt_str).split(os.sep): + remaining_fmt_part = fmt_part + lookup_info = [] + cum_start = 0 + while (re_match := re.search(r'\{(.*?)\}', remaining_fmt_part)): + arg_parts = re_match.group()[1:-1].split(':') + if len(arg_parts) != 2: + raise ValueError('Format string must have explicit specification ' + '(must include a ":")') + try: + arg_size = int(re.search(r'^\d+', arg_parts[1]).group()) + except Exception: + # pylint: disable=raise-missing-from + raise ValueError('Format string must have explicit size ' + '(must include a number after ":")') + arg_type = arg_parts[1][-1] + if arg_parts[1][-1] not in [self.INT, self.FLOAT, self.STR]: + raise ValueError('Format string must have explicit type (must include one of ' + f'["{self.INT}", "{self.FLOAT}", "{self.STR}"] after ":")') + + arg_start = re_match.start() + cum_start + arg_end = cum_start = arg_start + arg_size + lookup_info.append(_LookupInfo(arg_parts[0], arg_start, arg_end, arg_type)) + # update the format str to find the next argument + remaining_fmt_part = remaining_fmt_part[re_match.end():] + + exp_len = (sum(end-start for _, start, end, _ in lookup_info) + + len(re.sub(r'\{(.*?)\}', '', fmt_part))) + + self._lookup_dict[fmt_part] = _FormatLookup(exp_len, lookup_info) + # add default for empty string + self._lookup_dict[''] = _FormatLookup(0, []) + + def _parse_path(self, path: str, fmt_str: str) -> None: + """Parse a path for any knowable variables given the provided format string. + + Args: + path (str): The path string to be parsed + fmt_str (str): The format string that the path is assumed to correspond with + """ + if path != self._last_parsed_path: + parsed_arg_parts = self._get_parsed_arg_parts(path, fmt_str) + issue_dt = self._get_issue_from_arg_parts(parsed_arg_parts) + valid_dt = self._get_valid_from_arg_parts(parsed_arg_parts) + # add the issue/valid/lead datetime and timedelta objects + if issue_dt: + parsed_arg_parts[self.ISSUE] = issue_dt + if valid_dt: + parsed_arg_parts[self.VALID] = valid_dt + if issue_dt and valid_dt: + parsed_arg_parts[self.LEAD] = TimeDelta(valid_dt - issue_dt) + + # cache this path and the parsed_arg_parts for repeat requests + self._last_parsed_path = path + self._parsed_args = parsed_arg_parts + + def _get_parsed_arg_parts(self, path: str, fmt_str: str) -> dict: + """Build a dictionary of knowable variable based on path and format string. This + dictionary can be used to create complete issue/valid datetimes and/or contain + extra variables. + + Args: + path (str): The path string from which variables will be extracted + fmt_str (str): The format string used to identify where variables can be found + + Raises: + ValueError: If the path string does not conform to the format string (not expected len) + + Returns: + dict: Dictionary of variables + """ + # Split path and format strings into lists of parts, either dir and/or filenames + fmt_parts = os.path.normpath(fmt_str).split(os.sep) + path_parts = os.path.normpath(path).split(os.sep) + + parsed_arg_parts = {} + for path_part, fmt_part in zip(path_parts, fmt_parts): + expected_len, lookup_info = self._lookup_dict[fmt_part] + if (part_len := len(path_part)) != expected_len: + raise ValueError('Path is not expected length. Passed path part ' + f"'{path_part}' doesn't match format '{fmt_part}'") + for lookup in lookup_info: + if not (0 <= lookup.start <= part_len and 0 <= lookup.end <= part_len): + raise ValueError('Parse indices are out of range for path') + try: + match lookup.type: + case self.INT: + parsed_arg_parts[lookup.key] = int(path_part[lookup.start:lookup.end]) + case self.FLOAT: + parsed_arg_parts[lookup.key] = float(path_part[lookup.start:lookup.end]) + case self.STR: + parsed_arg_parts[lookup.key] = path_part[lookup.start:lookup.end] + except ValueError as exc: + raise ValueError('Unable to apply formatting') from exc + return parsed_arg_parts + + def _apply_format(self, fmt_str: str, **kwargs) -> str: + """Use the format string and any variables in the kwargs to build a path. + + Args: + fmt_str (str): A format string, for part or a whole path + + Raises: + ValueError: If the generated path part does not match expected length + + Returns: + str: A string representation of a system path, combined with os specific separators + """ + path_parts = [] + # we split the format string without normalizing to maintain user specified path + # struct such as a double separator (sometime this can be needed) + for fmt_part in fmt_str.split(os.sep): + path_part = fmt_part.format_map(kwargs) + if len(path_part) == self._lookup_dict[fmt_part].exp_len: + path_parts.append(path_part) + else: + raise ValueError('Arguments generate a path that violate ' + f"at least part of the format, part '{fmt_part}'") + return os.path.sep.join(path_parts) @staticmethod - def get_issue_from_time_args(parsed_args: dict, - valid: datetime | None = None, - lead: timedelta | None = None) -> datetime: + def _get_issue_from_arg_parts(parsed_args: dict, + valid: datetime | None = None, + lead: timedelta | None = None) -> datetime: """Static method for creating an issue date/time from parsed arguments and optional inputs Args: @@ -260,22 +455,18 @@ def get_issue_from_time_args(parsed_args: dict, parsed_args.get('issue.second', 0), parsed_args.get('issue.microsecond', 0), tzinfo=UTC) - if lead is None and 'lead.hour' in parsed_args: - lead = PathBuilder.get_lead_from_time_args(parsed_args) - + lead = PathBuilder._get_lead_from_time_args(parsed_args) if valid is None and 'valid.year' in parsed_args: - valid = PathBuilder.get_valid_from_time_args(parsed_args) - + valid = PathBuilder._get_valid_from_arg_parts(parsed_args) if valid and lead: return valid - lead - return None @staticmethod - def get_valid_from_time_args(parsed_args: dict, - issue: datetime | None = None, - lead: timedelta | None = None) -> datetime: + def _get_valid_from_arg_parts(parsed_args: dict, + issue: datetime | None = None, + lead: timedelta | None = None) -> datetime: """Static method for creating a valid date/time from parsed arguments and optional inputs Args: @@ -298,20 +489,16 @@ def get_valid_from_time_args(parsed_args: dict, parsed_args.get('valid.second', 0), parsed_args.get('valid.microsecond', 0), tzinfo=UTC) - if lead is None and 'lead.hour' in parsed_args: - lead = PathBuilder.get_lead_from_time_args(parsed_args) - + lead = PathBuilder._get_lead_from_time_args(parsed_args) if issue is None and 'issue.year' in parsed_args: - issue = PathBuilder.get_issue_from_time_args(parsed_args) - + issue = PathBuilder._get_issue_from_arg_parts(parsed_args) if issue and lead: return issue + lead - return None @staticmethod - def get_lead_from_time_args(time_args: dict) -> timedelta: + def _get_lead_from_time_args(time_args: dict) -> timedelta: """Static method for creating a lead time from parsed arguments and optional inputs Args: @@ -323,13 +510,22 @@ def get_lead_from_time_args(time_args: dict) -> timedelta: """ if 'lead.hour' in time_args.keys(): return timedelta(hours=time_args['lead.hour']) - return None @staticmethod - def _ensure_lead(issue: datetime, - valid: datetime, - lead: timedelta | TimeDelta) -> TimeDelta: + def _ensure_lead(issue: datetime | None, + valid: datetime | None, + lead: timedelta | TimeDelta | None) -> TimeDelta: + """Make every attempt to ensure lead is known, by calculating or converting if needed. + + Args: + issue (datetime | None): An issue datetime if known, else None + valid (datetime | None): A valid datetime if known, else None + lead (timedelta | TimeDelta | None): A lead if known, else None + + Returns: + TimeDelta: _description_ + """ if lead: if isinstance(lead, timedelta): return TimeDelta(lead) @@ -338,29 +534,17 @@ def _ensure_lead(issue: datetime, return TimeDelta(valid-issue) return None - def _parse_times(self, string: str, format_str: str) -> dict: - def parse_args(key: str, value: str, result: dict): - for arg in key.split('{')[1:]: - var_name, var_size = arg.split(':') - var_type = var_size[2:3] - var_size = int(var_size[0:2]) - match var_type: - case 'd': - result[var_name] = int(value[:var_size]) - case _: - raise ValueError(f'Unknown format type: {var_type}') - key = key[var_size:] - # Check for additional characters following the end of the format element to reach - # next offset position for value... - value = value[var_size + len(arg.partition('}')[2]):] - - # Update to more generically handle time formats... - dirs = os.path.normpath(format_str).split(os.sep) - vals = os.path.normpath(string).split(os.sep) - time_args = {} - for i, _dir in enumerate(dirs): - res = re.search(r'{.*}', _dir) - if res: - parse_args(res.group(), vals[i][res.span()[0]:], time_args) - - return time_args + +# Private utility classes +class _LookupInfo(NamedTuple): + """Data class used to hold lookup info""" + key: str + start: int + end: int + type: str # should be one of 'd', 'f', 's' + + +class _FormatLookup(NamedTuple): + """Data class used to hold format and lookup info""" + exp_len: int + lookups: list[_LookupInfo] diff --git a/python/idsse_common/idsse/common/protocol_utils.py b/python/idsse_common/idsse/common/protocol_utils.py new file mode 100644 index 0000000..2ab046f --- /dev/null +++ b/python/idsse_common/idsse/common/protocol_utils.py @@ -0,0 +1,212 @@ +"""Base class for http and awc data access""" +# ------------------------------------------------------------------------------- +# Created on Tue Dec 3 2024 +# +# Copyright (c) 2023 Colorado State University. All rights reserved. (1) +# Copyright (c) 2023 Regents of the University of Colorado. All rights reserved. (2) +# +# Contributors: +# Paul Hamer (1) +# +# ------------------------------------------------------------------------------- +import fnmatch +import os + +from abc import abstractmethod, ABC +from collections.abc import Sequence +from datetime import datetime, timedelta, UTC + +from .path_builder import PathBuilder +from .utils import TimeDelta, datetime_gen + +class ProtocolUtils(ABC): + """Base Class - interface for DAS data discovery""" + + def __init__(self, + basedir: str, + subdir: str, + file_base: str, + file_ext: str) -> None: + self.path_builder = PathBuilder(basedir, subdir, file_base, file_ext) + + # pylint: disable=invalid-name + @abstractmethod + def ls(self, path: str, prepend_path: bool = True) -> Sequence[str]: + """Execute a 'ls' on the specified path + + Args: + path (str): path + prepend_path (bool): Add to the filename + + Returns: + Sequence[str]: The results sent to stdout from executing a 'ls' on passed path + """ + + @abstractmethod + def cp(self, path: str, dest: str) -> bool: + """Execute download from path to dest. + + Args: + path (str): The object to be copied + dest (str): The destination location + Returns: + bool: Returns True if copy is successful + """ + + + def get_path(self, issue: datetime, valid: datetime) -> str: + """Delegates to instant PathBuilder to get full path given issue and valid + + Args: + issue (datetime): Issue date time + valid (datetime): Valid date time + + Returns: + str: Absolute path to file or object + """ + lead = TimeDelta(valid-issue) + return self.path_builder.build_path(issue=issue, valid=valid, lead=lead) + + + def check_for(self, issue: datetime, valid: datetime) -> tuple[datetime, str] | None: + """Checks if an object passed issue/valid exists + + Args: + issue (datetime): The issue date/time used to format the path to the object's location + valid (datetime): The valid date/time used to format the path to the object's location + + Returns: + [tuple[datetime, str] | None]: A tuple of the valid date/time (indicated by object's + location) and location (path) of an object, or None + if object does not exist + """ + lead = TimeDelta(valid - issue) + file_path = self.get_path(issue, valid) + dir_path = os.path.dirname(file_path) + filenames = self.ls(dir_path, prepend_path=False) + filename = self.path_builder.build_filename(issue=issue, valid=valid, lead=lead) + + for fname in filenames: + # Support wildcard matches - used for '?' as a single wildcard character in + # issue/valid time specs. + if fnmatch.fnmatch(os.path.basename(fname), filename): + return valid, os.path.join(dir_path, fname) + return None + + def get_issues(self, + num_issues: int = 1, + issue_start: datetime | None = None, + issue_end: datetime = datetime.now(UTC), + time_delta: timedelta = timedelta(hours=1) + ) -> Sequence[datetime]: + """Determine the available issue date/times + + Args: + num_issues (int): Maximum number of issue to return. Defaults to 1. + issue_start (datetime, optional): The oldest date/time to look for. Defaults to None. + issue_end (datetime): The newest date/time to look for. Defaults to now (UTC). + time_delta (timedelta): The time step size. Defaults to 1 hour. + + Returns: + Sequence[datetime]: A sequence of issue date/times + """ + zero_time_delta = timedelta(seconds=0) + if time_delta == zero_time_delta: + raise ValueError('Time delta must be non zero') + + issues_set: set[datetime] = set() + if issue_start: + datetimes = datetime_gen(issue_end, time_delta, issue_start, num_issues) + else: + # check if time delta is positive, if so make negative + if time_delta > zero_time_delta: + time_delta = timedelta(seconds=-1.0 * time_delta.total_seconds()) + datetimes = datetime_gen(issue_end, time_delta) + for issue_dt in datetimes: + if issue_start and issue_dt < issue_start: + break + try: + dir_path = self.path_builder.build_dir(issue=issue_dt) + issues_set.update(self._get_issues(dir_path, num_issues)) + if num_issues and len(issues_set) >= num_issues: + break + except PermissionError: + pass + if None in issues_set: + issues_set.remove(None) + return sorted(issues_set)[:num_issues] + + def get_valids(self, + issue: datetime, + valid_start: datetime | None = None, + valid_end: datetime | None = None) -> Sequence[tuple[datetime, str]]: + """Get all objects consistent with the passed issue date/time and filter by valid range + + Args: + issue (datetime): The issue date/time used to format the path to the object's location + valid_start (datetime | None, optional): All returned objects will be for + valids >= valid_start. Defaults to None. + valid_end (datetime | None, optional): All returned objects will be for + valids <= valid_end. Defaults to None. + + Returns: + Sequence[tuple[datetime, str]]: A sequence of tuples with valid date/time (indicated by + object's location) and the object's location (path). + Empty Sequence if no valids found for given time range. + """ + if valid_start and valid_start == valid_end: + valids_and_filenames = self.check_for(issue, valid_start) + return [valids_and_filenames] if valids_and_filenames is not None else [] + + dir_path = self.path_builder.build_dir(issue=issue) + valid_and_file =[] + for file_path in self.ls(dir_path): + if file_path.endswith(self.path_builder.file_ext): + try: + if issue == self.path_builder.get_issue(file_path): + valid_and_file.append((self.path_builder.get_valid(file_path), file_path)) + except ValueError: # Ignore invalid filepaths... + pass + valid_and_file = [(dt, path) for (dt, path) in valid_and_file if dt is not None] + # Remove any tuple that has "None" as the valid time + if valid_start: + if valid_end: + valid_and_file = [(valid, filename) + for valid, filename in valid_and_file + if valid_start <= valid <= valid_end] + else: + valid_and_file = [(valid, filename) + for valid, filename in valid_and_file + if valid >= valid_start] + elif valid_end: + valid_and_file = [(valid, filename) + for valid, filename in valid_and_file + if valid <= valid_end] + + return valid_and_file + + def _get_issues(self, + dir_path: str, + num_issues: int = 1 + ) -> set[datetime]: + """Get all objects consistent with the passed directory path and filter by valid range + + Args: + dir_path (str): The directory path + num_issues (int): Maximum number of issue to return. Defaults to 1. + + Returns: + Sequence[tuple[datetime, str]]: A sequence of tuples with valid date/time (indicated by + object's location) and the object's location (path). + Empty Sequence if no valids found for given time range. + """ + issues_set: set[datetime] = set() + for file_path in self.ls(dir_path): + if file_path.endswith(self.path_builder.file_ext): + try: + issues_set.add(self.path_builder.get_issue(file_path)) + if num_issues and len(issues_set) >= num_issues: + break + except ValueError: # Ignore invalid filepaths... + pass + return issues_set diff --git a/python/idsse_common/idsse/common/rabbitmq_utils.py b/python/idsse_common/idsse/common/rabbitmq_utils.py index e24faec..8887829 100644 --- a/python/idsse_common/idsse/common/rabbitmq_utils.py +++ b/python/idsse_common/idsse/common/rabbitmq_utils.py @@ -11,10 +11,12 @@ # # ---------------------------------------------------------------------------------- +import contextvars import logging import logging.config import uuid -from concurrent.futures import ThreadPoolExecutor + +from concurrent.futures import Future, ThreadPoolExecutor from collections.abc import Callable from functools import partial from threading import Event, Thread @@ -22,6 +24,7 @@ from pika import BasicProperties, ConnectionParameters, PlainCredentials from pika.adapters import BlockingConnection +from pika.adapters.blocking_connection import BlockingChannel from pika.channel import Channel from pika.exceptions import UnroutableError from pika.frame import Method @@ -70,7 +73,7 @@ class Queue(NamedTuple): durable: bool exclusive: bool auto_delete: bool - type: str = 'classic' + arguments: dict = {} class RabbitMqParams(NamedTuple): @@ -88,72 +91,339 @@ class RabbitMqParamsAndCallback(NamedTuple): callback: Callable -def _initialize_exchange_and_queue( - channel: Channel, - params: RabbitMqParams -) -> str: - """Declare and bind RabbitMQ exchange and queue using the provided channel. +class RabbitMqMessage(NamedTuple): + """ + Data class to hold a RabbitMQ message body, properties, and optional route_key (if outbound) + """ + body: str + properties: BasicProperties + route_key: str | None = None - Returns: - str: the name of the newly-initialized queue. + +class Consumer(Thread): """ - exch, queue = params - logger.info('Subscribing to exchange: %s', exch.name) + RabbitMQ consumer, runs in own thread to not block heartbeat. A thread pool + is used to not so much to parallelize the execution but rather to manage the + execution of the callbacks, including being able to wait for completion on + shutdown. The start() and stop() methods should be called from the same + thread as the one used to create the instance. + """ + def __init__( + self, + conn_params: Conn, + rmq_params_and_callbacks: RabbitMqParamsAndCallback | list[RabbitMqParamsAndCallback], + *args, + num_message_handlers: int = 2, + **kwargs, + ): + """ + Args: + conn_params (Conn): parameters to create a new RabbitMQ connection + rmq_params_and_callbacks (RabbitMqParamsAndCallback | list[RabbitMqParamsAndCallback]): + 1 or more Exch/Queue tuples, and a function to invoke when messages arrive on the + listed queue. + num_message_handlers (optional, int): The max thread pool size for workers to handle + message callbacks concurrently. Default is 2. + """ + super().__init__(*args, **kwargs, name='Consumer') + self.context = contextvars.copy_context() + self.daemon = True + self._tpx = ThreadPoolExecutor(max_workers=num_message_handlers) - # Do not try to declare the default exchange. It already exists - if exch.name != '': - channel.exchange_declare(exchange=exch.name, - exchange_type=exch.type, - durable=exch.durable) + if isinstance(rmq_params_and_callbacks, list): + _rmq_params_and_callbacks = rmq_params_and_callbacks + else: + _rmq_params_and_callbacks = [rmq_params_and_callbacks] - # Do not try to declare or bind built-in queues. They are pseudo-queues that already exist - if queue.name.startswith('amq.rabbitmq.'): - return queue.name + self.connection = BlockingConnection(conn_params.connection_parameters) + self.channel = self.connection.channel() + self.channel.basic_qos(prefetch_count=1) - # If we have a 'private' queue, i.e. one used to support message publishing, not consumed - # Set message time-to-live (TTL) to 10 seconds - arguments = {'x-message-ttl': 10 * 1000} if queue.name.startswith('_') else None - frame: Method = channel.queue_declare( - queue=queue.name, - exclusive=queue.exclusive, - durable=queue.durable, - auto_delete=queue.auto_delete, - arguments=arguments - ) + self._consumer_tags = [] + for (exch, queue), func in _rmq_params_and_callbacks: + _setup_exch_and_queue(self.channel, exch, queue) + self._consumer_tags.append( + self.channel.basic_consume(queue.name, + partial(self._on_message, func=func), + # RMQ requires auto_ack=True for Direct Reply-to + auto_ack=queue.name == DIRECT_REPLY_QUEUE) + ) - # Bind queue to exchange with routing_key. May need to support multiple keys in the future - if exch.name != '': - logger.info(' binding key %s to queue: %s', queue.route_key, queue.name) - channel.queue_bind(queue.name, exch.name, queue.route_key) - return frame.method.queue + def run(self): + _set_context(self.context) + # create a local logger since this is run in a separate threat when start() is called + _logger = logging.getLogger(f'{__name__}::{self.__class__.__name__}') + _logger.info('Start Consuming... (to stop press CTRL+C)') + self.channel.start_consuming() + def stop(self): + """Cleanly end the running of a thread, free up resources""" + logger.info('Stopping consumption of messages...') + logger.debug('Waiting for any currently running workers (this could take some time)') + self._tpx.shutdown(wait=True, cancel_futures=True) + # it would be nice to stop consuming before shutting down the thread pool, but when done in + # in the other order completed tasks can't be (n)ack-ed, this does mean that messages can be + # consumed from the queue and the shutdown starts that will not be processed, nor (n)ack-ed -def _initialize_connection_and_channel( - connection: Conn, - params: RabbitMqParams, - channel: Channel | Channel | None = None, -) -> tuple[BlockingConnection, Channel, str]: - """Establish RabbitMQ connection, and declare exchange and queue on new Channel""" - if not isinstance(connection, Conn): - # connection of unsupported type passed - raise ValueError( - (f'Cannot use or create new RabbitMQ connection using type {type(connection)}. ' - 'Should be type Conn (a dict with connection parameters)') + if self.connection and self.connection.is_open: + # there should be one consumer tag for each channel being consumed from + if self._consumer_tags: + threadsafe_call(self.channel, + *[partial(self.channel.stop_consuming, consumer_tag) + for consumer_tag in self._consumer_tags], + lambda: logger.info('Stopped Consuming')) + + threadsafe_call(self.channel, + self.channel.close, + self.connection.close) + + # pylint: disable=too-many-arguments + def _on_message(self, channel, method, properties, body, func): + """This is the callback wrapper, the core callback is passed as func""" + try: + self._tpx.submit(func, channel, method, properties, body) + except RuntimeError as exe: + logger.error('Unable to submit it to thread pool, Cause: %s', exe) + + +class Publisher(Thread): + """ + RabbitMQ publisher, runs in own thread to not block heartbeat. The start() and stop() + methods should be called from the same thread as the one used to create the instance. + """ + def __init__( + self, + conn_params: Conn, + exch_params: Exch, + *args, + **kwargs, + ): + """ + Args: + conn_params (Conn): RabbitMQ Conn parameters to create a new RabbitMQ connection + exch_params (Exch): params for what RabbitMQ exchange to publish messages to. + """ + super().__init__(*args, **kwargs, name='Publisher') + self.context = contextvars.copy_context() + self.daemon = True + self._is_running = True + self._exch = exch_params + self._queue = None + + # create new RabbitMQ Connection and Channel using the provided params + self.connection = BlockingConnection(conn_params.connection_parameters) + self.channel = self.connection.channel() + + # if delivery is mandatory there must be a queue attach to the exchange + if self._exch.mandatory: + self._queue = Queue(name=f'_{self._exch.name}_{uuid.uuid4()}', + route_key=self._exch.route_key, + durable=False, + exclusive=True, + auto_delete=False, + arguments={'x-queue-type': 'classic', + 'x-message-ttl': 10 * 1000}) + + _setup_exch_and_queue(self.channel, self._exch, self._queue) + elif self._exch.name != '': # if using default exchange, skip declare (not allowed by RMQ) + _setup_exch(self.channel, self._exch) + + if self._exch.delivery_conf: + self.channel.confirm_delivery() + + def run(self): + _set_context(self.context) + # create a local logger since this is run in a separate threat when start() is called + _logger = logging.getLogger(f'{__name__}::{self.__class__.__name__}') + _logger.info('Starting publisher') + while self._is_running: + if self.connection and self.connection.is_open: + self.connection.process_data_events(time_limit=1) + + def publish(self, message: bytes, properties: BasicProperties = None, route_key: str = None): + """ + Publish a message to this pre configured exchange. The actual publication + is asynchronous and this method only schedules it to be done. + + Args: + message (bytes): The message to be published + properties (BasicProperties): The props to be attached to message when published + route_key (str): Optional route key, overriding key provided during initialization + """ + threadsafe_call( + self.channel, + lambda: _publish(self.channel, + self._exch, + RabbitMqMessage(message, properties, route_key), + self._queue) ) - _connection = BlockingConnection(connection.connection_parameters) - logger.info('Established new RabbitMQ connection to %s on port %i', - connection.host, connection.port) + def blocking_publish(self, + message: bytes, + properties: BasicProperties = None, + route_key: str = None) -> bool: + """ + Blocking publish. Works by waiting for the completion of an asynchronous + publication. - if channel is None: - logger.info('Creating new RabbitMQ channel') - _channel = _connection.channel() - else: - _channel = channel + Args: + message (bytes): The message to be published + properties (BasicProperties): The props to be attached to message when published + route_key (str): Optional route key, overriding key provided during initialization - queue_name = _initialize_exchange_and_queue(_channel, params) + Returns: + bool: Returns True if no errors ocurred during publication. If this + publisher is configured to confirm delivery will return False if + failed to confirm. + """ + return _blocking_publish(self.channel, + self._exch, + RabbitMqMessage(message, properties, route_key), + self._queue) - return _connection, _channel, queue_name + def stop(self): + """Cleanly end the running of a thread, free up resources""" + logger.info("Stopping publisher") + self._is_running = False + # Wait until all the data events have been processed + if self.connection and self.connection.is_open: + self.connection.process_data_events(time_limit=1) + threadsafe_call(self.channel, + self.channel.close, + self.connection.close) + + +class Rpc: + """ + RabbitMQ RPC (remote procedure call) client, runs in own thread to not block heartbeat. + The start() and stop() methods should be called from the same thread that created the instance. + + This RPC class can be used to send "requests" (outbound messages) over RabbitMQ and block until + a "response" (inbound message) comes back from the receiving app. All producing to/consuming of + different queues, and associating requests with their responses, is abstracted away. + + Note that RPC by RabbitMQ convention uses the built-in Direct Reply-To queue to field the + responses messages, rather than creating its own queue. Directing responses to a custom queue + is not yet supported by Rpc. + + Example usage: + + my_client = RpcClient(...insert params here...) + + response = my_client.send_message('{"some": "json"}') # blocks while waiting for response + + logger.info(f'Response from external service: {response}') + """ + def __init__(self, conn_params: Conn, exch: Exch, timeout: float | None = None): + """ + Args: + conn_params (Conn): parameters to connect to RabbitMQ server + exch (Exch): parameters of RMQ Exchange where messages should be sent + timeout (float | None): optional timeout to give up on receiving each response. + Default is None, meaning wait indefinitely for response from external RMQ service. + """ + self._exch = exch + self._timeout = timeout + # only publish to built-in Direct Reply-to queue (recommended for RPC, less setup needed) + self._queue = Queue(DIRECT_REPLY_QUEUE, '', True, False, False) + + # worklist to track corr_ids sent to remote service, and associated response when it arrives + self._pending_requests: dict[str, Future] = {} + + # Start long-running thread to consume any messages from response queue + self.consumer = Consumer( + conn_params, + RabbitMqParamsAndCallback(RabbitMqParams(Exch('', 'direct'), self._queue), + self._response_callback) + ) + + @property + def is_open(self) -> bool: + """Returns True if RabbitMQ connection (Publisher) is open and ready to send messages""" + return self.consumer.is_alive() and self.consumer.channel.is_open + + def send_request(self, request_body: str | bytes) -> RabbitMqMessage | None: + """Send message to remote RabbitMQ service using thread-safe RPC. Will block until response + is received back, or timeout occurs. + + Returns: + RabbitMqMessage | None: The response message (body and properties), or None on request + timeout or error handling response. + """ + if not self.is_open: + logger.debug('RPC thread not yet initialized. Setting up now') + self.start() + + # generate unique ID to associate our request to external service's response + request_id = str(uuid.uuid4()) + + # send request to external RMQ service, providing the queue where it should respond + properties = BasicProperties(content_type='application/json', + correlation_id=request_id, + reply_to=self._queue.name) + + # add future to dict where callback can retrieve it and set result + request_future = Future() + self._pending_requests[request_id] = request_future + + logger.debug('Publishing request message to external service with body: %s', request_body) + _blocking_publish(self.consumer.channel, + self._exch, + RabbitMqMessage(request_body, properties, self._exch.route_key), + self._queue) + + try: + # block until callback runs (we'll know when the future's result has been changed) + return request_future.result(timeout=self._timeout) + except TimeoutError: + logger.warning('Timed out waiting for response. correlation_id: %s', request_id) + self._pending_requests.pop(request_id) # stop tracking request Future + return None + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning('Unexpected response from external service: %s', str(exc)) + self._pending_requests.pop(request_id) # stop tracking request Future + return None + + def start(self): + """Start dedicated threads to asynchronously send and receive RPC messages using a new + RabbitMQ connection and channel. Note: this method can be called externally, but it is + not required to use the client. It will automatically call this internally as needed.""" + if not self.is_open: + logger.debug('Starting RPC thread to send and consume messages') + self.consumer.start() + + def stop(self): + """Unsubscribe to Direct Reply-To queue and cleanup thread""" + logger.debug('Shutting down RPC threads') + if not self.is_open: + logger.debug('RPC threads not running, nothing to cleanup') + return + + # tell Consumer cleanup RabbitMQ resources and wait for thread to terminate + self.consumer.stop() + self.consumer.join() + + def _response_callback( + self, + channel: Channel, + method: Basic.Deliver, + properties: BasicProperties, + body: bytes + ): + """Handle RabbitMQ message emitted to response queue.""" + logger.debug('Received response with routing_key: %s, content_type: %s, message: %i', + method.routing_key, properties.content_type, str(body, encoding='utf-8')) + + # remove future from pending list. we will update result shortly + request_future = self._pending_requests.pop(properties.correlation_id) + + # messages sent through RabbitMQ Direct reply-to are auto acked + is_direct_reply = str(method.routing_key).startswith(DIRECT_REPLY_QUEUE) + if not is_direct_reply: + channel.basic_ack(delivery_tag=method.delivery_tag) + + # update future with response body to communicate it back up to main thread + return request_future.set_result(RabbitMqMessage(str(body, encoding='utf-8'), properties)) def subscribe_to_queue( @@ -162,7 +432,7 @@ def subscribe_to_queue( on_message_callback: Callable[ [Channel, Basic.Deliver, BasicProperties, bytes], None], channel: Channel | None = None -) -> tuple[BlockingConnection, Channel]: +) -> tuple[BlockingConnection, BlockingChannel]: """ Function that handles setup of consumer of RabbitMQ queue messages, declaring the exchange and queue if needed, and invoking the provided callback when a message is received. @@ -182,8 +452,8 @@ def subscribe_to_queue( on_message_callback (Callable[ [BlockingChannel, Basic.Deliver, BasicProperties, bytes], None]): function to handle messages that are received over the subscribed exchange and queue. - channel (BlockingChannel | None): optional existing (open) RabbitMQ channel to reuse. - Default is to create unique channel for this consumer. + channel (Channel | None): optional existing (open) RabbitMQ channel to reuse. Default is + to create unique channel for this consumer. Returns: tuple[BlockingConnection, BlockingChannel]: the connection and channel, which are now open @@ -203,52 +473,6 @@ def subscribe_to_queue( return _connection, _channel -def _setup_exch_and_queue(channel: Channel, exch: Exch, queue: Queue): - """Setup an exchange and queue and bind them with the queue's route key(s)""" - if queue.type == 'quorum' and queue.auto_delete: - raise ValueError('Quorum queues can not be configured to auto delete') - - _setup_exch(channel, exch) - - result: Method = channel.queue_declare( - queue=queue.name, - exclusive=queue.exclusive, - durable=queue.durable, - auto_delete=queue.auto_delete, - arguments={'x-queue-type': queue.type} - ) - queue_name = result.method.queue - logger.debug('Declared queue: %s', queue_name) - - if isinstance(queue.route_key, list): - for route_key in queue.route_key: - channel.queue_bind( - queue_name, - exchange=exch.name, - routing_key=route_key - ) - logger.debug('Bound queue(%s) to exchange(%s) with route_key(%s)', - queue_name, exch.name, route_key) - else: - channel.queue_bind( - queue_name, - exchange=exch.name, - routing_key=queue.route_key - ) - logger.debug('Bound queue(%s) to exchange(%s) with route_key(%s)', - queue_name, exch.name, queue.route_key) - - -def _setup_exch(channel: Channel, exch: Exch): - """Setup and exchange""" - channel.exchange_declare( - exchange=exch.name, - exchange_type=exch.type, - durable=exch.durable - ) - logger.debug('Declared exchange: %s', exch.name) - - def threadsafe_call(channel: Channel, *functions: Callable): """ This function provides a thread safe way to call pika functions (or functions that call @@ -282,7 +506,7 @@ def threadsafe_call(channel: Channel, *functions: Callable): partial(self.pub_conf.publish_message, message=message)) Args: - channel (BlockingChannel): RabbitMQ channel. + channel (Channel): RabbitMQ channel. functions (Callable): One or more callable function, typically created via functools.partial or lambda, but can be function without args """ @@ -305,7 +529,7 @@ def threadsafe_ack( This is just a convenance function that acks a message via threadsafe_call Args: - channel (BlockingChannel): RabbitMQ channel. + channel (Channel): RabbitMQ channel. delivery_tag (int): Delivery tag to be used when nacking. extra_func (Callable): Any extra function that you would like to be called after the nack. Typical use case would we to send a log via a lambda @@ -327,7 +551,7 @@ def threadsafe_nack( This is just a convenance function that nacks a message via threadsafe_call Args: - channel (BlockingChannel): RabbitMQ channel. + channel (Channel): RabbitMQ channel. delivery_tag (int): Delivery tag to be used when nacking. extra_func (Callable): Any extra function that you would like to be called after the nack. Typical use case would we to send a log via a lambda @@ -342,211 +566,214 @@ def threadsafe_nack( threadsafe_call(channel, lambda: channel.basic_nack(delivery_tag, requeue=requeue)) -class Consumer(Thread): - """ - RabbitMQ consumer, runs in own thread to not block heartbeat. A thread pool - is used to not so much to parallelize the execution but rather to manage the - execution of the callbacks, including being able to wait for completion on - shutdown. The start() and stop() methods should be called from the same - thread as the one used to create the instance. - """ - def __init__( - self, - conn_params: Conn, - rmq_params_and_callbacks: RabbitMqParamsAndCallback | list[RabbitMqParamsAndCallback], - num_message_handlers: int, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.daemon = True - self._tpx = ThreadPoolExecutor(max_workers=num_message_handlers) - self._conn_params = conn_params - if isinstance(rmq_params_and_callbacks, list): - self._rmq_params_and_callbacks = rmq_params_and_callbacks - else: - self._rmq_params_and_callbacks = [rmq_params_and_callbacks] - self.connection = BlockingConnection(self._conn_params.connection_parameters) - self.channel = self.connection.channel() - - self._consumer_tags = [] - for (exch, queue), func in self._rmq_params_and_callbacks: - _setup_exch_and_queue(self.channel, exch, queue) - self._consumer_tags.append( - self.channel.basic_consume(queue.name, - partial(self._on_message, func=func)) - ) +def _initialize_exchange_and_queue(channel: Channel, params: RabbitMqParams) -> str: + """Declare and bind RabbitMQ exchange and queue using the provided channel. - self.channel.basic_qos(prefetch_count=1) + Returns: + str: the name of the newly-initialized queue. + """ + exch, queue = params + logger.info('Subscribing to exchange: %s', exch.name) - def run(self): - logger.info('Start Consuming... (to stop press CTRL+C)') - self.channel.start_consuming() + # Do not try to declare the default exchange. It already exists + if exch.name != '': + channel.exchange_declare(exchange=exch.name, + exchange_type=exch.type, + durable=exch.durable) - def stop(self): - """Cleanly end the running of a thread, free up resources""" - logger.info('Stopping consumption of messages...') - logger.debug('Waiting for any currently running workers (this could take some time)') - self._tpx.shutdown(wait=True, cancel_futures=True) - # it would be nice to stop consuming before shutting down the thread pool, but when done in - # in the other order completed tasks can't be (n)ack-ed, this does mean that messages can be - # consumed from the queue and the shutdown starts that will not be processed, nor (n)ack-ed + # Do not try to declare or bind built-in queues. They are pseudo-queues that already exist + if queue.name.startswith('amq.rabbitmq.'): + return queue.name - if self.connection and self.connection.is_open: - # there should be one consumer tag for each channel being consumed from - if self._consumer_tags: - threadsafe_call(self.channel, - *[partial(self.channel.stop_consuming, consumer_tag) - for consumer_tag in self._consumer_tags], - lambda: logger.info('Stopped Consuming')) + # If we have a 'private' queue, i.e. one used to support message publishing, not consumed + # Set message time-to-live (TTL) to 10 seconds + if queue.name.startswith('_'): + queue.arguments['x-message-ttl'] = 10 * 1000 + frame: Method = channel.queue_declare( + queue=queue.name, + exclusive=queue.exclusive, + durable=queue.durable, + auto_delete=queue.auto_delete, + arguments=queue.arguments + ) - threadsafe_call(self.channel, - self.channel.close, - self.connection.close) + # Bind queue to exchange with routing_key. May need to support multiple keys in the future + if exch.name != '': + logger.info(' binding key %s to queue: %s', queue.route_key, queue.name) + channel.queue_bind(queue.name, exch.name, queue.route_key) + return frame.method.queue - # pylint: disable=too-many-arguments - def _on_message(self, channel, method, properties, body, func): - """This is the callback wrapper, the core callback is passed as func""" - try: - self._tpx.submit(func, channel, method, properties, body) - except RuntimeError as exe: - logger.error('Unable to submit it to thread pool, Cause: %s', exe) +def _initialize_connection_and_channel( + connection: Conn, + params: RabbitMqParams, + channel: Channel | None = None, +) -> tuple[BlockingConnection, BlockingChannel, str]: + """Establish RabbitMQ connection, and declare exchange and queue on new Channel""" + if not isinstance(connection, Conn): + # connection of unsupported type passed + raise ValueError( + (f'Cannot use or create new RabbitMQ connection using type {type(connection)}. ' + 'Should be type Conn (a dict with connection parameters)') + ) -class Publisher(Thread): - """ - RabbitMQ publisher, runs in own thread to not block heartbeat. The start() and stop() - methods should be called from the same thread as the one used to create the instance. - """ - def __init__( - self, - conn_params: Conn, - exch_params: Exch, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.daemon = True - self._is_running = True - self._exch = exch_params - self._queue = None + _connection = BlockingConnection(connection.connection_parameters) + logger.info('Established new RabbitMQ connection to %s on port %i', + connection.host, connection.port) - self.connection = BlockingConnection(conn_params.connection_parameters) - self.channel = self.connection.channel() + if channel is None: + logger.info('Creating new RabbitMQ channel') + _channel = _connection.channel() + else: + _channel = channel - # if delivery is mandatory there must be a queue attach to the exchange - if self._exch.mandatory: - self._queue = Queue(name=f'_{self._exch.name}_{uuid.uuid4()}', - route_key=self._exch.route_key, - durable=False, - exclusive=True, - auto_delete=False) + queue_name = _initialize_exchange_and_queue(_channel, params) - _setup_exch_and_queue(self.channel, self._exch, self._queue) - else: - _setup_exch(self.channel, self._exch) + return _connection, _channel, queue_name - if self._exch.delivery_conf: - self.channel.confirm_delivery() - def run(self): - logger.info('Starting publisher') - while self._is_running: - if self.connection and self.connection.is_open: - self.connection.process_data_events(time_limit=1) +def _setup_exch_and_queue(channel: Channel, exch: Exch, queue: Queue): + """Setup an exchange and queue and bind them with the queue's route key(s)""" + if queue.arguments and 'x-queue-type' in queue.arguments and \ + queue.arguments['x-queue-type'] == 'quorum' and queue.auto_delete: + raise ValueError('Quorum queues can not be configured to auto delete') - def publish(self, message: bytes, properties: BasicProperties = None, route_key: str = None): - """ - Publish a message to this pre configured exchange. The actual publication - is asynchronous and this method only schedules it to be done. + if exch.name != '': # if using default exchange, skip declaring (not allowed by RMQ) + _setup_exch(channel, exch) - Args: - message (bytes): The message to be published - properties (BasicProperties): The props to be attached to message when published - route_key (str): Optional route key, overriding key provided during initialization - """ - threadsafe_call(self.channel, - lambda: self._publish(message, properties, route_key, [False])) + if queue.name == DIRECT_REPLY_QUEUE: + queue_name = queue.name + logger.debug('Using Direct Reply-to queue, skipping declare') - def blocking_publish(self, - message: bytes, - properties: BasicProperties = None, - route_key: str = None) -> bool: - """ - Blocking publish. Works by waiting for the completion of an asynchronous - publication. + else: + result: Method = channel.queue_declare( + queue=queue.name, + exclusive=queue.exclusive, + durable=queue.durable, + auto_delete=queue.auto_delete, + arguments=queue.arguments + ) + queue_name = result.method.queue + logger.debug('Declared queue: %s', queue_name) + + if exch.name != '': # if using default exchange, skip binding queues (not allowed by RMQ) + if isinstance(queue.route_key, list): + for route_key in queue.route_key: + channel.queue_bind( + queue_name, + exchange=exch.name, + routing_key=route_key + ) + logger.debug('Bound queue(%s) to exchange(%s) with route_key(%s)', + queue_name, exch.name, route_key) + else: + channel.queue_bind( + queue_name, + exchange=exch.name, + routing_key=queue.route_key + ) + logger.debug('Bound queue(%s) to exchange(%s) with route_key(%s)', + queue_name, exch.name, queue.route_key) - Args: - message (bytes): The message to be published - properties (BasicProperties): The props to be attached to message when published - route_key (str): Optional route key, overriding key provided during initialization - Returns: - bool: Returns True if no errors ocurred during publication. If this - publisher is configured to confirm delivery will return False if - failed to confirm. - """ - success_flag = [False] - done_event = Event() - threadsafe_call(self.channel, lambda: self._publish(message, - properties, - route_key, - success_flag, - done_event)) - done_event.wait() - return success_flag[0] +def _setup_exch(channel: Channel, exch: Exch): + """Setup an exchange""" + channel.exchange_declare( + exchange=exch.name, + exchange_type=exch.type, + durable=exch.durable + ) + logger.debug('Declared exchange: %s', exch.name) - def stop(self): - """Cleanly end the running of a thread, free up resources""" - logger.info("Stopping publisher") - self._is_running = False - # Wait until all the data events have been processed - if self.connection and self.connection.is_open: - self.connection.process_data_events(time_limit=1) - threadsafe_call(self.channel, - self.channel.close, - self.connection.close) - # pylint: disable=too-many-arguments,unused-argument - def _publish( - self, - message: bytes, - properties: BasicProperties, - route_key: str = None, - success_flag: list[bool] = None, - done_event: Event = None - ): - """ - Core publish method. Success flag is passed by reference, and done event, if not None - can be used to block until message is actually publish, vs being scheduled to be. - - success_flag (list[bool]): This is effectively passing a boolean by reference. This - will change the value of the first element it this list - to indicate if the core publishing was successful. - done_event (Event): A Thread.Event that can be used to indicate when publishing is - complete in a different thread. This can be used to wait for the - completion via 'done_event.wait()' following calling this function. - """ +# pylint: disable=too-many-arguments +def _publish(channel: BlockingChannel, + exch: Exch, + message_params: RabbitMqMessage, + queue: Queue | None = None, + success_flag: list[bool] = None, + done_event: Event = None): + """ + Core method to publish to a given RabbitMQ exchange (threadsafe). + Success flag is passed by reference, and done event, if not None + can be used to block until message is actually publish, vs being scheduled to be. + + channel (BlockingChannel): the pika channel to use to publish. + exch (Exch): parameters for the RabbitMQ exchange to publish message to. + message_params (RabbitMqMessage): the message body to publish, plus properties and + (optional) route_key. + queue (optional, Queue | None): parameters for RabbitMQ queue, if message is being + published to a "temporary"/"private" message queue. The published message will be + purged from this queue after its TTL expires. + Default is None (destination queue not private). + success_flag (list[bool]): This is effectively passing a boolean by reference. This + will change the value of the first element it this list + to indicate if the core publishing was successful. + done_event (Event): A Thread.Event that can be used to indicate when publishing is + complete in a different thread. This can be used to wait for the + completion via 'done_event.wait()' following calling this function. + """ + if success_flag: + success_flag[0] = False + try: + channel.basic_publish( + exch.name, + message_params.route_key if message_params.route_key else exch.route_key, + body=message_params.body, + properties=message_params.properties, + mandatory=exch.mandatory + ) if success_flag: - success_flag[0] = False - try: - self.channel.basic_publish(self._exch.name, - route_key if route_key else self._exch.route_key, - body=message, - properties=properties, - mandatory=self._exch.mandatory) - if success_flag: - success_flag[0] = True - if self._queue and self._queue.name.startswith('_'): - try: - self.channel.queue_purge(queue=self._queue.name) - except ValueError as exe: - logger.warning('Exception when removing message from private queue: %s', exe) - except UnroutableError: - logger.warning('Message was not delivered') - except Exception as exe: - logger.warning('Message not published, cause: %s', exe) - raise exe - finally: - if done_event: - done_event.set() + success_flag[0] = True + if queue and queue.name.startswith('_'): + try: + channel.queue_purge(queue.name) + except ValueError as exc: + logger.warning('Exception when removing message from private queue: %s', exc) + except UnroutableError: + logger.warning('Message was not delivered') + except Exception as exc: + logger.warning('Message not published, cause: %s', exc) + raise exc + finally: + if done_event: + done_event.set() + + +def _blocking_publish( + channel: BlockingChannel, + exch: Exch, + message_params: RabbitMqMessage, + queue: Queue | None = None, +) -> bool: + """ + Threadsafe, blocking publish on the specified RabbitMQ exch via the provided channel. + Is thread-safe. + + Args: + channel (BlockingChannel): the pika channel to use to publish. + exch (Exch): parameters for the RabbitMQ exchange to publish message to. + message_params (RabbitMqMessage): the message body to publish, plus properties and + queue (optional, Queue | None): parameters for RabbitMQ queue, if message is being + published to a "temporary"/"private" message queue. The published message will be + purged from this queue after its TTL expires. + Default is None (destination queue not private). + Returns: + (bool) True if message published successfully. If the provided queue is confirmed to + confirm delivery, will return False if failed to confirm. + """ + success_flag = [False] + done_event = Event() + threadsafe_call(channel, lambda: _publish(channel, + exch, + message_params, + queue, + success_flag, + done_event)) + done_event.wait() + return success_flag[0] + + +def _set_context(context): + for var, value in context.items(): + var.set(value) diff --git a/python/idsse_common/idsse/common/utils.py b/python/idsse_common/idsse/common/utils.py index ea12da6..f447a1a 100644 --- a/python/idsse_common/idsse/common/utils.py +++ b/python/idsse_common/idsse/common/utils.py @@ -31,26 +31,37 @@ class RoundingMethod(Enum): RoundingParam = str | RoundingMethod -class TimeDelta: - """Wrapper class for datetime.timedelta to add helpful properties""" - - def __init__(self, time_delta: timedelta) -> None: - self._td = time_delta +class TimeDelta(timedelta): + """Extend class for datetime.timedelta to add helpful properties.""" + def __new__(cls, *args, **kwargs): + if isinstance(args[0], timedelta): + return super().__new__(cls, seconds=args[0].total_seconds()) + return super().__new__(cls, *args, **kwargs) @property def minute(self): """Property to get the number of minutes this instance represents""" - return int(self._td / timedelta(minutes=1)) + return int(self / timedelta(minutes=1)) + + @property + def minutes(self): + """Property to get the number of minutes this instance represents""" + return self.minute @property def hour(self): """Property to get the number of hours this instance represents""" - return int(self._td / timedelta(hours=1)) + return int(self / timedelta(hours=1)) + + @property + def hours(self): + """Property to get the number of hours this instance represents""" + return self.hour @property def day(self): """Property to get the number of days this instance represents""" - return self._td.days + return self.days class Map(dict): @@ -202,6 +213,7 @@ def _round_toward_zero(number: float) -> int: func = math.ceil if number < 0 else math.floor return func(number) + def round_half_away(number: int | float, precision: int = 0) -> int | float: """ *Deprecated: avoid using this function directly, instead use idsse.commons.round_()* diff --git a/python/idsse_common/test/test_aws_utils.py b/python/idsse_common/test/test_aws_utils.py index 36564a3..ec6466c 100644 --- a/python/idsse_common/test/test_aws_utils.py +++ b/python/idsse_common/test/test_aws_utils.py @@ -8,6 +8,7 @@ # Contributors: # Mackenzie Grimes (1) # Geary Layne (2) +# Paul Hamer (1) # # ---------------------------------------------------------------------------------- # pylint: disable=missing-function-docstring,redefined-outer-name,pointless-statement @@ -56,7 +57,10 @@ def aws_utils_with_wild() -> AwsUtils: @fixture def mock_exec_cmd(monkeypatch: MonkeyPatch) -> Mock: def get_files_for_dir(args: Iterable[str]) -> Sequence[str]: - hour = args[-1].split('/')[-3] + if args[-1].endswith('grib2') or args[-1].endswith('/'): + hour = args[-1].split('/')[-3] + else: + hour = args[-1].split('/')[-2] return [f'blend.t{hour}z.core.f002.co.grib2', f'blend.t{hour}z.core.f003.co.grib2', f'blend.t{hour}z.core.f004.co.grib2'] @@ -71,70 +75,70 @@ def test_get_path(aws_utils: AwsUtils): assert result_path == f'{EXAMPLE_DIR}blend.t12z.core.f002.co.grib2' -def test_aws_ls(aws_utils: AwsUtils, mock_exec_cmd): - result = aws_utils.aws_ls(EXAMPLE_DIR) +def test_ls(aws_utils: AwsUtils, mock_exec_cmd): + result = aws_utils.ls(EXAMPLE_DIR) assert len(result) == len(EXAMPLE_FILES) assert result[0] == f'{EXAMPLE_DIR}{EXAMPLE_FILES[0]}' mock_exec_cmd.assert_called_once() -def test_aws_ls_without_prepend_path(aws_utils: AwsUtils, mock_exec_cmd): - result = aws_utils.aws_ls(EXAMPLE_DIR, prepend_path=False) +def test_ls_without_prepend_path(aws_utils: AwsUtils, mock_exec_cmd): + result = aws_utils.ls(EXAMPLE_DIR, prepend_path=False) assert len(result) == len(EXAMPLE_FILES) assert result[0] == EXAMPLE_FILES[0] mock_exec_cmd.assert_called_once() -def test_aws_ls_retries_with_s3(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): +def test_ls_retries_with_s3(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): # fails first call, succeeds second call mock_exec_cmd_failure = Mock( side_effect=[FileNotFoundError, EXAMPLE_FILES]) monkeypatch.setattr('idsse.common.aws_utils.exec_cmd', mock_exec_cmd_failure) - result = aws_utils.aws_ls(EXAMPLE_DIR) + result = aws_utils.ls(EXAMPLE_DIR) assert len(result) == 3 # ls should have eventually returned good data assert mock_exec_cmd_failure.call_count == 2 -def test_aws_ls_on_error(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): +def test_ls_on_error(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): mock_exec_cmd_failure = Mock(side_effect=PermissionError('No permissions')) monkeypatch.setattr('idsse.common.aws_utils.exec_cmd', mock_exec_cmd_failure) - result = aws_utils.aws_ls(EXAMPLE_DIR) + result = aws_utils.ls(EXAMPLE_DIR) assert result == [] mock_exec_cmd_failure.assert_called_once() -def test_aws_cp_succeeds(aws_utils: AwsUtils, mock_exec_cmd): +def test_cp_succeeds(aws_utils: AwsUtils, mock_exec_cmd): path = f'{EXAMPLE_DIR}file.grib2.idx' dest = f'{EXAMPLE_DIR}new_file.grib2.idx' - copy_success = aws_utils.aws_cp(path, dest) + copy_success = aws_utils.cp(path, dest) assert copy_success -def test_aws_cp_retries_with_s3_command_line(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): +def test_cp_retries_with_s3_command_line(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): mock_exec_cmd_failure = Mock( side_effect=[FileNotFoundError, ['cp worked']]) monkeypatch.setattr('idsse.common.aws_utils.exec_cmd', mock_exec_cmd_failure) - copy_success = aws_utils.aws_cp('s3:/some/path', 's3:/new/path') + copy_success = aws_utils.cp('s3:/some/path', 's3:/new/path') assert copy_success assert mock_exec_cmd_failure.call_count == 2 -def test_aws_cp_fails(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): +def test_cp_fails(aws_utils: AwsUtils, monkeypatch: MonkeyPatch): mock_exec_cmd_failure = Mock( side_effect=[FileNotFoundError, Exception('unexpected bad thing happened')]) monkeypatch.setattr('idsse.common.aws_utils.exec_cmd', mock_exec_cmd_failure) - copy_success = aws_utils.aws_cp('s3:/some/path', 's3:/new/path') + copy_success = aws_utils.cp('s3:/some/path', 's3:/new/path') assert not copy_success assert mock_exec_cmd_failure.call_count == 2 @@ -155,8 +159,8 @@ def test_get_issues(aws_utils: AwsUtils, mock_exec_cmd): result = aws_utils.get_issues( issue_start=EXAMPLE_ISSUE, issue_end=EXAMPLE_VALID, num_issues=2) assert len(result) == 2 - assert result[0] == EXAMPLE_VALID - assert result[1] == EXAMPLE_VALID - timedelta(hours=1) + assert result[0] == EXAMPLE_VALID - timedelta(hours=1) + assert result[1] == EXAMPLE_VALID def test_get_issues_with_same_start_stop(aws_utils: AwsUtils, mock_exec_cmd): diff --git a/python/idsse_common/test/test_http_utils.py b/python/idsse_common/test/test_http_utils.py new file mode 100644 index 0000000..fbffd08 --- /dev/null +++ b/python/idsse_common/test/test_http_utils.py @@ -0,0 +1,181 @@ +"""Test suite for http_utils.py""" +# ---------------------------------------------------------------------------------- +# Created on Tue Dec 3 +# +# Copyright (c) 2023 Colorado State University. All rights reserved. (1) +# Copyright (c) 2023 Regents of the University of Colorado. All rights reserved. (2) +# +# Contributors: +# Paul Hamer (1) +# +# ---------------------------------------------------------------------------------- +# pylint: disable=missing-function-docstring,redefined-outer-name,pointless-statement +# pylint: disable=invalid-name,unused-argument, duplicate-code, line-too-long + +from datetime import datetime, timedelta, UTC + +from pytest import fixture +from pytest_httpserver import HTTPServer + +from idsse.common.http_utils import HttpUtils +from idsse.testing.utils.resources import get_resource_from_file + + +EXAMPLE_ISSUE = datetime(2024, 10, 30, 20, 56, 40, tzinfo=UTC) +EXAMPLE_VALID = datetime(2024, 10, 30, 20, 56, 40, tzinfo=UTC) + +EXAMPLE_URL = 'http://127.0.0.1:5000/data/' +EXAMPLE_ENDPOINT = '3DRefl/MergedReflectivityQC_00.50' +EXAMPLE_PROD_DIR = '3DRefl/MergedReflectivityQC_00.50/' +EXAMPLE_FILES = ['MRMS_MergedReflectivityQC_00.50.latest.grib2.gz', + 'MRMS_MergedReflectivityQC_00.50_20241030-205438.grib2.gz', + 'MRMS_MergedReflectivityQC_00.50_20241030-205640.grib2.gz'] + +EXAMPLE_VALID_FILES = ['MRMS_MergedReflectivityQC_00.50.latest.grib2.gz', + 'MRMS_MergedReflectivityQC_00.50_20241030-205438_20241030-205438.grib2.gz', + 'MRMS_MergedReflectivityQC_00.50_20241030-205640_20241030-205640.grib2.gz'] + +EXAMPLE_RETURN = get_resource_from_file('idsse.testing.idsse_common', + 'mrms_response.html') +EXAMPLE_VALID_RETURN = get_resource_from_file('idsse.testing.idsse_common', + 'mrms_valid_response.html') + +# fixtures +@fixture(scope="session") +def httpserver_listen_address(): + return "127.0.0.1", 5000 + +@fixture +def http_utils() -> HttpUtils: + EXAMPLE_BASE_DIR = 'http://127.0.0.1:5000/data/' + EXAMPLE_SUB_DIR = '3DRefl/MergedReflectivityQC_00.50/' + EXAMPLE_FILE_BASE = ('MRMS_MergedReflectivityQC_00.50_' + '{issue.year:04d}{issue.month:02d}{issue.day:02d}' + '-{issue.hour:02d}{issue.minute:02d}{issue.second:02d}') + EXAMPLE_FILE_EXT = '.grib2.gz' + + return HttpUtils(EXAMPLE_BASE_DIR, EXAMPLE_SUB_DIR, EXAMPLE_FILE_BASE, EXAMPLE_FILE_EXT) + +@fixture +def http_utils_with_valid() -> HttpUtils: + EXAMPLE_BASE_DIR = 'http://127.0.0.1:5000/data/' + EXAMPLE_SUB_DIR = '3DRefl/MergedReflectivityQC_00.50/' + EXAMPLE_FILE_BASE = ('MRMS_MergedReflectivityQC_00.50_' + '{issue.year:04d}{issue.month:02d}{issue.day:02d}' + '-{issue.hour:02d}{issue.minute:02d}{issue.second:02d}_' + '{valid.year:04d}{valid.month:02d}{valid.day:02d}' + '-{valid.hour:02d}{valid.minute:02d}{valid.second:02d}') + EXAMPLE_FILE_EXT = '.grib2.gz' + + return HttpUtils(EXAMPLE_BASE_DIR, EXAMPLE_SUB_DIR, EXAMPLE_FILE_BASE, EXAMPLE_FILE_EXT) + + +@fixture +def http_utils_with_wild() -> HttpUtils: + EXAMPLE_BASE_DIR = 'http://127.0.0.1:5000/data/' + EXAMPLE_SUB_DIR = '3DRefl/MergedReflectivityQC_00.50/' + EXAMPLE_FILE_BASE = ('MRMS_MergedReflectivityQC_00.50_{issue.year:04d}{issue.month:02d}{issue.day:02d}' + '-{issue.hour:02d}{issue.minute:02d}?{issue.second:02d}') + EXAMPLE_FILE_EXT = '.grib2.gz' + + return HttpUtils(EXAMPLE_BASE_DIR, EXAMPLE_SUB_DIR, EXAMPLE_FILE_BASE, EXAMPLE_FILE_EXT) + +# test class methods +def test_get_path(http_utils: HttpUtils): + result_path = http_utils.get_path(EXAMPLE_ISSUE, EXAMPLE_VALID) + assert result_path == f'{EXAMPLE_URL}{EXAMPLE_PROD_DIR}MRMS_MergedReflectivityQC_00.50_20241030-205640.grib2.gz' + + +def test_ls(http_utils: HttpUtils, httpserver: HTTPServer): + httpserver.expect_request('/data/'+EXAMPLE_ENDPOINT).respond_with_data(EXAMPLE_RETURN, + content_type="text/plain") + result = http_utils.ls(EXAMPLE_URL + EXAMPLE_ENDPOINT) + assert len(result) == len(EXAMPLE_FILES) + assert result[0] == f'{EXAMPLE_URL}{EXAMPLE_PROD_DIR}{EXAMPLE_FILES[-1]}' + + +def test_ls_without_prepend_path(http_utils: HttpUtils, httpserver: HTTPServer): + httpserver.expect_request('/data/'+EXAMPLE_ENDPOINT).respond_with_data(EXAMPLE_RETURN, + content_type="text/plain") + result = http_utils.ls(EXAMPLE_URL + EXAMPLE_ENDPOINT, prepend_path=False) + assert len(result) == len(EXAMPLE_FILES) + assert result[0] == EXAMPLE_FILES[-1] + + +def test_ls_on_error(http_utils: HttpUtils, httpserver: HTTPServer): + httpserver.expect_request('/data/'+EXAMPLE_ENDPOINT).respond_with_data('', content_type="text/plain") + result = http_utils.ls(EXAMPLE_URL + EXAMPLE_ENDPOINT) + assert result == [] + + +def test_cp_succeeds(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_PROD_DIR+'/temp.grib2.gz' + httpserver.expect_request(url).respond_with_data(bytes([0,1,2]), status=200, + content_type="application/octet-stream") + path = f'{EXAMPLE_URL}{EXAMPLE_PROD_DIR}/temp.grib2.gz' + dest = '/tmp/temp.grib2.gz' + + copy_success = http_utils.cp(path, dest) + assert copy_success + +def test_cp_fails(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_PROD_DIR+'/temp.grib2.gz' + httpserver.expect_request(url).respond_with_data(bytes([0, 1, 2]), status=404, + content_type="application/octet-stream") + path = f'{EXAMPLE_URL}{EXAMPLE_PROD_DIR}/temp.grib2.gz' + dest = '/tmp/temp.grib2.gz' + copy_success = http_utils.cp(path, dest) + assert not copy_success + + +def test_check_for_succeeds(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT + httpserver.expect_request(url).respond_with_data(EXAMPLE_RETURN, content_type="text/plain") + + result = http_utils.check_for(EXAMPLE_ISSUE, EXAMPLE_VALID) + assert result is not None + assert result == (EXAMPLE_VALID, f'{EXAMPLE_URL}{EXAMPLE_PROD_DIR}{EXAMPLE_FILES[-1]}') + + +def test_check_for_does_not_find_valid(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT + httpserver.expect_request(url).respond_with_data('', content_type="text/plain") + unexpected_valid = datetime(1970, 10, 3, 23, tzinfo=UTC) + result = http_utils.check_for(EXAMPLE_ISSUE, unexpected_valid) + assert result is None + + +def test_get_issues(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/' + EXAMPLE_ENDPOINT +'/' + httpserver.expect_request(url).respond_with_data(EXAMPLE_RETURN, content_type="text/plain") + result = http_utils.get_issues(issue_start=EXAMPLE_ISSUE, + time_delta=timedelta(minutes=1)) + assert len(result) == 1 + assert result[0] == EXAMPLE_ISSUE + + +def test_get_issues_with_same_start_stop(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT+'/' + httpserver.expect_request(url).respond_with_data(EXAMPLE_RETURN, content_type="text/plain") + result = http_utils.get_issues(issue_start=EXAMPLE_ISSUE, issue_end=EXAMPLE_ISSUE, time_delta=timedelta(minutes=1)) + assert len(result) == 1 + assert result[0] == EXAMPLE_ISSUE + +def test_get_valids(http_utils: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT+'/' + httpserver.expect_request(url).respond_with_data(EXAMPLE_RETURN, content_type="text/plain") + result = http_utils.get_valids(EXAMPLE_ISSUE) + assert len(result) == 0 + +def test_get_valids_all(http_utils_with_valid: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT+'/' + httpserver.expect_request(url).respond_with_data(EXAMPLE_VALID_RETURN, content_type="text/plain") + result = http_utils_with_valid.get_valids(EXAMPLE_ISSUE) + assert len(result) == 1 + + +def test_get_valids_with_wildcards(http_utils_with_wild: HttpUtils, httpserver: HTTPServer): + url = '/data/'+EXAMPLE_ENDPOINT+'/' + httpserver.expect_request(url).respond_with_data(EXAMPLE_RETURN, content_type="text/plain") + result = http_utils_with_wild.get_valids(EXAMPLE_ISSUE) + assert len(result) == 0 diff --git a/python/idsse_common/test/test_path_builder.py b/python/idsse_common/test/test_path_builder.py index 7843199..7b0dfb8 100644 --- a/python/idsse_common/test/test_path_builder.py +++ b/python/idsse_common/test/test_path_builder.py @@ -9,34 +9,11 @@ # # -------------------------------------------------------------------------------- # pylint: disable=missing-function-docstring,invalid-name,redefined-outer-name,protected-access -# cspell:words pathbuilder -from datetime import datetime, timedelta, UTC -import pytest - -from idsse.common.utils import TimeDelta -from idsse.common.path_builder import PathBuilder - - -def test_from_dir_filename_creates_valid_pathbuilder(): - directory = './test_directory' - filename = 'some_file.txt' - path_builder = PathBuilder.from_dir_filename(directory, filename) - - assert isinstance(path_builder, PathBuilder) - assert path_builder._basedir == directory - assert path_builder._file_ext == '' - - -def test_from_path_creates_valid_pathbuilder(): - base_dir = './test_directory' - path_builder = PathBuilder.from_path(f'{base_dir}/some_file.txt') - - assert isinstance(path_builder, PathBuilder) - assert path_builder._basedir == base_dir - assert path_builder._file_base == base_dir - assert path_builder._file_ext == '' +from datetime import datetime, UTC +from pytest import fixture, raises +from idsse.common.path_builder import TimeDelta, PathBuilder # properties EXAMPLE_BASE_DIR = './some/directory' @@ -44,13 +21,56 @@ def test_from_path_creates_valid_pathbuilder(): EXAMPLE_FILE = 'my_file' EXAMPLE_FILE_EXT = '.txt' +EXAMPLE_ISSUE = datetime(1970, 10, 3, 12, tzinfo=UTC) # a.k.a. issued at +EXAMPLE_VALID = datetime(1970, 10, 3, 14, tzinfo=UTC) # a.k.a. valid until +EXAMPLE_LEAD = TimeDelta(EXAMPLE_VALID - EXAMPLE_ISSUE) # a.k.a. duration of time that issue lasts +EXAMPLE_FULL_PATH = '~/blend.19701003/12/core/blend.t12z.core.f002.co.grib2.idx' -@pytest.fixture + +@fixture def local_path_builder() -> PathBuilder: # create example Pa†hBuilder instance using test strings return PathBuilder(EXAMPLE_BASE_DIR, EXAMPLE_SUB_DIR, EXAMPLE_FILE, EXAMPLE_FILE_EXT) +@fixture +def path_builder() -> PathBuilder: + subdirectory_pattern = ( + 'blend.{issue.year:04d}{issue.month:02d}{issue.day:02d}/{issue.hour:02d}/core/' + ) + file_base_pattern = 'blend.t{issue.hour:02d}z.core.f{lead.hour:03d}.co' + return PathBuilder('~', subdirectory_pattern, file_base_pattern, 'grib2.idx') + + +@fixture +def path_builder_with_region() -> PathBuilder: + subdirectory_pattern = ( + 'blend.{issue.year:04d}{issue.month:02d}{issue.day:02d}/{issue.hour:02d}/core/' + ) + file_base_pattern = 'blend.t{issue.hour:02d}z.core.f{lead.hour:03d}.{region:2s}' + return PathBuilder('~', subdirectory_pattern, file_base_pattern, 'grib2.idx') + + +def test_from_dir_filename_creates_valid_path_builder(): + directory = './test_directory' + filename = 'some_file.txt' + path_builder = PathBuilder.from_dir_filename(directory, filename) + + assert isinstance(path_builder, PathBuilder) + assert path_builder.base_dir == directory + assert path_builder.file_ext == '.txt' + + +def test_from_path_creates_valid_path_builder(): + base_dir = './test_directory' + filename = 'some_file.txt' + path_builder = PathBuilder.from_path(f'{base_dir}/{filename}') + assert isinstance(path_builder, PathBuilder) + assert path_builder.base_dir == base_dir + assert path_builder.file_base == filename + assert path_builder.file_ext == '.txt' + + def test_dir_fmt(local_path_builder: PathBuilder): assert local_path_builder.dir_fmt == f'{EXAMPLE_BASE_DIR}/{EXAMPLE_SUB_DIR}' @@ -69,31 +89,14 @@ def test_path_fmt(local_path_builder: PathBuilder): ) -# methods -EXAMPLE_ISSUE = datetime(1970, 10, 3, 12, tzinfo=UTC) # a.k.a. issued at -EXAMPLE_VALID = datetime(1970, 10, 3, 14, tzinfo=UTC) # a.k.a. valid until -EXAMPLE_LEAD = TimeDelta(EXAMPLE_VALID - EXAMPLE_ISSUE) # a.k.a. duration of time that issue lasts - -EXAMPLE_FULL_PATH = '~/blend.19701003/12/core/blend.t12z.core.f002.co.grib2.idx' - - -@pytest.fixture -def path_builder() -> PathBuilder: - subdirectory_pattern = ( - 'blend.{issue.year:04d}{issue.month:02d}{issue.day:02d}/{issue.hour:02d}/core/' - ) - file_base_pattern = 'blend.t{issue.hour:02d}z.core.f{lead.hour:03d}.co' - return PathBuilder('~', subdirectory_pattern, file_base_pattern, 'grib2.idx') - - def test_build_dir_gets_issue_valid_and_lead(path_builder: PathBuilder): - result_dict = path_builder.build_dir(issue=EXAMPLE_ISSUE) - assert result_dict == '~/blend.19701003/12/core/' + result = path_builder.build_dir(issue=EXAMPLE_ISSUE) + assert result == '~/blend.19701003/12/core/' def test_build_dir_fails_without_issue(path_builder: PathBuilder): - result_dict = path_builder.build_dir(issue=None) - assert result_dict is None + result = path_builder.build_dir(issue=None) + assert result is None def test_build_filename(path_builder: PathBuilder): @@ -106,6 +109,38 @@ def test_build_path(path_builder: PathBuilder): assert result_filepath == '~/blend.19701003/12/core/blend.t12z.core.f002.co.grib2.idx' +def test_build_path_with_invalid_lead(path_builder: PathBuilder): + # if lead needs more than 3 chars to be represented, ValueError will be raised + with raises(ValueError): + path_builder.build_path(issue=EXAMPLE_ISSUE, + lead=EXAMPLE_LEAD*1000) + + +def test_build_path_with_region(path_builder_with_region: PathBuilder): + region = 'co' + result = path_builder_with_region.build_path(issue=EXAMPLE_ISSUE, + lead=EXAMPLE_LEAD, + region=region) + result_dict = path_builder_with_region.parse_path(result) + assert result_dict['issue'] == EXAMPLE_ISSUE + assert result_dict['lead'] == EXAMPLE_LEAD + assert result_dict['region'] == region + + +def test_build_path_with_invalid_region(path_builder_with_region: PathBuilder): + # if region is more than 2 chars, ValueError will be raised + with raises(ValueError): + path_builder_with_region.build_path(issue=EXAMPLE_ISSUE, + lead=EXAMPLE_LEAD, + region='conus') + + +def test_build_path_with_required_but_missing_region(path_builder_with_region: PathBuilder): + # if a required variable (region) is not provided, KeyError will be raised + with raises(KeyError): + path_builder_with_region.build_path(issue=EXAMPLE_ISSUE, lead=EXAMPLE_LEAD) + + def test_parse_dir(path_builder: PathBuilder): result_dict = path_builder.parse_dir(EXAMPLE_FULL_PATH) @@ -137,57 +172,3 @@ def test_get_valid_returns_none_when_issue_or_lead_failed(path_builder: PathBuil result_valid = path_builder.get_valid(path_with_invalid_lead) assert result_valid is None - - -# static methods -def test_get_issue_from_time_args(path_builder: PathBuilder): - parsed_dict = path_builder.parse_path(EXAMPLE_FULL_PATH) - issue_result = PathBuilder.get_issue_from_time_args(parsed_args=parsed_dict) - - assert issue_result == EXAMPLE_ISSUE - - -def test_get_issue_returns_none_if_args_empty(): - issue_result = PathBuilder.get_issue_from_time_args({}) - assert issue_result is None - - -def test_get_valid_from_time_args(): - parsed_dict = {} - parsed_dict['valid.year'] = 1970 - parsed_dict['valid.month'] = 10 - parsed_dict['valid.day'] = 3 - parsed_dict['valid.hour'] = 14 - - valid_result = PathBuilder.get_valid_from_time_args(parsed_dict) - assert valid_result == EXAMPLE_VALID - - -def test_get_valid_returns_none_if_args_empty(): - valid_result = PathBuilder.get_valid_from_time_args({}) - assert valid_result is None - - -def test_get_valid_from_time_args_calculates_based_on_lead(path_builder: PathBuilder): - parsed_dict = path_builder.parse_path(EXAMPLE_FULL_PATH) - result_valid: datetime = PathBuilder.get_valid_from_time_args(parsed_args=parsed_dict) - assert result_valid == EXAMPLE_VALID - - -def test_get_lead_from_time_args(path_builder: PathBuilder): - parsed_dict = path_builder.parse_path(EXAMPLE_FULL_PATH) - lead_result: timedelta = PathBuilder.get_lead_from_time_args(parsed_dict) - assert lead_result.seconds == EXAMPLE_LEAD.minute * 60 - - -def test_calculate_issue_from_valid_and_lead(): - parsed_dict = { - 'valid.year': 1970, - 'valid.month': 10, - 'valid.day': 3, - 'valid.hour': 14, - 'lead.hour': 2 - } - - result_issue = PathBuilder.get_issue_from_time_args(parsed_args=parsed_dict) - assert result_issue == EXAMPLE_ISSUE