From ca9fffc4f39bb04a8c3c52bf2b45c563b93b7bd7 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Mon, 9 Dec 2024 19:35:13 +0000 Subject: [PATCH] refactor the arn tree --- .../fix_plugin_aws/access_edges/__init__.py | 0 .../fix_plugin_aws/access_edges/arn_tree.py | 231 +++++++++++ .../edge_builder.py} | 373 ++---------------- .../aws/fix_plugin_aws/access_edges/types.py | 117 ++++++ plugins/aws/fix_plugin_aws/collector.py | 2 +- plugins/aws/test/acccess_edges_test.py | 83 ++-- 6 files changed, 419 insertions(+), 387 deletions(-) create mode 100644 plugins/aws/fix_plugin_aws/access_edges/__init__.py create mode 100644 plugins/aws/fix_plugin_aws/access_edges/arn_tree.py rename plugins/aws/fix_plugin_aws/{access_edges.py => access_edges/edge_builder.py} (79%) create mode 100644 plugins/aws/fix_plugin_aws/access_edges/types.py diff --git a/plugins/aws/fix_plugin_aws/access_edges/__init__.py b/plugins/aws/fix_plugin_aws/access_edges/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py b/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py new file mode 100644 index 0000000000..581fea4f94 --- /dev/null +++ b/plugins/aws/fix_plugin_aws/access_edges/arn_tree.py @@ -0,0 +1,231 @@ +from typing import List, Set +from attrs import frozen +from fix_plugin_aws.access_edges.types import ArnResourceValueKind, FixPolicyDocument, WildcardKind +from policy_sentry.util.arns import ARN +import fnmatch +import logging + + +log = logging.getLogger("fix.plugins.aws") + + +@frozen(slots=True) +class ArnResource[T]: + key: str + values: Set[T] + kind: ArnResourceValueKind + not_resource: bool + + def matches(self, segment: str) -> bool: + _match = False + match self.kind: + case ArnResourceValueKind.Any: + _match = True + case ArnResourceValueKind.Pattern: + _match = fnmatch.fnmatch(segment, self.key) + case ArnResourceValueKind.Static: + _match = segment == self.key + + if self.not_resource: + _match = not _match + + return _match + + +@frozen(slots=True) +class ArnAccountId[T]: + key: str + wildcard: bool # if the account is a wildcard, e.g. "*" or "::" + values: Set[T] + children: List[ArnResource[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.key == segment + + +@frozen(slots=True) +class ArnRegion[T]: + key: str + wildcard: bool # if the region is a wildcard, e.g. "*" or "::" + values: Set[T] + children: List[ArnAccountId[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or self.key == segment + + +@frozen(slots=True) +class ArnService[T]: + key: str + values: Set[T] + children: List[ArnRegion[T]] + + def matches(self, segment: str) -> bool: + return self.key == segment + + +@frozen(slots=True) +class ArnPartition[T]: + key: str + wildcard: bool # for the cases like "Allow": "*" on all resources + values: Set[T] + children: List[ArnService[T]] + + def matches(self, segment: str) -> bool: + return self.wildcard or segment == self.key + + +class ArnTree[T]: + def __init__(self) -> None: + self.partitions: List[ArnPartition[T]] = [] + + def add_element(self, elem: T, policy_documents: List[FixPolicyDocument]) -> None: + """ + This method iterates over every policy statement and adds corresponding arns to principal tree. + """ + + for policy_doc in policy_documents: + for statement in policy_doc.fix_statements: + if statement.effect_allow: + has_wildcard_resource = False + for resource in statement.resources: + if resource == "*": + has_wildcard_resource = True + continue + self._add_resource(resource, elem) + for not_resource in statement.not_resource: + self._add_resource(not_resource, elem, nr=True) + + if has_wildcard_resource or (not statement.resources and not statement.not_resource): + for ap in statement.actions_patterns: + if ap.kind == WildcardKind.any: + self._add_allow_all_wildcard(elem) + self._add_service(ap.service, elem) + + def _add_allow_all_wildcard(self, elem: T) -> None: + partition = next((p for p in self.partitions if p.key == "*"), None) + if not partition: + partition = ArnPartition(key="*", wildcard=True, values=set(), children=[]) + self.partitions.append(partition) + + partition.values.add(elem) + + def _add_resource(self, resource_constraint: str, elem: T, nr: bool = False) -> None: + """ + _add resource will add the principal arn at the resource level + """ + + try: + arn = ARN(resource_constraint) + # Find existing or create partition + partition = next((p for p in self.partitions if p.key == arn.partition), None) + if not partition: + partition = ArnPartition[T](key=arn.partition, wildcard=False, values=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.key == arn.service_prefix), None) + if not service: + service = ArnService[T](key=arn.service_prefix, values=set(), children=[]) + partition.children.append(service) + + # Find or create region + region_wildcard = arn.region == "*" or not arn.region + region = next((r for r in service.children if r.key == (arn.region or "*")), None) + if not region: + region = ArnRegion[T](key=arn.region or "*", wildcard=region_wildcard, values=set(), children=[]) + service.children.append(region) + + # Find or create account + account_wildcard = arn.account == "*" or not arn.account + account = next((a for a in region.children if a.key == (arn.account or "*")), None) + if not account: + account = ArnAccountId[T](key=arn.account or "*", wildcard=account_wildcard, values=set(), children=[]) + region.children.append(account) + + # Add resource + resource = next( + (r for r in account.children if r.key == arn.resource_string and r.not_resource == nr), None + ) + if not resource: + if arn.resource_string == "*": + resource_kind = ArnResourceValueKind.Any + elif "*" in arn.resource_string: + resource_kind = ArnResourceValueKind.Pattern + else: + resource_kind = ArnResourceValueKind.Static + resource = ArnResource(key=arn.resource_string, values=set(), kind=resource_kind, not_resource=nr) + account.children.append(resource) + + resource.values.add(elem) + + except Exception as e: + log.error(f"Error parsing ARN {resource_constraint}: {e}") + pass + + def _add_service(self, service_prefix: str, elem: T) -> None: + # Find existing or create partition + partition = next((p for p in self.partitions if p.key == "*"), None) + if not partition: + partition = ArnPartition(key="*", wildcard=True, values=set(), children=[]) + self.partitions.append(partition) + + # Find or create service + service = next((s for s in partition.children if s.key == service_prefix), None) + if not service: + service = ArnService(key=service_prefix, values=set(), children=[]) + partition.children.append(service) + + service.values.add(elem) + + def find_matching_values(self, resource_arn: ARN) -> Set[T]: + """ + this will be called for every resource and it must be fast + """ + result: Set[T] = set() + + matching_partitions = [p for p in self.partitions if p.key if p.matches(resource_arn.partition)] + if not matching_partitions: + return result + + matching_services = [ + s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix) + ] + if not matching_services: + return result + result.update([arn for s in matching_services for arn in s.values]) + + matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] + if not matching_regions: + return result + result.update([arn for r in matching_regions for arn in r.values]) + + matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] + if not matching_account_ids: + return result + result.update([arn for a in matching_account_ids for arn in a.values]) + + matching_resources = [ + r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string) + ] + if not matching_resources: + return result + + result.update([arn for r in matching_resources for arn in r.values]) + + return result + + +PrincipalArn = str + + +class PrincipalTree: + + def __init__(self) -> None: + self.arn_tree = ArnTree[PrincipalArn]() + + def add_principal(self, principal_arn: PrincipalArn, policy_documents: List[FixPolicyDocument]) -> None: + self.arn_tree.add_element(principal_arn, policy_documents) + + def list_principals(self, resource_arn: ARN) -> Set[str]: + return self.arn_tree.find_matching_values(resource_arn) diff --git a/plugins/aws/fix_plugin_aws/access_edges.py b/plugins/aws/fix_plugin_aws/access_edges/edge_builder.py similarity index 79% rename from plugins/aws/fix_plugin_aws/access_edges.py rename to plugins/aws/fix_plugin_aws/access_edges/edge_builder.py index 32f6482b1e..a8e0a6cd66 100644 --- a/plugins/aws/fix_plugin_aws/access_edges.py +++ b/plugins/aws/fix_plugin_aws/access_edges/edge_builder.py @@ -1,88 +1,51 @@ -from enum import Enum -import enum +import fnmatch +import logging +import re from functools import lru_cache -from attr import frozen +from typing import Callable, Dict, List, Literal, Optional, Pattern, Set, Tuple, Union + import networkx +from attr import frozen +from cloudsplaining.scan.statement_detail import StatementDetail +from fix_plugin_aws.access_edges.arn_tree import PrincipalTree +from fix_plugin_aws.access_edges.types import ( + ActionWildcardPattern, + ArnResourceValueKind, + FixPolicyDocument, + FixStatementDetail, + ResourceWildcardPattern, + WildcardKind, +) from fix_plugin_aws.resource.base import AwsAccount, AwsResource, GraphBuilder -from policy_sentry.querying.actions import get_actions_for_service -from typing import Callable, Dict, List, Literal, Set, Optional, Tuple, Union, Pattern -import fnmatch +from fix_plugin_aws.resource.iam import AwsIamGroup, AwsIamPolicy, AwsIamRole, AwsIamUser from networkx.algorithms.dag import is_directed_acyclic_graph +from policy_sentry.querying.actions import get_action_data, get_actions_for_service +from policy_sentry.querying.all import get_all_actions +from policy_sentry.querying.arns import get_matching_raw_arns, get_resource_type_name_with_raw_arn +from policy_sentry.shared.iam_data import get_service_prefix_data +from policy_sentry.util.arns import ARN, get_service_from_arn from fixlib.baseresources import ( + AccessPermission, + EdgeType, + HasResourcePolicy, PermissionCondition, - PolicySource, + PermissionLevel, PermissionScope, - AccessPermission, + PolicySource, + PolicySourceKind, ResourceConstraint, ) -from fix_plugin_aws.resource.iam import AwsIamGroup, AwsIamPolicy, AwsIamUser, AwsIamRole -from fixlib.baseresources import EdgeType, PolicySourceKind, HasResourcePolicy, PermissionLevel +from fixlib.graph import EdgeKey from fixlib.json import to_json, to_json_str from fixlib.types import Json -from cloudsplaining.scan.policy_document import PolicyDocument -from cloudsplaining.scan.statement_detail import StatementDetail -from policy_sentry.querying.actions import get_action_data -from policy_sentry.querying.all import get_all_actions -from policy_sentry.querying.arns import get_matching_raw_arns, get_resource_type_name_with_raw_arn -from policy_sentry.shared.iam_data import get_service_prefix_data -from policy_sentry.util.arns import ARN, get_service_from_arn -from fixlib.graph import EdgeKey -import re -import logging - log = logging.getLogger("fix.plugins.aws") ALL_ACTIONS = get_all_actions() -class WildcardKind(Enum): - fixed = 1 - pattern = 2 - any = 3 - - -@frozen(slots=True) -class ActionWildcardPattern: - pattern: str - service: str - kind: WildcardKind - - -class FixStatementDetail(StatementDetail): - def __init__(self, statement: Json): - super().__init__(statement) - - def pattern_from_action(action: str) -> ActionWildcardPattern: - if action == "*": - return ActionWildcardPattern(pattern=action, service="*", kind=WildcardKind.any) - - action = action.lower() - service, action_name = action.split(":", 1) - if action_name == "*": - kind = WildcardKind.any - elif "*" in action_name: - kind = WildcardKind.pattern - else: - kind = WildcardKind.fixed - - return ActionWildcardPattern(pattern=action, service=service, kind=kind) - - self.actions_patterns = [pattern_from_action(action) for action in self.actions] - self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] - self.resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.resources] - self.not_resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.not_resource] - - -class FixPolicyDocument(PolicyDocument): - def __init__(self, policy_document: Json): - super().__init__(policy_document) - - self.fix_statements = [FixStatementDetail(statement.json) for statement in self.statements] - - @frozen(slots=True) class ActionToCheck: raw: str @@ -91,284 +54,6 @@ class ActionToCheck: action_name: str -class ArnResourceValueKind(enum.Enum): - Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", - Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", - Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" - - @staticmethod - def from_str(value: str) -> "ArnResourceValueKind": - if value == "*": - return ArnResourceValueKind.Any - if "*" in value: - return ArnResourceValueKind.Pattern - return ArnResourceValueKind.Static - - -@frozen(slots=True) -class ArnResource: - value: str - principal_arns: Set[str] - kind: ArnResourceValueKind - not_resource: bool - - def matches(self, segment: str) -> bool: - _match = False - match self.kind: - case ArnResourceValueKind.Any: - _match = True - case ArnResourceValueKind.Pattern: - _match = fnmatch.fnmatch(segment, self.value) - case ArnResourceValueKind.Static: - _match = segment == self.value - - if self.not_resource: - _match = not _match - - return _match - - -@frozen(slots=True) -class ArnAccountId: - value: str - wildcard: bool # if the account is a wildcard, e.g. "*" or "::" - principal_arns: Set[str] - children: List[ArnResource] - - def matches(self, segment: str) -> bool: - return self.wildcard or self.value == segment - - -@frozen(slots=True) -class ArnRegion: - value: str - wildcard: bool # if the region is a wildcard, e.g. "*" or "::" - principal_arns: Set[str] - children: List[ArnAccountId] - - def matches(self, segment: str) -> bool: - return self.wildcard or self.value == segment - - -@frozen(slots=True) -class ArnService: - value: str - principal_arns: Set[str] - children: List[ArnRegion] - - def matches(self, segment: str) -> bool: - return self.value == segment - - -@frozen(slots=True) -class ArnPartition: - value: str - wildcard: bool # for the cases like "Allow": "*" on all resources - principal_arns: Set[str] - children: List[ArnService] - - def matches(self, segment: str) -> bool: - return self.wildcard or segment == self.value - - -def is_wildcard(segment: str) -> bool: - return segment == "*" or segment == "" - - -class PrincipalTree: - def __init__(self) -> None: - self.partitions: List[ArnPartition] = [] - - def _add_allow_all_wildcard(self, principal_arn: str) -> None: - partition = next((p for p in self.partitions if p.value == "*"), None) - if not partition: - partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) - self.partitions.append(partition) - - partition.principal_arns.add(principal_arn) - - def _add_resource(self, resource_constraint: str, principal_arn: str, nr: bool = False) -> None: - """ - _add resource will add the principal arn at the resource level - """ - - try: - arn = ARN(resource_constraint) - # Find existing or create partition - partition = next((p for p in self.partitions if p.value == arn.partition), None) - if not partition: - partition = ArnPartition(value=arn.partition, wildcard=False, principal_arns=set(), children=[]) - self.partitions.append(partition) - - # Find or create service - service = next((s for s in partition.children if s.value == arn.service_prefix), None) - if not service: - service = ArnService(value=arn.service_prefix, principal_arns=set(), children=[]) - partition.children.append(service) - - # Find or create region - region_wildcard = arn.region == "*" or not arn.region - region = next((r for r in service.children if r.value == (arn.region or "*")), None) - if not region: - region = ArnRegion(value=arn.region or "*", wildcard=region_wildcard, principal_arns=set(), children=[]) - service.children.append(region) - - # Find or create account - account_wildcard = arn.account == "*" or not arn.account - account = next((a for a in region.children if a.value == (arn.account or "*")), None) - if not account: - account = ArnAccountId( - value=arn.account or "*", wildcard=account_wildcard, principal_arns=set(), children=[] - ) - region.children.append(account) - - # Add resource - resource = next( - (r for r in account.children if r.value == arn.resource_string and r.not_resource == nr), None - ) - if not resource: - if arn.resource_string == "*": - resource_kind = ArnResourceValueKind.Any - elif "*" in arn.resource_string: - resource_kind = ArnResourceValueKind.Pattern - else: - resource_kind = ArnResourceValueKind.Static - resource = ArnResource( - value=arn.resource_string, principal_arns=set(), kind=resource_kind, not_resource=nr - ) - account.children.append(resource) - - resource.principal_arns.add(principal_arn) - - except Exception as e: - log.error(f"Error parsing ARN {principal_arn}: {e}") - pass - - def _add_service(self, service_prefix: str, principal_arn: str) -> None: - # Find existing or create partition - partition = next((p for p in self.partitions if p.value == "*"), None) - if not partition: - partition = ArnPartition(value="*", wildcard=True, principal_arns=set(), children=[]) - self.partitions.append(partition) - - # Find or create service - service = next((s for s in partition.children if s.value == service_prefix), None) - if not service: - service = ArnService(value=service_prefix, principal_arns=set(), children=[]) - partition.children.append(service) - - service.principal_arns.add(principal_arn) - - def add_principal(self, principal_arn: str, policy_documents: List[FixPolicyDocument]) -> None: - """ - This method iterates over every policy statement and adds corresponding arns to principal tree. - """ - - for policy_doc in policy_documents: - for statement in policy_doc.fix_statements: - if statement.effect_allow: - has_wildcard_resource = False - for resource in statement.resources: - if resource == "*": - has_wildcard_resource = True - continue - self._add_resource(resource, principal_arn) - for not_resource in statement.not_resource: - self._add_resource(not_resource, principal_arn, nr=True) - - if has_wildcard_resource or (not statement.resources and not statement.not_resource): - for ap in statement.actions_patterns: - if ap.kind == WildcardKind.any: - self._add_allow_all_wildcard(principal_arn) - self._add_service(ap.service, principal_arn) - - def list_principals(self, resource_arn: ARN) -> Set[str]: - """ - this will be called for every resource and it must be fast - """ - principals: Set[str] = set() - - matching_partitions = [p for p in self.partitions if p.value if p.matches(resource_arn.partition)] - if not matching_partitions: - return principals - - matching_services = [ - s for p in matching_partitions for s in p.children if s.matches(resource_arn.service_prefix) - ] - if not matching_services: - return principals - principals.update([arn for s in matching_services for arn in s.principal_arns]) - - matching_regions = [r for s in matching_services for r in s.children if r.matches(resource_arn.region)] - if not matching_regions: - return principals - principals.update([arn for r in matching_regions for arn in r.principal_arns]) - - matching_account_ids = [a for r in matching_regions for a in r.children if r.matches(resource_arn.account)] - if not matching_account_ids: - return principals - principals.update([arn for a in matching_account_ids for arn in a.principal_arns]) - - matching_resources = [ - r for a in matching_account_ids for r in a.children if r.matches(resource_arn.resource_string) - ] - if not matching_resources: - return principals - - principals.update([arn for r in matching_resources for arn in r.principal_arns]) - - return principals - - -@frozen(slots=True) -class ResourceWildcardPattern: - raw_value: str - partition: str | None # None in case the whole string is "*" - service: str - region: str - region_value_kind: ArnResourceValueKind - account: str - account_value_kind: ArnResourceValueKind - resource: str - resource_value_kind: ArnResourceValueKind - - @staticmethod - def from_str(value: str) -> "ResourceWildcardPattern": - if value == "*": - return ResourceWildcardPattern( - raw_value=value, - partition=None, - service="*", - region="*", - region_value_kind=ArnResourceValueKind.Any, - account="*", - account_value_kind=ArnResourceValueKind.Any, - resource="*", - resource_value_kind=ArnResourceValueKind.Any, - ) - - try: - splitted = value.split(":", 5) - if len(splitted) != 6: - raise ValueError(f"Invalid resource pattern: {value}") - _, partition, service, region, account, resource = splitted - - return ResourceWildcardPattern( - raw_value=value, - partition=partition, - service=service, - region=region, - region_value_kind=ArnResourceValueKind.from_str(region), - account=account, - account_value_kind=ArnResourceValueKind.from_str(account), - resource=resource, - resource_value_kind=ArnResourceValueKind.from_str(resource), - ) - except Exception as e: - log.error(f"Error parsing resource pattern {value}: {e}") - raise e - - @frozen(slots=True) class IamRequestContext: principal: AwsResource diff --git a/plugins/aws/fix_plugin_aws/access_edges/types.py b/plugins/aws/fix_plugin_aws/access_edges/types.py new file mode 100644 index 0000000000..ca9e83e2ca --- /dev/null +++ b/plugins/aws/fix_plugin_aws/access_edges/types.py @@ -0,0 +1,117 @@ +from enum import Enum +from attr import frozen +from cloudsplaining.scan.policy_document import PolicyDocument +from cloudsplaining.scan.statement_detail import StatementDetail +from fixlib.types import Json +import logging + + +log = logging.getLogger("fix.plugins.aws") + + +class WildcardKind(Enum): + fixed = 1 + pattern = 2 + any = 3 + + +@frozen(slots=True) +class ActionWildcardPattern: + pattern: str + service: str + kind: WildcardKind + + +class ArnResourceValueKind(Enum): + Static = 1 # the segment is a fixed value, e.g. "s3", "vpc/vpc-0e9801d129EXAMPLE", + Pattern = 2 # the segment is a pattern, e.g. "my_corporate_bucket/*", + Any = 3 # the segment is missing, e.g. "::" or it is a wildcard, e.g. "*" + + @staticmethod + def from_str(value: str) -> "ArnResourceValueKind": + if value == "*": + return ArnResourceValueKind.Any + if "*" in value: + return ArnResourceValueKind.Pattern + return ArnResourceValueKind.Static + + +@frozen(slots=True) +class ResourceWildcardPattern: + raw_value: str + partition: str | None # None in case the whole string is "*" + service: str + region: str + region_value_kind: ArnResourceValueKind + account: str + account_value_kind: ArnResourceValueKind + resource: str + resource_value_kind: ArnResourceValueKind + + @staticmethod + def from_str(value: str) -> "ResourceWildcardPattern": + if value == "*": + return ResourceWildcardPattern( + raw_value=value, + partition=None, + service="*", + region="*", + region_value_kind=ArnResourceValueKind.Any, + account="*", + account_value_kind=ArnResourceValueKind.Any, + resource="*", + resource_value_kind=ArnResourceValueKind.Any, + ) + + try: + splitted = value.split(":", 5) + if len(splitted) != 6: + raise ValueError(f"Invalid resource pattern: {value}") + _, partition, service, region, account, resource = splitted + + return ResourceWildcardPattern( + raw_value=value, + partition=partition, + service=service, + region=region, + region_value_kind=ArnResourceValueKind.from_str(region), + account=account, + account_value_kind=ArnResourceValueKind.from_str(account), + resource=resource, + resource_value_kind=ArnResourceValueKind.from_str(resource), + ) + except Exception as e: + log.error(f"Error parsing resource pattern {value}: {e}") + raise e + + +class FixStatementDetail(StatementDetail): + def __init__(self, statement: Json): + super().__init__(statement) + + def pattern_from_action(action: str) -> ActionWildcardPattern: + if action == "*": + return ActionWildcardPattern(pattern=action, service="*", kind=WildcardKind.any) + + action = action.lower() + service, action_name = action.split(":", 1) + if action_name == "*": + kind = WildcardKind.any + elif "*" in action_name: + kind = WildcardKind.pattern + else: + kind = WildcardKind.fixed + + return ActionWildcardPattern(pattern=action, service=service, kind=kind) + + self.actions_patterns = [pattern_from_action(action) for action in self.actions] + self.not_action_patterns = [pattern_from_action(action) for action in self.not_action] + self.resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.resources] + self.not_resource_patterns = [ResourceWildcardPattern.from_str(resource) for resource in self.not_resource] + + +class FixPolicyDocument(PolicyDocument): + def __init__(self, policy_document: Json): + super().__init__(policy_document) + + self.fix_statements = [FixStatementDetail(statement.json) for statement in self.statements] diff --git a/plugins/aws/fix_plugin_aws/collector.py b/plugins/aws/fix_plugin_aws/collector.py index 03a08346c5..0fb21eda0c 100644 --- a/plugins/aws/fix_plugin_aws/collector.py +++ b/plugins/aws/fix_plugin_aws/collector.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone from typing import List, Type, Optional, Union, cast, Any -from fix_plugin_aws.access_edges import AccessEdgeCreator +from fix_plugin_aws.access_edges.edge_builder import AccessEdgeCreator from fix_plugin_aws.aws_client import AwsClient from fix_plugin_aws.configuration import AwsConfig from fix_plugin_aws.resource import ( diff --git a/plugins/aws/test/acccess_edges_test.py b/plugins/aws/test/acccess_edges_test.py index cd15c1342d..4650a73bde 100644 --- a/plugins/aws/test/acccess_edges_test.py +++ b/plugins/aws/test/acccess_edges_test.py @@ -6,7 +6,7 @@ from policy_sentry.util.arns import ARN import re -from fix_plugin_aws.access_edges import ( +from fix_plugin_aws.access_edges.edge_builder import ( find_allowed_action, make_resoruce_regex, check_statement_match, @@ -14,13 +14,12 @@ IamRequestContext, check_explicit_deny, compute_permissions, - FixPolicyDocument, - FixStatementDetail, ActionToCheck, get_actions_matching_arn, - PrincipalTree, - ArnResourceValueKind, ) +from fix_plugin_aws.access_edges.types import FixPolicyDocument, FixStatementDetail, ArnResourceValueKind + +from fix_plugin_aws.access_edges.arn_tree import ArnTree from fixlib.baseresources import PolicySourceKind, PolicySource, PermissionLevel from fixlib.json import to_json_str @@ -1023,7 +1022,7 @@ def test_compute_permissions_role_inline_policy_allow() -> None: def test_principal_tree_add_allow_all_wildcard() -> None: """Test adding wildcard (*) permission to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" tree._add_allow_all_wildcard(principal_arn) @@ -1031,14 +1030,14 @@ def test_principal_tree_add_allow_all_wildcard() -> None: # Verify the wildcard partition exists assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "*" + assert partition.key == "*" assert partition.wildcard is True - assert principal_arn in partition.principal_arns + assert principal_arn in partition.values def test_principal_tree_add_resource() -> None: """Test adding a resource ARN to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/my-object" @@ -1047,38 +1046,38 @@ def test_principal_tree_add_resource() -> None: # Verify the partition structure assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "aws" + assert partition.key == "aws" assert not partition.wildcard # Verify service level assert len(partition.children) == 1 service = partition.children[0] - assert service.value == "s3" + assert service.key == "s3" # Verify region level assert len(service.children) == 1 region = service.children[0] - assert region.value == "*" + assert region.key == "*" assert region.wildcard # Verify account level assert len(region.children) == 1 account = region.children[0] - assert account.value == "*" + assert account.key == "*" assert account.wildcard # Verify resource level assert len(account.children) == 1 resource = account.children[0] - assert resource.value == "my-bucket/my-object" + assert resource.key == "my-bucket/my-object" assert resource.kind == ArnResourceValueKind.Static - assert principal_arn in resource.principal_arns + assert principal_arn in resource.values assert not resource.not_resource def test_principal_tree_add_resource_with_wildcard() -> None: """Test adding a resource ARN with wildcards to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/*" @@ -1091,14 +1090,14 @@ def test_principal_tree_add_resource_with_wildcard() -> None: account = region.children[0] resource = account.children[0] - assert resource.value == "my-bucket/*" + assert resource.key == "my-bucket/*" assert resource.kind == ArnResourceValueKind.Pattern - assert principal_arn in resource.principal_arns + assert principal_arn in resource.values def test_principal_tree_add_not_resource() -> None: """Test adding a NotResource ARN to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" resource_arn = "arn:aws:s3:::my-bucket/private/*" @@ -1115,7 +1114,7 @@ def test_principal_tree_add_not_resource() -> None: def test_principal_tree_add_service() -> None: """Test adding a service to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" service_prefix = "s3" @@ -1124,17 +1123,17 @@ def test_principal_tree_add_service() -> None: # Verify service is added under wildcard partition assert len(tree.partitions) == 1 partition = tree.partitions[0] - assert partition.value == "*" + assert partition.key == "*" assert len(partition.children) == 1 service = partition.children[0] - assert service.value == "s3" - assert principal_arn in service.principal_arns + assert service.key == "s3" + assert principal_arn in service.values def test_principal_tree_add_principal_policy() -> None: """Test adding a principal with policy documents to the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_json = { @@ -1146,16 +1145,16 @@ def test_principal_tree_add_principal_policy() -> None: } policy_doc = FixPolicyDocument(policy_json) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Verify both the specific resource and wildcard permissions are added assert any( - p.value == "aws" + p.key == "aws" and any( - s.value == "s3" + s.key == "s3" and any( - r.value == "*" - and any(a.value == "*" and any(res.value == "my-bucket/*" for res in a.children) for a in r.children) + r.key == "*" + and any(a.key == "*" and any(res.key == "my-bucket/*" for res in a.children) for a in r.children) for r in s.children ) for s in p.children @@ -1166,7 +1165,7 @@ def test_principal_tree_add_principal_policy() -> None: def test_principal_tree_list_principals() -> None: """Test listing principals that have access to a given ARN.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal1 = "arn:aws:iam::123456789012:user/test-user1" principal2 = "arn:aws:iam::123456789012:user/test-user2" @@ -1185,12 +1184,12 @@ def test_principal_tree_list_principals() -> None: } ) - tree.add_principal(principal1, [policy_doc1]) - tree.add_principal(principal2, [policy_doc2]) + tree.add_element(principal1, [policy_doc1]) + tree.add_element(principal2, [policy_doc2]) # Test specific resource access resource_arn = ARN("arn:aws:s3:::my-bucket/test.txt") - matching_principals = tree.list_principals(resource_arn) + matching_principals = tree.find_matching_values(resource_arn) assert principal1 in matching_principals # Has specific access assert principal2 in matching_principals # Has wildcard access @@ -1198,7 +1197,7 @@ def test_principal_tree_list_principals() -> None: def test_principal_tree_add_multiple_statements() -> None: """Test adding multiple statements for the same principal.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_doc = FixPolicyDocument( @@ -1211,19 +1210,19 @@ def test_principal_tree_add_multiple_statements() -> None: } ) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Test access to both buckets bucket1_arn = ARN("arn:aws:s3:::bucket1/test.txt") bucket2_arn = ARN("arn:aws:s3:::bucket2/test.txt") - assert principal_arn in tree.list_principals(bucket1_arn) - assert principal_arn in tree.list_principals(bucket2_arn) + assert principal_arn in tree.find_matching_values(bucket1_arn) + assert principal_arn in tree.find_matching_values(bucket2_arn) def test_principal_tree_not_resource() -> None: """Test NotResource handling in the principal tree.""" - tree = PrincipalTree() + tree = ArnTree[str]() principal_arn = "arn:aws:iam::123456789012:user/test-user" policy_doc = FixPolicyDocument( @@ -1235,18 +1234,18 @@ def test_principal_tree_not_resource() -> None: } ) - tree.add_principal(principal_arn, [policy_doc]) + tree.add_element(principal_arn, [policy_doc]) # Test access is denied to private bucket private_arn = ARN("arn:aws:s3:::private-bucket/secret.txt") public_arn = ARN("arn:aws:s3:::public-bucket/public.txt") ec2 = ARN("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0") - matching_principals = tree.list_principals(private_arn) + matching_principals = tree.find_matching_values(private_arn) assert principal_arn not in matching_principals - matching_principals = tree.list_principals(public_arn) + matching_principals = tree.find_matching_values(public_arn) assert principal_arn in matching_principals - matching_principals = tree.list_principals(ec2) + matching_principals = tree.find_matching_values(ec2) assert len(matching_principals) == 0