From 48f836c763e18ec4a99a38c2ad0352779bf9eff0 Mon Sep 17 00:00:00 2001 From: Costas Tyfoxylos Date: Fri, 4 Oct 2024 12:11:29 +0200 Subject: [PATCH] Reapply latest changes with fixes on dependencies. --- .github/workflows/main.yml | 4 +- .python-version | 2 +- Pipfile | 6 +- Pipfile.lock | 20 ++- _CI/files/environment_variables.json | 2 +- .../awsfindingsmanagerlib.py | 132 ++++++++++++------ dev-requirements.txt | 6 +- requirements.txt | 2 +- tests/test_suppressions.py | 52 ++++--- tests/utils.py | 29 ++-- tox.ini | 4 +- 11 files changed, 161 insertions(+), 98 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5f7214e..e427378 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.8.18 + python-version: 3.11.10 - name: Install pipenv run: pip install pipenv @@ -46,7 +46,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.8.18 + python-version: 3.11.10 - name: Install pipenv run: pip install pipenv diff --git a/.python-version b/.python-version index 9ad6380..3e72aa6 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.8.18 +3.11.10 diff --git a/Pipfile b/Pipfile index b97bb98..3d13456 100644 --- a/Pipfile +++ b/Pipfile @@ -8,16 +8,16 @@ sphinx = ">=7.0,<8.0" sphinx-rtd-theme = ">=1.0,<2.0" prospector = ">=1.8,<2.0" coverage = ">=7,<8.0" -nose = ">=1.3,<2.0" +pynose = ">=1.5.3,<2.0" nose-htmloutput = ">=0.1,<1.0" -tox = ">=4.0<5.0" +tox = ">=4.0,<5.0" betamax = ">=0.8,<1.0" betamax-serializers = "~=0.2,<1.0" semver = ">=3.0,<4.0" gitwrapperlib = ">=1.0,<2.0" twine = ">=4.0,<5.0" coloredlogs = ">=15.0,<16.0" -emoji = ">=2.0,<3.0" +emoji = ">=2.13.2,<3.0" toml = ">=0.1,<1.0" typing-extensions = ">=4.0,<5.0" astroid = "==2.15.6" diff --git a/Pipfile.lock b/Pipfile.lock index 8dab670..015de43 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "c749e461c12fd0d2dc102cd4b8eebab9051e34c21128210fcdcdef3549a11633" + "sha256": "23ca352a4cc39d46eefa3929d4976fd6a077a798a1e5e3ca0297baa116264127" }, "pipfile-spec": 6, "requires": {}, @@ -166,7 +166,7 @@ "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427" ], "index": "pypi", - "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'", "version": "==2.9.0.post0" }, "pyyaml": { @@ -259,7 +259,7 @@ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" ], - "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'", "version": "==1.16.0" }, "urllib3": { @@ -871,7 +871,6 @@ "sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a", "sha256:f1bffef9cbc82628f6e7d7b40d7e255aefaa1adb6a1b1d26c69a8b79e6208a98" ], - "index": "pypi", "version": "==1.3.7" }, "nose-htmloutput": { @@ -927,7 +926,7 @@ "sha256:49236fe334652d1229a85bc10789ac4390cc5027f6becda989ea473f893dde57" ], "index": "pypi", - "markers": "python_version < '4.0' and python_full_version >= '3.8.1'", + "markers": "python_full_version >= '3.8.1' and python_version < '4.0'", "version": "==1.11.0" }, "pycodestyle": { @@ -997,6 +996,15 @@ "markers": "python_full_version >= '3.6.2'", "version": "==0.7" }, + "pynose": { + "hashes": [ + "sha256:1b00ab94447cd7fcbb0a344fc1435137404c043db4a8e3cda63ca2893f8e5903", + "sha256:f50091f3b11524a0c8a328b3075333af8317cc7a15378494789d88d5cf53fe11" + ], + "index": "pypi", + "markers": "python_version >= '3.7'", + "version": "==1.5.3" + }, "pyproject-api": { "hashes": [ "sha256:3d7d347a047afe796fd5d1885b1e391ba29be7169bd2f102fcd378f04273d228", @@ -1233,7 +1241,7 @@ "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f" ], "index": "pypi", - "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'", "version": "==0.10.2" }, "tomlkit": { diff --git a/_CI/files/environment_variables.json b/_CI/files/environment_variables.json index acf36e8..901539f 100644 --- a/_CI/files/environment_variables.json +++ b/_CI/files/environment_variables.json @@ -1,6 +1,6 @@ { "PIPENV_VENV_IN_PROJECT": "true", - "PIPENV_DEFAULT_PYTHON_VERSION": "3.8", + "PIPENV_DEFAULT_PYTHON_VERSION": "3.11", "PYPI_URL": "https://upload.pypi.org/legacy/", "PROJECT_SLUG": "awsfindingsmanagerlib" } diff --git a/awsfindingsmanagerlib/awsfindingsmanagerlib.py b/awsfindingsmanagerlib/awsfindingsmanagerlib.py index a87aff4..5331914 100755 --- a/awsfindingsmanagerlib/awsfindingsmanagerlib.py +++ b/awsfindingsmanagerlib/awsfindingsmanagerlib.py @@ -82,7 +82,8 @@ class Finding: def __init__(self, data: Dict) -> None: self._data = self._validate_data(data) - self._logger = logging.getLogger(f'{LOGGER_BASENAME}.{self.__class__.__name__}') + self._logger = logging.getLogger( + f'{LOGGER_BASENAME}.{self.__class__.__name__}') self._matched_rule = None def __hash__(self) -> int: @@ -104,7 +105,8 @@ def __ne__(self, other: Finding) -> bool: def _validate_data(data: Dict) -> Dict: missing = set(Finding.required_fields) - set(data.keys()) if missing: - raise InvalidFindingData(f'Missing required keys: "{missing}" for data with ID "{data.get("Id")}"') + raise InvalidFindingData( + f'Missing required keys: "{missing}" for data with ID "{data.get("Id")}"') return data @property @@ -116,7 +118,8 @@ def matched_rule(self) -> Rule: def matched_rule(self, rule) -> None: """The matched rule setter that is registered in the finding.""" if not isinstance(rule, Rule): - raise InvalidRuleType(f'The argument provided is not a valid rule object. Received: "{rule}"') + raise InvalidRuleType( + f'The argument provided is not a valid rule object. Received: "{rule}"') self._matched_rule = rule @property @@ -277,7 +280,8 @@ def _parse_date_time(self, datetime_string) -> Optional[datetime]: try: return parse(datetime_string) except ValueError: - self._logger.warning(f'Could not automatically parse datetime string: "{datetime_string}"') + self._logger.warning( + f'Could not automatically parse datetime string: "{datetime_string}"') return None @property @@ -350,16 +354,20 @@ def is_matching_rule(self, rule: Rule) -> bool: if not isinstance(rule, Rule): raise InvalidRuleType(rule) if any([ - self.match_if_set(self.security_control_id, rule.security_control_id), + self.match_if_set(self.security_control_id, + rule.security_control_id), self.match_if_set(self.control_id, rule.rule_or_control_id), self.match_if_set(self.rule_id, rule.rule_or_control_id) ]): - self._logger.debug(f'Matched with rule "{rule.note}" on one of "control_id, security_control_id"') + self._logger.debug( + f'Matched with rule "{rule.note}" on one of "control_id, security_control_id"') if not any([rule.tags, rule.resource_id_regexps]): - self._logger.debug(f'Rule "{rule.note}" does not seem to have filters for resources or tags.') + self._logger.debug( + f'Rule "{rule.note}" does not seem to have filters for resources or tags.') return True if any([self.is_matching_tags(rule.tags), self.is_matching_resource_ids(rule.resource_id_regexps)]): - self._logger.debug(f'Matched with rule "{rule.note}" either on resources or tags.') + self._logger.debug( + f'Matched with rule "{rule.note}" either on resources or tags.') return True return False @@ -368,7 +376,8 @@ class Rule: """Models a suppression rule.""" def __init__(self, note: str, action: str, match_on: Dict) -> None: - self._data = validate_rule_data({'note': note, 'action': action, 'match_on': match_on}) + self._data = validate_rule_data( + {'note': note, 'action': action, 'match_on': match_on}) def __hash__(self) -> int: return hash(self.note) @@ -508,7 +517,8 @@ def __init__(self, denied_account_ids: Optional[List[str]] = None, strict_mode: bool = True, suppress_label: str = None): - self._logger = logging.getLogger(f'{LOGGER_BASENAME}.{self.__class__.__name__}') + self._logger = logging.getLogger( + f'{LOGGER_BASENAME}.{self.__class__.__name__}') self.allowed_regions, self.denied_regions = validate_allowed_denied_regions(allowed_regions, denied_regions) self.allowed_account_ids, self.denied_account_ids = validate_allowed_denied_account_ids(allowed_account_ids, @@ -516,7 +526,8 @@ def __init__(self, self.sts = self._get_sts_client() self.ec2 = self._get_ec2_client(region) self._aws_regions = None - self.aws_region = self._validate_region(region) or self._sts_client_config_region + self.aws_region = self._validate_region( + region) or self._sts_client_config_region self._rules = set() self._strict_mode = strict_mode self._rules_errors = [] @@ -611,7 +622,8 @@ def _get_security_hub_client(region: str): client = boto3.client('securityhub', **kwargs) except (botocore.exceptions.NoRegionError, botocore.exceptions.InvalidRegionError) as msg: - raise NoRegion(f'Security Hub client requires a valid region set to connect, message was:{msg}') from None + raise NoRegion( + f'Security Hub client requires a valid region set to connect, message was: {msg}') from None return client def _get_security_hub_paginator_iterator(self, region: str, operation_name: str, query_filter: dict): @@ -631,7 +643,8 @@ def _get_ec2_client(region: str): except (botocore.exceptions.NoRegionError, botocore.exceptions.InvalidRegionError, botocore.exceptions.EndpointConnectionError) as msg: - raise NoRegion(f'Ec2 client requires a valid region set to connect, message was:{msg}') from None + raise NoRegion( + f'Ec2 client requires a valid region set to connect, message was: {msg}') from None except (botocore.exceptions.ClientError, botocore.exceptions.NoCredentialsError) as msg: raise InvalidOrNoCredentials(msg) from None return client @@ -646,14 +659,20 @@ def regions(self): self._aws_regions = [region.get('RegionName') for region in self._describe_ec2_regions() if region.get('OptInStatus', '') != 'not-opted-in'] - self._logger.debug(f'Regions in EC2 that were opted in are : {self._aws_regions}') + self._logger.debug( + f'Regions in EC2 that were opted in are: {self._aws_regions}') if self.allowed_regions: - self._aws_regions = set(self._aws_regions).intersection(set(self.allowed_regions)) - self._logger.debug(f'Working on allowed regions {self._aws_regions}') + self._aws_regions = set(self._aws_regions).intersection( + set(self.allowed_regions)) + self._logger.debug( + f'Working on allowed regions {self._aws_regions}') elif self.denied_regions: - self._logger.debug(f'Excluding denied regions {self.denied_regions}') - self._aws_regions = set(self._aws_regions) - set(self.denied_regions) - self._logger.debug(f'Working on non-denied regions {self._aws_regions}') + self._logger.debug( + f'Excluding denied regions {self.denied_regions}') + self._aws_regions = set(self._aws_regions) - \ + set(self.denied_regions) + self._logger.debug( + f'Working on non-denied regions {self._aws_regions}') else: self._logger.debug('Working on all regions') return self._aws_regions @@ -663,10 +682,12 @@ def _get_aggregating_region(self): try: client = self._get_security_hub_client(self.aws_region) data = client.list_finding_aggregators() - aggregating_region = data.get('FindingAggregators')[0].get('FindingAggregatorArn').split(':')[3] + aggregating_region = data.get('FindingAggregators')[0].get( + 'FindingAggregatorArn').split(':')[3] self._logger.info(f'Found aggregating region {aggregating_region}') except (IndexError, botocore.exceptions.ClientError): - self._logger.debug('Could not get aggregating region, either not set, or a client error') + self._logger.debug( + 'Could not get aggregating region, either not set, or a client error') return aggregating_region @staticmethod @@ -688,7 +709,8 @@ def _calculate_account_id_filter(allowed_account_ids: Optional[List[str]], if any([allowed_account_ids, denied_account_ids]): comparison = 'EQUALS' if allowed_account_ids else 'NOT_EQUALS' iterator = allowed_account_ids if allowed_account_ids else denied_account_ids - aws_account_ids = [{'Comparison': comparison, 'Value': account} for account in iterator] + aws_account_ids = [{'Comparison': comparison, + 'Value': account} for account in iterator] return aws_account_ids # pylint: disable=dangerous-default-value @@ -711,7 +733,8 @@ def update_query_for_account_ids(query_filter: Dict = DEFAULT_SECURITY_HUB_FILTE """ query_filter = deepcopy(query_filter) - aws_account_ids = FindingsManager._calculate_account_id_filter(allowed_account_ids, denied_account_ids) + aws_account_ids = FindingsManager._calculate_account_id_filter( + allowed_account_ids, denied_account_ids) if aws_account_ids: query_filter.update({'AwsAccountId': aws_account_ids}) return query_filter @@ -720,7 +743,8 @@ def update_query_for_account_ids(query_filter: Dict = DEFAULT_SECURITY_HUB_FILTE def _get_findings(self, query_filter: Dict): findings = set() aggregating_region = self._get_aggregating_region() - regions_to_retrieve = [aggregating_region] if aggregating_region else self.regions + regions_to_retrieve = [ + aggregating_region] if aggregating_region else self.regions for region in regions_to_retrieve: self._logger.debug(f'Trying to get findings for region {region}') iterator = self._get_security_hub_paginator_iterator( @@ -732,11 +756,13 @@ def _get_findings(self, query_filter: Dict): for page in iterator: for finding_data in page['Findings']: finding = Finding(finding_data) - self._logger.debug(f'Adding finding with id {finding.id}') + self._logger.debug( + f'Adding finding with id {finding.id}') findings.add(finding) except botocore.exceptions.ClientError as error: if error.response['Error']['Code'] in ['AccessDeniedException', 'InvalidAccessException']: - self._logger.debug(f'No access for Security Hub for region {region}.') + self._logger.debug( + f'No access for Security Hub for region {region}.') continue raise error return list(findings) @@ -750,7 +776,8 @@ def _get_matching_findings(rule: Rule, findings: List[Finding], logger: logging. logger.debug(f'Following findings matched with rule with note: "{rule.note}", ' f'{[finding.id for finding in matching_findings]}') else: - logger.debug('No resource id patterns or tags are provided in the rule, all findings used.') + logger.debug( + 'No resource id patterns or tags are provided in the rule, all findings used.') matching_findings = findings for finding in matching_findings: finding.matched_rule = rule @@ -771,7 +798,8 @@ def get_findings(self) -> List[Finding]: findings = list(set(all_findings)) diff = initial_size - len(findings) if diff: - self._logger.warning(f'Missmatch of finding numbers, there seems to be an overlap of {diff}') + self._logger.warning( + f'Missmatch of finding numbers, there seems to be an overlap of {diff}') return findings def get_findings_by_matching_rule(self, rule: Rule) -> List[Finding]: @@ -823,7 +851,8 @@ def _validate_rule_in_findings(self, findings: List[Finding]): NoRuleFindings if strict mode is enabled and any findings do not have matching rules. """ - no_rule_matches = [finding.id for finding in findings if not finding.matched_rule] + no_rule_matches = [ + finding.id for finding in findings if not finding.matched_rule] if no_rule_matches: message = f'Findings with the following ids "{no_rule_matches}" do not have matching rules' if self._strict_mode: @@ -844,7 +873,8 @@ def _get_suppressing_payload(self, findings: List[Finding]): A generator with suppressing payloads per common note chunked at MAX_SUPPRESSION_PAYLOAD_SIZE """ - findings = findings if isinstance(findings, (list, tuple, set)) else [findings] + findings = findings if isinstance( + findings, (list, tuple, set)) else [findings] findings = self._validate_rule_in_findings(findings) rule_findings_mapping = defaultdict(list) for finding in findings: @@ -871,7 +901,8 @@ def _get_unsuppressing_payload(self, findings: List[Finding]): A generator with unsuppressing payloads chunked at MAX_SUPPRESSION_PAYLOAD_SIZE """ - findings = findings if isinstance(findings, (list, tuple, set)) else [findings] + findings = findings if isinstance( + findings, (list, tuple, set)) else [findings] for chunk in FindingsManager._chunk([{'Id': finding.id, 'ProductArn': finding.product_arn} for finding in findings], MAX_SUPPRESSION_PAYLOAD_SIZE): @@ -892,15 +923,19 @@ def _workflow_state_change_on_findings(self, findings: List[Finding], suppress=T message_state = 'suppression' if suppress else 'unsuppression' method = self._get_suppressing_payload if suppress else self._get_unsuppressing_payload security_hub = self._get_security_hub_client(self.aws_region) - return all((result for result in self._batch_apply_payloads(security_hub, + successes, payloads = zip(*(result for result in self._batch_apply_payloads(security_hub, method(findings), # noqa message_state))) + success = all(successes) + return (success, list(payloads)) def _batch_apply_payloads(self, security_hub, payloads, message_state): for payload in payloads: - self._logger.debug(f'Sending payload {payload} for {message_state} to Security Hub.') + self._logger.debug( + f'Sending payload {payload} for {message_state} to Security Hub.') if os.environ.get('FINDINGS_MANAGER_DRY_RUN_MODE'): - self._logger.debug(f'Dry run mode is on, skipping the actual {message_state}.') + self._logger.debug( + f'Dry run mode is on, skipping the actual {message_state}.') continue yield self._batch_update_findings(security_hub, payload) @@ -935,7 +970,9 @@ def _batch_update_findings(self, security_hub, payload): security_hub: Security hub client payload: The payload to send to the service - Returns: True on success False otherwise + Returns: + tuple: A tuple containing a boolean status and the payload. + The status is True on success and False otherwise. Raises: FailedToBatchUpdate: if strict mode is set and there are failures to update. @@ -951,8 +988,9 @@ def _batch_update_findings(self, security_hub, payload): for fail in failed: id_ = fail.get('FindingIdentifier', '').get('Id') error = fail.get('ErrorMessage') - self._logger.error(f'Failed to update finding with ID: "{id_}" with error: "{error}"') - return status + self._logger.error( + f'Failed to update finding with ID: "{id_}" with error: "{error}"') + return (status, payload) def validate_finding_on_matching_rules(self, finding_data: Dict): """Validates that the provided data is correct data for a finding. @@ -982,12 +1020,14 @@ def _construct_findings_on_matching_rules(self, finding_data: Union[List[Dict], if isinstance(finding_data, dict): finding_data = [finding_data] if self._strict_mode: - findings = [self.validate_finding_on_matching_rules(payload) for payload in finding_data] + findings = [self.validate_finding_on_matching_rules( + payload) for payload in finding_data] else: findings = [] for payload in finding_data: try: - findings.append(self.validate_finding_on_matching_rules(payload)) + findings.append( + self.validate_finding_on_matching_rules(payload)) except InvalidFindingData: self._logger.error(f'Data {payload} seems to be invalid.') return [finding for finding in findings if finding] @@ -1002,7 +1042,8 @@ def suppress_finding_on_matching_rules(self, finding_data: Dict): finding_data: The data of a finding as provided by Security Hub. Returns: - True on success False otherwise. + tuple: A tuple containing a boolean status and the payload. + The status is True on success and False otherwise. Raises: InvalidFindingData: If the data is not valid finding data. @@ -1020,13 +1061,15 @@ def suppress_findings_on_matching_rules(self, finding_data: Union[List[Dict], Di finding_data: The data of a finding as provided by Security Hub. Returns: - True on success False otherwise. + tuple: A tuple containing a boolean status and the payload. + The status is True on success and False otherwise. Raises: InvalidFindingData: If any data is not valid finding data. """ - matching_findings = self._construct_findings_on_matching_rules(finding_data) + matching_findings = self._construct_findings_on_matching_rules( + finding_data) return self._workflow_state_change_on_findings(matching_findings) def get_unmanaged_suppressed_findings(self) -> List[Finding]: @@ -1042,11 +1085,12 @@ def get_unmanaged_suppressed_findings(self) -> List[Finding]: 'Comparison': 'EQUALS'}]} return self._get_findings(query) - def unsuppress_unmanaged_findings(self) -> bool: + def unsuppress_unmanaged_findings(self) -> tuple[bool, list]: """Unsuppresses findings that have not been suppressed by this library. Returns: - True on full success, False otherwise. + tuple: A tuple containing a boolean status and the payload. + The status is True on success and False otherwise. """ return self._workflow_state_change_on_findings(self.get_unmanaged_suppressed_findings(), suppress=False) diff --git a/dev-requirements.txt b/dev-requirements.txt index 29a55fd..eaca583 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,9 +9,9 @@ # sphinx>=7.3.7 ; python_version >= '3.9' sphinx-rtd-theme>=1.3.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -prospector>=1.11.0 ; python_version < '4.0' and python_full_version >= '3.8.1' +prospector>=1.11.0 ; python_full_version >= '3.8.1' and python_version < '4.0' coverage>=7.6.1 ; python_version >= '3.8' -nose>=1.3.7 +pynose>=1.5.3 ; python_version >= '3.7' nose-htmloutput>=0.6.0 tox>=4.21.2 ; python_version >= '3.8' betamax>=0.9.0 ; python_full_version >= '3.8.1' @@ -21,6 +21,6 @@ gitwrapperlib>=1.0.4 twine>=4.0.2 ; python_version >= '3.7' coloredlogs>=15.0.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' emoji>=2.14.0 ; python_version >= '3.7' -toml>=0.10.2 ; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' +toml>=0.10.2 ; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2' typing-extensions>=4.12.2 ; python_version >= '3.8' astroid==2.15.6 ; python_full_version >= '3.7.2' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 07af9d5..2777493 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ # boto3>=1.35.33 ; python_version >= '3.8' opnieuw>=1.2.1 ; python_version >= '3.7' -python-dateutil>=2.9.0.post0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' +python-dateutil>=2.9.0.post0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2' schema~=0.7.7 requests~=2.32.3 ; python_version >= '3.8' pyyaml~=6.0.2 ; python_version >= '3.8' \ No newline at end of file diff --git a/tests/test_suppressions.py b/tests/test_suppressions.py index 9bf1199..ae29be9 100644 --- a/tests/test_suppressions.py +++ b/tests/test_suppressions.py @@ -69,54 +69,68 @@ # hence no dev for S3.14 for env in ['dev', 'acc', 'prd'] if security_control_id != 'S3.14' else ['acc', 'prd']: with open(f'tests/fixtures/findings/full/{security_control_id}/{env}.json', encoding='utf-8') as findings_file: - findings_by_security_control_id_fixture[security_control_id].append(json.load(findings_file)) + findings_by_security_control_id_fixture[security_control_id].append( + json.load(findings_file)) with open('tests/fixtures/matches.json', encoding='utf-8') as matches_file: full_matches_fixture = json.load(matches_file) + +def batch_update_findings_mock(_, payload): + return (True, payload) + + class TestValidation(FindingsManagerTestCase): backend_file = './tests/fixtures/suppressions/single.yaml' def test_basic_run(self): self.assertEqual( [], - self.findings_manager._construct_findings_on_matching_rules(api_consolidated_findings_fixture['Findings']) + self.findings_manager._construct_findings_on_matching_rules( + api_consolidated_findings_fixture['Findings']) ) + class TestLegacyValidation(FindingsManagerTestCase): backend_file = './tests/fixtures/suppressions/legacy.yaml' def test_basic_run(self): self.assertEqual( [], - self.findings_manager._construct_findings_on_matching_rules(gui_legacy_findings_fixture) + self.findings_manager._construct_findings_on_matching_rules( + gui_legacy_findings_fixture) ) + class TestBasicRun(FindingsManagerTestCase): @patch( 'awsfindingsmanagerlib.FindingsManager._get_security_hub_paginator_iterator', lambda *_, **__: [api_consolidated_findings_fixture], ) - @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings') + @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings', side_effect=batch_update_findings_mock) def test_basic_run(self, _batch_update_findings_mocked: MagicMock): - self.assertTrue(self.findings_manager.suppress_matching_findings()) - self.assert_batch_update_findings_called_with( - [batch_update_findings_fixture], _batch_update_findings_mocked - ) + success, payloads = self.findings_manager.suppress_matching_findings() + self.assertTrue(success) + self.assert_batch_update_findings( + [batch_update_findings_fixture], payloads) + class TestFullSuppressions(FindingsManagerTestCase): backend_file = './tests/fixtures/suppressions/full.yaml' def test_validation(self): self.assertEqual(full_matches_fixture, - [dict(finding._data, matched_rule=finding._matched_rule._data) - for finding in self.findings_manager._construct_findings_on_matching_rules(full_findings_fixture)] - ) + [dict(finding._data, matched_rule=finding._matched_rule._data) + for finding in self.findings_manager._construct_findings_on_matching_rules(full_findings_fixture)] + ) - @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings') + @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings', side_effect=batch_update_findings_mock) def test_payload_construction(self, _batch_update_findings_mocked: MagicMock): - self.assertTrue(self.findings_manager.suppress_findings_on_matching_rules(full_findings_fixture)) - self.assert_batch_update_findings_called_with(batch_update_findings_full_fixture, _batch_update_findings_mocked) + success, payloads = self.findings_manager.suppress_findings_on_matching_rules( + full_findings_fixture) + self.assertTrue(success) + self.assert_batch_update_findings( + batch_update_findings_full_fixture, payloads) @patch( 'awsfindingsmanagerlib.FindingsManager._get_security_hub_paginator_iterator', @@ -124,9 +138,9 @@ def test_payload_construction(self, _batch_update_findings_mocked: MagicMock): 'Findings': findings_by_security_control_id_fixture[kwargs['query_filter']['ComplianceSecurityControlId'][0]['Value']] }], ) - @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings') + @patch('awsfindingsmanagerlib.FindingsManager._batch_update_findings', side_effect=batch_update_findings_mock) def test_from_query(self, _batch_update_findings_mocked: MagicMock): - self.assertTrue(self.findings_manager.suppress_matching_findings()) - self.assert_batch_update_findings_called_with( - batch_update_findings_full_fixture, _batch_update_findings_mocked - ) + success, payloads = self.findings_manager.suppress_matching_findings() + self.assertTrue(success) + self.assert_batch_update_findings( + batch_update_findings_full_fixture, payloads) diff --git a/tests/utils.py b/tests/utils.py index 7bb70d4..6112457 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -69,28 +69,25 @@ def setUp(self) -> None: self.findings_manager = FindingsManager() self.findings_manager.register_rules(rules) - def assert_batch_update_findings_called_with(self, batch_update_findings_expected: List[dict], _batch_update_findings_mocked: MagicMock): + def assert_batch_update_findings(self, batch_update_findings_expected: List[dict], batch_update_findings: List[dict]): """ - Compare expected to actual (=mocked) api call payload. - - Sadly, something like this does not work: _batch_update_findings_mocked.assert_called_once_with(ANY, batch_update_findings), - because FindingIdentifiers is a randomly ordered collection. + Compare expected to actual api call payload. """ - self.assertEqual( - len(batch_update_findings_expected), - _batch_update_findings_mocked.call_count - ) + self.assertEqual(len(batch_update_findings_expected), + len(batch_update_findings)) for expected in batch_update_findings_expected: - for call in _batch_update_findings_mocked.call_args_list: + for finding in batch_update_findings: try: - received_args = call.args[1] - self.assertEqual(expected.keys(), received_args.keys()) - self.assertEqual(expected['Note'], received_args['Note']) - self.assertEqual(expected['Workflow'], received_args['Workflow']) - self.assertEqual(len(expected['FindingIdentifiers']), len(received_args['FindingIdentifiers'])) + self.assertTrue( + set(expected.keys()).issubset(set(finding.keys()))) + self.assertEqual(expected['Note'], finding['Note']) + self.assertEqual( + expected['Workflow'], finding['Workflow']) + self.assertEqual(len(expected['FindingIdentifiers']), len( + finding['FindingIdentifiers'])) for item in expected['FindingIdentifiers']: - self.assertIn(item, received_args['FindingIdentifiers']) + self.assertIn(item, finding['FindingIdentifiers']) break except: continue diff --git a/tox.ini b/tox.ini index 216e368..414a0c6 100755 --- a/tox.ini +++ b/tox.ini @@ -5,11 +5,11 @@ # and then run "tox" from this directory. [tox] -envlist = py38, +envlist = py311, [testenv] allowlist_externals = * -commands = ./setup.py nosetests --with-coverage --cover-tests --cover-html --cover-html-dir=test-output/coverage --with-html --html-file test-output/nosetests.html +commands = nosetests --with-coverage --cover-tests --cover-html --cover-html-dir=test-output/coverage --with-html --html-file test-output/nosetests.html deps = -rrequirements.txt -rdev-requirements.txt