diff --git a/plugins/aws/fix_plugin_aws/resource/guardduty.py b/plugins/aws/fix_plugin_aws/resource/guardduty.py index d7563265c3..11ca119654 100644 --- a/plugins/aws/fix_plugin_aws/resource/guardduty.py +++ b/plugins/aws/fix_plugin_aws/resource/guardduty.py @@ -1,3 +1,4 @@ +from concurrent.futures import as_completed from datetime import datetime, timezone from typing import ClassVar, Dict, List, Optional, Tuple, Type, Any import logging @@ -1284,32 +1285,25 @@ def set_findings(builder: GraphBuilder, resource_to_set: AwsResource, to_check: resource_to_set._assessments.append(Assessment("guard_duty", provider_findings)) def parse_finding(self, source: Json) -> Finding: - def get_severity() -> str: + def get_severity() -> Severity: if not self.finding_severity: - return "MEDIUM" + return Severity.medium if self.finding_severity <= 2: - return "INFORMATIONAL" + return Severity.info elif self.finding_severity <= 4: - return "LOW" + return Severity.low elif self.finding_severity <= 6: - return "MEDIUM" + return Severity.medium elif self.finding_severity <= 8: - return "HIGH" + return Severity.high else: - return "CRITICAL" - - severity_map = { - "INFORMATIONAL": Severity.info, - "LOW": Severity.low, - "MEDIUM": Severity.medium, - "HIGH": Severity.high, - "CRITICAL": Severity.critical, - } + return Severity.critical + finding_title = self.safe_name if not self.finding_severity: finding_severity = Severity.medium else: - finding_severity = severity_map.get(get_severity(), Severity.medium) + finding_severity = get_severity() description = self.description updated_at = self.mtime details = source.get("Service", {}) @@ -1369,28 +1363,40 @@ def check_type_and_adjust_id( return finding_resources try: - detector_ids = builder.client.list( - service_name, - "list-detectors", - "DetectorIds", - ) - for detector_id in detector_ids: - finding_ids = builder.client.list( + detector_ids = builder.client.list(service_name, "list-detectors", "DetectorIds") + finding_id_futures = { + builder.submit_work( + service_name, + builder.client.list, service_name, "list-findings", "FindingIds", expected_errors=["BadRequestException"], DetectorId=detector_id, - ) + ): detector_id + for detector_id in detector_ids + } + + for future in as_completed(finding_id_futures): + detector_id = finding_id_futures[future] + finding_ids = future.result() + chunk_futures = [] for chunk_ids in chunks(finding_ids, 49): - for finding in builder.client.list( + future = builder.submit_work( + service_name, + builder.client.list, service_name, "get-findings", "Findings", expected_errors=["BadRequestException"], DetectorId=detector_id, FindingIds=chunk_ids, - ): + ) + chunk_futures.append(future) + + for chunk_future in as_completed(chunk_futures): + findings = chunk_future.result() + for finding in findings: if finding.get("AccountId", None) == builder.account.id: if instance := AwsGuardDutyFinding.from_api(finding, builder): if fr := instance.finding_resource: diff --git a/plugins/aws/test/collector_test.py b/plugins/aws/test/collector_test.py index 2c4f545cb9..4dcb0464a5 100644 --- a/plugins/aws/test/collector_test.py +++ b/plugins/aws/test/collector_test.py @@ -38,8 +38,8 @@ def count_kind(clazz: Type[AwsResource]) -> int: # make sure all threads have been joined assert len(threading.enumerate()) == 1 # ensure the correct number of nodes and edges - assert count_kind(AwsResource) == 262 - assert len(account_collector.graph.edges) == 577 + assert count_kind(AwsResource) == 261 + assert len(account_collector.graph.edges) == 575 assert len(account_collector.graph.deferred_edges) == 2 for node in account_collector.graph.nodes: if isinstance(node, AwsRegion):