diff --git a/docs/user/lib_use_cases_acl.md b/docs/user/lib_use_cases_acl.md index 49cfe583..eece36f8 100644 --- a/docs/user/lib_use_cases_acl.md +++ b/docs/user/lib_use_cases_acl.md @@ -82,7 +82,7 @@ The `ACLRule` class at a high level: ### Initialization & Loading Data -The initialization process calls on the `load_data` method. This on a high level verifies schema of initial data, allows you to process data (e.g. convert tcp/https -> 80/443), expand data, determine Cartesian product (or permutations) of the firewall rule (traditionally 5-tuple), and verifies schema of result data. +`ACLRule` initialization process calls on the `__init__` method. This on a high level verifies schema of initial data, allows you to process data (e.g. convert tcp/https -> 80/443), expand data, determine Cartesian product (or permutations) of the firewall rule (traditionally 5-tuple), and verifies schema of result data. The Cartesian product (or permutations) is key to the functionality of other steps, this allows you to evaluate each rule based on the smallest view of the data, so pay close attention to those steps, as it is important to other methods as well. @@ -100,14 +100,13 @@ Many of validations will be based on IPs, but not all. Here you will find a written understanding of what is happening in the code: -- The init method takes in data and calls `load_data`. -- The method `load_data` processes the input data. +- The init method takes in data as key-value arguments and calls `_load_data`. +- The method `_load_data` processes the input data. - The `input_data_check` method is called and verifies the input data based on the specified JSON schema. - This is controlled by the `input_data_verify` attribute and schema defined in `input_data_schema`. - - For each `self.attrs`, a method name matching `f"process_{attr}"`, (e.g. `process_src_ip()`) is called. + - For each `self.attrs`, a method name matching `f"process_{attr}"`, (e.g. `process_src_ip()`) is called for `src_ip` attribute. - This allows you to inherit from and provide your own custom processes to convert, expand, or otherwise modify data before being evaluated. - - The `process_dst_port` method processes the `dst_port` attribute by converting the protocol and port information, it is enabled by default but controlled with the `dst_port_process` attribute. - - Both a dictionary `self.processed` and attributes (e.g. self.action, self.src_ip, etc.) are created. + - Dictionaries `self._preprocessed_data`, `self._processed_data` and attributes (e.g. `self.action`, `self.src_ip`, etc.) are created. - The `result_data_check` method verifies the processed data based on the specified JSON schema. - This is controlled by the `result_data_verify` attribute which is disabled by default. - The `validate` method validating the rule using a series of custom methods starting with `validate_` prefixes. @@ -146,7 +145,7 @@ While not accurate in all use cases it would be best practice to run any of your ### Match & Match Details -The `match_details` method provides a verbose way of verifying match details between two ACL rule's, the `match` method uses `match_details` and provides a boolean if there are any rules in `rules_unmatched` which would tell you if you had a full match or not. We will only review in detail the `match_details`. +The `match_details` method provides a verbose way of verifying match details between two ACL rule's, the `match` method uses `match_details` and provides a boolean if there are any rules in `products_unmatched` which would tell you if you had a full match or not. We will only review in detail the `match_details`. Here you will find a written understanding of what is happening in the code: @@ -155,13 +154,13 @@ Here you will find a written understanding of what is happening in the code: - This allows you to inherit from and provide your own custom equality check or verify with your business logic. - You do not need to have a `f"match_{attr}"` method for every attr, description as example would not be a good candidate to match on. - Equality checks are done on `src_zone`, `dst_zone`, `action`, and `port` by default. - - An `is_ip_within` check is done with for `src_ip` and `dst_ip` by default. + - Equality checks and an `is_ip_within` check is done with for `src_ip` and `dst_ip` by default. - In the process, details are provided for and returned: - - `rules_matched` - Root key that is a list of dictionaries of rules that matched. - - `rules_unmatched` - Root key that is a list of dictionaries of rules that did not match. - - `existing_rule_product` - The original expanded_rule that existed in this item. + - `match` - Root key, of type `bool`, that indicates whether a rule was matched. - `existing_rule` - The full original rule (not expanded_rule) that existed. - - `match_rule` - The full original rule that tested against, only shown in `rules_matched` root key. + - `match_rule` - The full original rule that tested against. + - `products_matched` - Root key that is a list of dictionaries of match products that matched. + - `products_unmatched` - Root key that is a list of dictionaries of match products that did not match. This data could help you to understand what matched, why it matched, and other metadata. This detail data can be used to in `ACLRules` to aggregate and ask more interesting questions. @@ -183,21 +182,21 @@ Here we can test if a rule is matched via the existing ruleset. We can leverage **Simple Example** ```python ->>> from netutils.acl import ACLRules +>>> from netutils.acl import ACLRules, ACLRule >>> >>> existing_acls = [ ... dict( ... name="Allow to internal web", ... src_ip=["192.168.0.0/24", "10.0.0.0/16"], ... dst_ip=["172.16.0.0/16", "192.168.250.10-192.168.250.20"], -... dst_port=["tcp/80", "udp/53"], +... dst_port=["80", "53"], ... action="permit", ... ), ... dict( ... name="Allow to internal dns", ... src_ip=["192.168.1.0/24"], ... dst_ip=["172.16.0.0/16"], -... dst_port=["tcp/80", "udp/53"], +... dst_port=["80", "53"], ... action="permit", ... ) ... ] @@ -206,24 +205,24 @@ Here we can test if a rule is matched via the existing ruleset. We can leverage ... name="Check multiple sources pass", ... src_ip=["192.168.1.10", "192.168.1.11", "192.168.1.15-192.168.1.20"], ... dst_ip="172.16.0.10", -... dst_port="tcp/www-http", +... dst_port="80", ... action="permit", ... ) >>> ->>> ACLRules(existing_acls).match(new_acl_match) -'permit' +>>> ACLRules(existing_acls).match(ACLRule(**new_acl_match)) +True >>> >>> >>> new_acl_non_match = dict( ... name="Check no match", ... src_ip=["10.1.1.1"], ... dst_ip="172.16.0.10", -... dst_port="tcp/www-http", +... dst_port="80", ... action="permit", ... ) >>> ->>> ACLRules(existing_acls).match(new_acl_non_match) -'deny' +>>> ACLRules(existing_acls).match(ACLRule(**new_acl_non_match)) +False >>> ``` diff --git a/netutils/acl.py b/netutils/acl.py index 88eeba66..316cb671 100644 --- a/netutils/acl.py +++ b/netutils/acl.py @@ -2,8 +2,8 @@ import itertools import copy +import types import typing as t -from netutils.protocol_mapper import PROTO_NAME_TO_NUM, TCP_NAME_TO_NUM, UDP_NAME_TO_NUM from netutils.ip import is_ip_within try: @@ -16,12 +16,13 @@ INPUT_SCHEMA = { "type": "object", "properties": { - "name": {"type": "string"}, - "src_zone": {"type": ["string", "array"]}, + "name": {"type": ["string", "null"]}, + "src_zone": {"type": ["string", "array", "null"]}, "src_ip": {"$ref": "#/definitions/arrayOrIP"}, "dst_ip": {"$ref": "#/definitions/arrayOrIP"}, "dst_port": { "oneOf": [ + {"type": "null"}, { "$ref": "#/definitions/port", }, @@ -35,8 +36,9 @@ }, ], }, - "dst_zone": {"type": ["string", "array"]}, + "dst_zone": {"type": ["string", "array", "null"]}, "action": {"type": "string"}, + "protocol": {"type": ["string", "null"]}, }, "definitions": { "ipv4": {"type": "string", "pattern": "^(?:\\d{1,3}\\.){3}\\d{1,3}$"}, @@ -83,14 +85,25 @@ }, ], }, - "port": {"type": "string", "pattern": "^\\S+\\/\\S+$"}, + "port": { + "oneOf": [ + { + "type": "integer", + "minimum": 0, + "maximum": 65535, + }, + { + "type": "string", + "pattern": "^\\d+$", + }, + ] + }, }, "required": [], } RESULT_SCHEMA = copy.deepcopy(INPUT_SCHEMA) -RESULT_SCHEMA["definitions"]["port"]["pattern"] = "^\\d+\\/\\d+$" # type: ignore def _cartesian_product(data: t.Dict[str, str]) -> t.List[t.Dict[str, t.Any]]: @@ -130,6 +143,8 @@ def _cartesian_product(data: t.Dict[str, str]) -> t.List[t.Dict[str, t.Any]]: keys.append(key) if isinstance(value, (str, int)): values.append([value]) + elif value is None: + values.append([None]) else: values.append(value) product = list(itertools.product(*values)) @@ -147,30 +162,67 @@ def _check_schema(data: t.Any, schema: t.Any, verify: bool) -> None: raise ValueError() +def _get_attributes(obj: t.Any) -> t.Dict[str, t.Any]: + """Function that describes class attributes.""" + result = { + attr: getattr(obj, attr) + for attr in dir(obj) + if not attr.startswith("_") + and not callable(getattr(obj, attr)) + and not isinstance(getattr(obj, attr), types.FunctionType) + } + return result + + +def _get_match_funcs(obj: t.Any) -> t.Dict[str, t.Any]: + """Returns {'attr': match_attr_funct, ...} dict.""" + attrs = {} + for attr_name in dir(obj): + if attr_name.startswith("match_") and attr_name not in ["match_details"]: + match_name = attr_name[len("match_") :] # noqa: E203 + # When an attribute is not defined, can skip it + if not hasattr(obj, match_name): + continue + attrs[match_name] = getattr(obj, attr_name) + + return attrs + + class ACLRule: """A class that helps you imagine an acl rule via methodologies.""" - attrs: t.List[str] = ["name", "src_ip", "src_zone", "dst_ip", "dst_port", "dst_zone", "action"] - permit: str = "permit" - deny: str = "deny" + name: t.Any = None + src_ip: t.Any = None + src_zone: t.Any = None + dst_ip: t.Any = None + dst_port: t.Any = None + dst_zone: t.Any = None + protocol: t.Any = None + action: t.Any = None - input_data_verify: bool = False - input_data_schema: t.Any = INPUT_SCHEMA + class Meta: # pylint: disable=too-few-public-methods + """Default meta class.""" - result_data_verify: bool = False - result_data_schema: t.Any = RESULT_SCHEMA + permit: str = "permit" + deny: str = "deny" - matrix: t.Any = {} - matrix_enforced: bool = False - matrix_definition: t.Any = {} + input_data_verify: bool = False + input_data_schema: t.Any = INPUT_SCHEMA - dst_port_process: bool = True + result_data_verify: bool = False + result_data_schema: t.Any = RESULT_SCHEMA - order_validate: t.List[str] = [] - order_enforce: t.List[str] = [] - filter_same_ip: bool = True + matrix: t.Any = {} + matrix_enforced: bool = False + matrix_definition: t.Any = {} - def __init__(self, data: t.Any, *args: t.Any, **kwargs: t.Any): # pylint: disable=unused-argument + dst_port_process: bool = False + + order_validate: t.List[str] = [] + order_enforce: t.List[str] = [] + filter_same_ip: bool = True + + def __init__(self, **kwargs: t.Any) -> None: # pylint: disable=unused-argument """Initialize and load data. Args: @@ -179,55 +231,87 @@ def __init__(self, data: t.Any, *args: t.Any, **kwargs: t.Any): # pylint: disab Examples: >>> from netutils.acl import ACLRule >>> - >>> acl_data = dict( + >>> rule = ACLRule( ... name="Check no match", ... src_ip=["10.1.1.1"], + ... src_zone="internal", ... dst_ip="172.16.0.10", - ... dst_port="tcp/www-http", + ... dst_port="80", + ... dst_zone="external", + ... protocol='tcp', ... action="permit", ... ) >>> - >>> rule = ACLRule(acl_data) >>> >>> rule.expanded_rules - [{'name': 'Check no match', 'src_ip': '10.1.1.1', 'dst_ip': '172.16.0.10', 'dst_port': '6/80', 'action': 'permit'}] + [{'name': 'Check no match', 'src_ip': '10.1.1.1', 'src_zone': 'internal', 'dst_ip': '172.16.0.10', 'dst_port': '80', 'dst_zone': 'external', 'protocol': 'tcp', 'action': 'permit'}] >>> """ - self.processed: t.Dict[str, str] = {} - self.data = data - self.load_data() + self._load_data(kwargs=kwargs) - def load_data(self) -> None: + def _load_data(self, kwargs: t.Dict[str, t.Any]) -> None: """Load the data into the rule while verifying input data, result data, and processing data.""" + # Remaining kwargs stored under ACLRule.Meta + pop_kwargs = [] + for key, val in kwargs.items(): + if key not in _get_attributes(self): + setattr(self.Meta, key, val) + pop_kwargs.append(key) + + # Pop unneeded keys + for key in pop_kwargs: + kwargs.pop(key) + + # Ensure each class attr is in init kwargs. + for attr in _get_attributes(self): + if attr not in kwargs: + kwargs[attr] = getattr(self, attr) + + # Store the init input + self._preprocessed_data = copy.deepcopy(kwargs) + self._processed_data = copy.deepcopy(self._preprocessed_data) + + # Input check self.input_data_check() - for attr in self.attrs: - if not self.data.get(attr): - continue - if hasattr(self, f"process_{attr}"): - proccessor = getattr(self, f"process_{attr}") - _attr_data = proccessor(self.data[attr]) + + for attr in _get_attributes(self): + processor_func = getattr(self, f"process_{attr}", None) + if processor_func: + _attr_data = processor_func(self._processed_data[attr]) else: - _attr_data = self.data[attr] - self.processed[attr] = _attr_data + _attr_data = self._processed_data[attr] + + self._processed_data[attr] = _attr_data setattr(self, attr, _attr_data) + self.result_data_check() self.validate() - self.expanded_rules = _cartesian_product(self.processed) - if self.filter_same_ip: - self.expanded_rules = [item for item in self.expanded_rules if item["dst_ip"] != item["src_ip"]] + self._set_expanded_rules() + + def _set_expanded_rules(self) -> None: + """Expanded rule setter.""" + _expanded_rules = _cartesian_product(self._processed_data) + if self.Meta.filter_same_ip: + _expanded_rules = [ + item + for item in _expanded_rules + if (item["dst_ip"] != item["src_ip"]) or (item["dst_ip"] is None and item["src_ip"] is None) + ] + + self.expanded_rules = _expanded_rules # pylint: disable=attribute-defined-outside-init def input_data_check(self) -> None: """Verify the input data against the specified JSONSchema or using a simple dictionary check.""" - return _check_schema(self.data, self.input_data_schema, self.input_data_verify) + return _check_schema(self._preprocessed_data, self.Meta.input_data_schema, self.Meta.input_data_verify) def result_data_check(self) -> None: """Verify the result data against the specified JSONSchema or using a simple dictionary check.""" - return _check_schema(self.processed, self.result_data_schema, self.result_data_verify) + return _check_schema(self._processed_data, self.Meta.result_data_schema, self.Meta.result_data_verify) def validate(self) -> t.Any: """Run through any method that startswith('validate_') and run that method.""" - if self.order_validate: - method_order = self.order_validate + if self.Meta.order_validate: + method_order = self.Meta.order_validate else: method_order = dir(self) results = [] @@ -242,52 +326,14 @@ def validate(self) -> t.Any: results.extend(result) return results - def process_dst_port( - self, dst_port: t.Any - ) -> t.Union[t.List[str], None]: # pylint: disable=inconsistent-return-statements - """Convert port and protocol information. - - Method supports a single format of `{protocol}/{port}`, and will translate the - protocol for all IANA defined protocols. The port will be translated for TCP and - UDP ports only. For all other protocols should use port of 0, e.g. `ICMP/0` for ICMP - or `50/0` for ESP. Similarly, IANA defines the port mappings, while these are mostly - staying unchanged, but sourced from - https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.csv. - """ - output = [] - if not self.dst_port_process: - return None - if not isinstance(dst_port, list): - dst_port = [dst_port] - for item in dst_port: - protocol = item.split("/")[0] - port = item.split("/")[1] - if protocol.isalpha(): - if not PROTO_NAME_TO_NUM.get(protocol.upper()): - raise ValueError( - f"Protocol {protocol} was not found in netutils.protocol_mapper.PROTO_NAME_TO_NUM." - ) - protocol = PROTO_NAME_TO_NUM[protocol.upper()] - # test port[0] vs port, since dashes do not count, e.g. www-http - if int(protocol) == 6 and port[0].isalpha(): - if not TCP_NAME_TO_NUM.get(port.upper()): - raise ValueError(f"Port {port} was not found in netutils.protocol_mapper.TCP_NAME_TO_NUM.") - port = TCP_NAME_TO_NUM[port.upper()] - if int(protocol) == 17 and port[0].isalpha(): - if not UDP_NAME_TO_NUM.get(port.upper()): - raise ValueError(f"Port {port} was not found in netutils.protocol_mapper.UDP_NAME_TO_NUM.") - port = UDP_NAME_TO_NUM[port.upper()] - output.append(f"{protocol}/{port}") - return output - def enforce(self) -> t.List[t.Dict[str, t.Any]]: """Run through any method that startswith('enforce_') and run that method. Returns: A list of dictionaries that explains the results of the enforcement. """ - if self.order_enforce: - method_order = self.order_enforce + if self.Meta.order_enforce: + method_order = self.Meta.order_enforce else: method_order = dir(self) results = [] @@ -308,11 +354,11 @@ def enforce_matrix(self) -> t.Union[t.List[t.Dict[str, t.Any]], None]: Returns: A list of dictionaries that explains the results of the matrix being enforced. """ - if not self.matrix_enforced: + if not self.Meta.matrix_enforced: return None - if not self.matrix: + if not self.Meta.matrix: raise ValueError("You must set a matrix dictionary to use the matrix feature.") - if not self.matrix_definition: + if not self.Meta.matrix_definition: raise ValueError("You must set a matrix definition dictionary to use the matrix feature.") actions = [] for rule in self.expanded_rules: @@ -322,14 +368,14 @@ def enforce_matrix(self) -> t.Union[t.List[t.Dict[str, t.Any]], None]: src_zone = "" dst_zone = "" as_tuple = (source, destination, port) - for zone, ips in self.matrix_definition.items(): + for zone, ips in self.Meta.matrix_definition.items(): if is_ip_within(source, ips): src_zone = zone if is_ip_within(destination, ips): dst_zone = zone - if port in self.matrix.get(src_zone, {}).get(dst_zone, {}).get("allow", []): + if port in self.Meta.matrix.get(src_zone, {}).get(dst_zone, {}).get("allow", []): actions.append({"obj": as_tuple, "action": "allow"}) - elif port in self.matrix.get(src_zone, {}).get(dst_zone, {}).get("notify", []): + elif port in self.Meta.matrix.get(src_zone, {}).get(dst_zone, {}).get("notify", []): actions.append({"obj": as_tuple, "action": "notify"}) else: actions.append({"obj": as_tuple, "action": "deny"}) @@ -357,6 +403,9 @@ def match_src_ip(self, existing_ip: str, check_ip: str) -> bool: Returns: True if `check_ip` is within the range of `existing_ip`, False otherwise. """ + if existing_ip == check_ip: # None cases + return True + return is_ip_within(check_ip, existing_ip) def match_src_zone(self, existing_src_zone: str, check_src_zone: str) -> bool: @@ -381,6 +430,9 @@ def match_dst_ip(self, existing_ip: str, check_ip: str) -> bool: Returns: True if `check_ip` is within the range of `existing_ip`, False otherwise. """ + if existing_ip == check_ip: # None cases + return True + return is_ip_within(check_ip, existing_ip) def match_dst_zone(self, existing_dst_zone: str, check_dst_zone: str) -> bool: @@ -416,51 +468,34 @@ def match_details(self, match_rule: "ACLRule") -> t.Dict[str, t.Any]: # pylint: Returns: A dictionary with root keys of `rules_matched` and `rules_matched`. """ - attrs = [] - for name in dir(self): - if name.startswith("match_"): - obj_name = name[len("match_") :] # noqa: E203 - # When an attribute is not defined, can skip it - if not hasattr(match_rule, obj_name): - continue - attrs.append(obj_name) - - rules_found: t.List[bool] = [] - rules_unmatched: t.List[t.Dict[str, t.Any]] = [] - rules_matched: t.List[t.Dict[str, t.Any]] = [] + products_matched: t.List[t.Dict[str, t.Any]] = [] + products_unmatched: t.List[t.Dict[str, t.Any]] = [] - if not match_rule.expanded_rules: + if not match_rule.expanded_rules: # pylint: disable=protected-access raise ValueError("There is no expanded rules to test against.") - for rule in match_rule.expanded_rules: - rules_found.append(False) - for existing_rule in self.expanded_rules: - missing = False - for attr in attrs: - # Examples of obj are match_rule.src_ip, match_rule.dst_port - rule_value = rule[attr] - existing_value = existing_rule[attr] - # Examples of getter are self.match_src_ip, self.match_dst_port - getter = getattr(self, f"match_{attr}")(existing_value, rule_value) - if not getter and getter is not None: - missing = True - break - # If the loop gets through with each existing rule not flagging - # the `missing` value, we know everything was matched, and the rule has - # found a complete match, we can break out of the loop at this point. - if not missing: - rules_found[-1] = True - break - detailed_info = { - "existing_rule_product": existing_rule, # pylint: disable=undefined-loop-variable - "match_rule": match_rule.processed, - "existing_rule": self.processed, - } - if rules_found[-1]: - detailed_info["match_rule_product"] = rule - rules_matched.append(detailed_info) + + if not self.expanded_rules: # pylint: disable=protected-access + raise ValueError("There is no expanded rules to test.") + + for match_product in match_rule.expanded_rules: # pylint: disable=protected-access + for existing_product in self.expanded_rules: + # Break if we find match_product in existing_product (all matchers returned True) + if all( + attr_func(existing_product[attr_name], match_product[attr_name]) + for attr_name, attr_func in _get_match_funcs(self).items() + ): + products_matched.append(match_product) + break # Do not compare remaining existing products. else: - rules_unmatched.append(detailed_info) - return {"rules_matched": rules_matched, "rules_unmatched": rules_unmatched} + products_unmatched.append(match_product) + + return { + "match": not bool(products_unmatched), + "existing_rule": self.serialize(), + "match_rule": match_rule.serialize(), + "products_matched": products_matched, + "products_unmatched": products_unmatched, + } def match(self, match_rule: "ACLRule") -> bool: """Simple boolean way of verifying match or not. @@ -472,21 +507,27 @@ def match(self, match_rule: "ACLRule") -> bool: A boolean if there was a full match or not. """ details = self.match_details(match_rule) - return not bool(details["rules_unmatched"]) + + return details["match"] # type: ignore def __repr__(self) -> str: """Set repr of the object to be sane.""" - output = [] - for attr in self.attrs: - if self.processed.get(attr): - output.append(f"{attr}: {self.processed[attr]}") - return ", ".join(output) + return self.name or "Name is not set" + + def serialize(self) -> t.Dict[str, t.Any]: + """Primitive Serializer.""" + return {k: v for k, v in self.__dict__.items() if not (k.startswith("_") or k == "expanded_rules")} class ACLRules: """Class to help match multiple ACLRule objects.""" - class_obj = ACLRule + rules: t.List[t.Any] = [] + + class Meta: # pylint: disable=too-few-public-methods + """Default meta class.""" + + class_obj = ACLRule def __init__(self, data: t.Any, *args: t.Any, **kwargs: t.Any): # pylint: disable=unused-argument """Class to help match multiple ACLRule. @@ -494,16 +535,19 @@ def __init__(self, data: t.Any, *args: t.Any, **kwargs: t.Any): # pylint: disab Args: data: A list of `ACLRule` rules. """ - self.data: t.Any = data self.rules: t.List[t.Any] = [] - self.load_data() + self.load_data(data=data) - def load_data(self) -> None: + def load_data(self, data: t.Any) -> None: """Load the data for multiple rules.""" - for item in self.data: - self.rules.append(self.class_obj(item)) + for item in data: + self.rules.append(self.Meta.class_obj(**item)) - def match(self, rule: ACLRule) -> str: + def serialize(self) -> t.List[t.Any]: + """Primitive Serializer.""" + return [rule.serialize() for rule in self.rules] + + def match(self, rule: ACLRule) -> bool: """Check the rules loaded in `load_data` match against a new `rule`. Args: @@ -513,9 +557,10 @@ def match(self, rule: ACLRule) -> str: The response from the rule that matched, or `deny` by default. """ for item in self.rules: - if item.match(self.class_obj(rule)): - return str(item.action) - return str(item.deny) # pylint: disable=undefined-loop-variable + if item.match(rule): + return True + + return False def match_details(self, rule: ACLRule) -> t.Any: """Verbosely check the rules loaded in `load_data` match against a new `rule`. @@ -528,5 +573,5 @@ def match_details(self, rule: ACLRule) -> t.Any: """ output = [] for item in self.rules: - output.append(item.match_details(self.class_obj(rule))) + output.append(item.match_details(rule)) return output diff --git a/tests/unit/test_acl.py b/tests/unit/test_acl.py index 2e88a29c..24bae7e9 100644 --- a/tests/unit/test_acl.py +++ b/tests/unit/test_acl.py @@ -11,60 +11,60 @@ name="Check multiple sources pass. Check conversion of non-alpha tcp, e.g. with a dash", src_ip=["192.168.1.10", "192.168.1.11", "192.168.1.15-192.168.1.20"], dst_ip="172.16.0.10", - dst_port="tcp/www-http", + dst_port="80", action="permit", ), - "received": "permit", + "received": True, }, { "sent": dict( name="Check with number in port definition", src_ip="192.168.0.10", dst_ip="192.168.250.11", - dst_port="6/80", + dst_port="80", action="permit", ), - "received": "permit", + "received": True, }, { "sent": dict( name="Check with subnets", src_ip="192.168.0.0/25", dst_ip="172.16.0.0/24", - dst_port="6/80", + dst_port="80", action="permit", ), - "received": "permit", + "received": True, }, { "sent": dict( name="Test partial match on Source IP", src_ip=["192.168.1.10", "192.168.2.10"], dst_ip="172.16.0.11", - dst_port="tcp/80", + dst_port="80", action="permit", ), - "received": "deny", + "received": False, }, { "sent": dict( name="Test an entry that is not found", src_ip="192.168.1.10", dst_ip="192.168.240.1", - dst_port="tcp/80", + dst_port="80", action="permit", ), - "received": "deny", + "received": False, }, { "sent": dict( name="Test an action not permit or deny", src_ip="10.1.1.1", dst_ip="10.255.255.255", - dst_port="tcp/443", + dst_port="443", action="permit", ), - "received": "deny", + "received": False, }, ] @@ -73,35 +73,35 @@ name="Allow to internal web", src_ip=["192.168.0.0/24", "10.0.0.0/16"], dst_ip=["172.16.0.0/16", "192.168.250.10-192.168.250.20"], - dst_port=["tcp/80", "udp/53"], + dst_port=["80", "53"], action="permit", ), dict( name="Allow to internal dns", src_ip=["192.168.1.0/24"], dst_ip=["172.16.0.0/16"], - dst_port=["tcp/80", "udp/53"], + dst_port=["80", "53"], action="permit", ), dict( name="Allow to internal https", src_ip=["10.0.0.0/8"], dst_ip=["172.16.0.0/16"], - dst_port=["tcp/443"], + dst_port=["443"], action="deny", ), dict( name="Drop (not deny) this specfic packet", src_ip="10.1.1.1", dst_ip="10.255.255.255", - dst_port="tcp/443", + dst_port="443", action="drop", ), dict( name="Allow External DNS", src_ip=["0.0.0.0/0"], dst_ip=["8.8.8.8/32", "8.8.4.4/32"], - dst_port=["udp/53"], + dst_port=["53"], action="permit", ), ] @@ -112,42 +112,42 @@ name="Check allow", src_ip="10.1.100.5", dst_ip="10.1.200.0", - dst_port="tcp/www-http", + dst_port="80", action="permit", ), - "received": [{"obj": ("10.1.100.5", "10.1.200.0", "6/80"), "action": "allow"}], + "received": [{"obj": ("10.1.100.5", "10.1.200.0", "80"), "action": "allow"}], }, { "sent": dict( name="Check Notify", src_ip="10.1.100.5", dst_ip="10.1.200.0", - dst_port="tcp/25", + dst_port="25", action="permit", ), - "received": [{"obj": ("10.1.100.5", "10.1.200.0", "6/25"), "action": "notify"}], + "received": [{"obj": ("10.1.100.5", "10.1.200.0", "25"), "action": "notify"}], }, { "sent": dict( name="Check not found and denied", src_ip="10.1.100.5", dst_ip="10.1.200.0", - dst_port="tcp/53", + dst_port="53", action="permit", ), - "received": [{"obj": ("10.1.100.5", "10.1.200.0", "6/53"), "action": "deny"}], + "received": [{"obj": ("10.1.100.5", "10.1.200.0", "53"), "action": "deny"}], }, { "sent": dict( name="Check not found and denied", src_ip=["10.1.100.5", "10.1.100.6"], dst_ip="10.1.200.0", - dst_port="tcp/80", + dst_port="80", action="permit", ), "received": [ - {"obj": ("10.1.100.5", "10.1.200.0", "6/80"), "action": "allow"}, - {"obj": ("10.1.100.6", "10.1.200.0", "6/80"), "action": "allow"}, + {"obj": ("10.1.100.5", "10.1.200.0", "80"), "action": "allow"}, + {"obj": ("10.1.100.6", "10.1.200.0", "80"), "action": "allow"}, ], }, { @@ -155,10 +155,10 @@ name="Nothing found", src_ip="1.1.1.1", dst_ip="2.2.2.2", - dst_port="tcp/53", + dst_port="53", action="permit", ), - "received": [{"obj": ("1.1.1.1", "2.2.2.2", "6/53"), "action": "deny"}], + "received": [{"obj": ("1.1.1.1", "2.2.2.2", "53"), "action": "deny"}], }, ] @@ -168,7 +168,7 @@ name="Bad IP", src_ip="10.1.100.A", dst_ip="10.1.200.0", - dst_port="tcp/www-http", + dst_port="80", action="permit", ), }, @@ -177,7 +177,7 @@ name="Bad port", src_ip="10.1.100.5", dst_ip="10.1.200.0", - dst_port="tcp25", + dst_port="5o0", action="permit", ), }, @@ -186,7 +186,7 @@ name="Bad IP in list", src_ip=["10.1.100.5", "10.1.100.A"], dst_ip="10.1.200.0", - dst_port="tcp/25", + dst_port="25", action="permit", ), }, @@ -198,7 +198,7 @@ name="Check allow", src_ip="10.1.100.1", dst_ip="10.1.200.0", - dst_port="6/www-http", + dst_port="80", action=100, ), }, @@ -211,39 +211,42 @@ } MATRIX = { - "red": {"blue": {"allow": ["6/80", "6/443"], "notify": ["6/25"]}, "orange": {"allow": ["6/80"]}}, - "blue": {"red": {"allow": ["6/80"]}}, + "red": {"blue": {"allow": ["80", "443"], "notify": ["25"]}, "orange": {"allow": ["80"]}}, + "blue": {"red": {"allow": ["80"]}}, } -class TestMatrix(acl.ACLRule): +class TestMatrixRule(acl.ACLRule): """ACLRule inherited class to test the matrix.""" - matrix = MATRIX - matrix_enforced = True - matrix_definition = IP_DEFINITIONS + class Meta(acl.ACLRule.Meta): # pylint: disable=too-few-public-methods + matrix = MATRIX + matrix_enforced = True + matrix_definition = IP_DEFINITIONS -class TestSchema(acl.ACLRule): +class TestSchemaRule(acl.ACLRule): """ACLRule inherited class to test the schema.""" - input_data_verify = True + class Meta(acl.ACLRule.Meta): # pylint: disable=too-few-public-methods + input_data_verify = True -class TestSchema2(acl.ACLRule): +class TestSchema2Rule(acl.ACLRule): """ACLRule inherited class alternate to test the schema.""" - result_data_verify = True + class Meta(acl.ACLRule.Meta): # pylint: disable=too-few-public-methods + result_data_verify = True @pytest.mark.parametrize("data", verify_acl) def test_verify_acl(data): - assert acl.ACLRules(acls).match(data["sent"]) == data["received"] + assert acl.ACLRules(acls).match(acl.ACLRule(**data["sent"])) == data["received"] @pytest.mark.parametrize("data", verify_matrix) def test_matrix(data): - assert TestMatrix(data["sent"]).enforce() == data["received"] + assert TestMatrixRule(**data["sent"]).enforce() == data["received"] @pytest.mark.parametrize("data", verify_schema) @@ -255,19 +258,19 @@ def test_schema(data): pass with pytest.raises(jsonschema.exceptions.ValidationError): - TestSchema(data["sent"]) + TestSchemaRule(**data["sent"]) def test_schema_not_enforced_when_option_not_set(): try: - acl.ACLRule(dict(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="tcp/80", action=100)) + acl.ACLRule(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="80", action=100) except Exception: # pylint: disable=broad-exception-caught assert False, "No error should have been raised" def test_schema_valid(): try: - TestSchema(dict(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="tcp/80", action="permit")) + TestSchemaRule(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="80", action="permit") except Exception: # pylint: disable=broad-exception-caught assert False, "No error should have been raised" @@ -281,12 +284,12 @@ def test_schema2(data): pass with pytest.raises(jsonschema.exceptions.ValidationError): - TestSchema2(data["sent"]).validate() + TestSchema2Rule(**data["sent"]).validate() def test_schema2_valid(): try: - TestSchema2(dict(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="tcp/80", action="permit")).validate() + TestSchema2Rule(src_ip="10.1.1.1", dst_ip="10.2.2.2", dst_port="80", action="permit").validate() except Exception: # pylint: disable=broad-exception-caught assert False, "No error should have been raised" @@ -294,13 +297,13 @@ def test_schema2_valid(): class TestAddrGroups(acl.ACLRule): """ACLRule inherited class alternate to test expansions.""" - address_groups = {"red": ["white", "blue"], "blue": ["cyan"], "yellow": ["orange"]} + def __init__(self, **kwargs): + self._address_groups = {"red": ["white", "blue"], "blue": ["cyan"], "yellow": ["orange"]} + self._addresses = {"white": ["10.1.1.1", "10.2.2.2"], "cyan": ["10.3.3.3"], "orange": ["10.4.4.4"]} - addresses = {"white": ["10.1.1.1", "10.2.2.2"], "cyan": ["10.3.3.3"], "orange": ["10.4.4.4"]} + self._flattened_addresses = self.flatten_addresses(self._address_groups, self._addresses) - def __init__(self, data, *args, **kwargs): - self.flattened_addresses = self.flatten_addresses(self.address_groups, self.addresses) - super().__init__(data, *args, **kwargs) + super().__init__(**kwargs) def flatten_addresses(self, address_groups, addresses): """Go through and get the addresses given potential address groups.""" @@ -320,7 +323,7 @@ def flatten_addresses(self, address_groups, addresses): if group != sub_group: flattened_addresses.setdefault(group, []).extend(ips) - return flattened_addresses + return flattened_addresses def process_ip(self, ip): """Test ability to expand IP for both source and destination.""" @@ -331,10 +334,10 @@ def process_ip(self, ip): for ip_name in ip: if not ip_name[0].isalpha(): output.append(ip_name) - elif self.addresses.get(ip_name): - output.extend(self.addresses[ip_name]) - elif self.flattened_addresses.get(ip_name): - output.extend(self.flattened_addresses[ip_name]) + elif self._addresses.get(ip_name): + output.extend(self._addresses[ip_name]) + elif self._flattened_addresses.get(ip_name): + output.extend(self._flattened_addresses[ip_name]) return sorted(list(set(output))) def process_src_ip(self, src_ip): @@ -354,16 +357,70 @@ def process_dst_ip(self, dst_ip): name="Check allow", src_ip=["red", "blue", "10.4.4.4"], dst_ip=["white"], - dst_port="6/www-http", + dst_port="80", action="permit", ), "received": [ - {"action": "permit", "dst_ip": "10.2.2.2", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.1.1.1"}, - {"action": "permit", "dst_ip": "10.1.1.1", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.2.2.2"}, - {"action": "permit", "dst_ip": "10.1.1.1", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.3.3.3"}, - {"action": "permit", "dst_ip": "10.2.2.2", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.3.3.3"}, - {"action": "permit", "dst_ip": "10.1.1.1", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.4.4.4"}, - {"action": "permit", "dst_ip": "10.2.2.2", "dst_port": "6/80", "name": "Check allow", "src_ip": "10.4.4.4"}, + { + "action": "permit", + "dst_ip": "10.2.2.2", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.1.1.1", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, + { + "action": "permit", + "dst_ip": "10.1.1.1", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.2.2.2", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, + { + "action": "permit", + "dst_ip": "10.1.1.1", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.3.3.3", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, + { + "action": "permit", + "dst_ip": "10.2.2.2", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.3.3.3", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, + { + "action": "permit", + "dst_ip": "10.1.1.1", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.4.4.4", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, + { + "action": "permit", + "dst_ip": "10.2.2.2", + "dst_port": "80", + "name": "Check allow", + "src_ip": "10.4.4.4", + "protocol": None, + "src_zone": None, + "dst_zone": None, + }, ], } ] @@ -371,5 +428,5 @@ def process_dst_ip(self, dst_ip): @pytest.mark.parametrize("data", add_group_check) def test_custom_address_group(data): - obj = TestAddrGroups(data["sent"]) - assert obj.expanded_rules == data["received"] + obj = TestAddrGroups(**data["sent"]) + assert obj.expanded_rules == data["received"] # pylint: disable=protected-access