diff --git a/tests/infrastructure/programs/sre/test_firewall.py b/tests/infrastructure/programs/sre/test_firewall.py index fdda776aeb..2abc7a86c9 100644 --- a/tests/infrastructure/programs/sre/test_firewall.py +++ b/tests/infrastructure/programs/sre/test_firewall.py @@ -1,18 +1,33 @@ +from collections.abc import Callable +from enum import Enum from functools import partial -from typing import Callable import pulumi import pulumi.runtime import pytest from pulumi_azure_native import network +from ..resource_assertions import assert_equal + + +class MyMocks(pulumi.runtime.Mocks): + def new_resource(self, args: pulumi.runtime.MockResourceArgs): + return [args.name + "_id", args.inputs] + + def call(self, _: pulumi.runtime.MockCallArgs): + return {} + + +pulumi.runtime.set_mocks( + MyMocks(), + preview=False, # Sets the flag `dry_run`, which is true at runtime during a preview. +) + from data_safe_haven.infrastructure.programs.sre.firewall import ( SREFirewallComponent, SREFirewallProps, ) -from ..resource_assertions import assert_equal, assert_equal_json - @pytest.fixture def allow_internet_props_setter( @@ -71,6 +86,11 @@ def set_allow_workspace_internet(allow_workspace_internet) -> SREFirewallCompone return set_allow_workspace_internet +class InternetAccess(Enum): + ENABLED = True + DISABLED = False + + class TestSREFirewallProps: @pulumi.runtime.test @@ -108,6 +128,45 @@ def test_component_allow_workspace_internet_enabled( ) firewall_component.firewall.application_rule_collections.apply( - partial(assert_equal_json, []), + partial(TestSREFirewallComponent.assert_allow_internet_access, InternetAccess.ENABLED), # type: ignore run_with_unknowns=True, ) + + @pulumi.runtime.test + def test_component_allow_workspace_internet_disabled( + self, + allow_internet_component_setter: Callable[[bool], SREFirewallComponent], + ): + firewall_component: SREFirewallComponent = allow_internet_component_setter( + allow_workspace_internet=False + ) + + firewall_component.firewall.application_rule_collections.apply( + partial(TestSREFirewallComponent.assert_allow_internet_access, InternetAccess.DISABLED), # type: ignore + run_with_unknowns=True, + ) + + @staticmethod + def assert_allow_internet_access( + internet_access: InternetAccess, + application_rule_collections: ( + list[network.outputs.AzureFirewallApplicationRuleCollectionResponse] | None + ), + ): + + if application_rule_collections is not None: + + workspace_deny_collection: list[ + network.outputs.AzureFirewallApplicationRuleCollectionResponse + ] = [ + rule_collection + for rule_collection in application_rule_collections + if rule_collection.name == "workspaces-deny" + ] + + if internet_access == InternetAccess.ENABLED: + assert not workspace_deny_collection + else: + assert len(workspace_deny_collection) == 1 + else: + raise AssertionError()