diff --git a/.hatch/requirements-test.txt b/.hatch/requirements-test.txt index 5b75759987..a9397a2ba3 100644 --- a/.hatch/requirements-test.txt +++ b/.hatch/requirements-test.txt @@ -1,7 +1,7 @@ # # This file is autogenerated by hatch-pip-compile with Python 3.12 # -# [constraints] .hatch/requirements.txt (SHA256: 697cb5b4ddc1cb9481ae4847d33cf407119c028729be950240b30eaddd8c294e) +# [constraints] .hatch/requirements.txt (SHA256: f892a9714607641735b83f480e2c234b2ab8e1dffd2d59ad4188c887c06b24de) # # - appdirs==1.4.4 # - azure-core==1.31.0 @@ -25,6 +25,7 @@ # - fqdn==1.5.1 # - psycopg[binary]==3.2.3 # - pulumi-azure-native==2.64.3 +# - pulumi-azuread==5.53.4 # - pulumi-random==4.16.6 # - pulumi==3.135.1 # - pydantic==2.9.2 @@ -274,6 +275,7 @@ parver==0.5 # via # -c .hatch/requirements.txt # pulumi-azure-native + # pulumi-azuread # pulumi-random pluggy==1.5.0 # via pytest @@ -298,11 +300,16 @@ pulumi==3.135.1 # -c .hatch/requirements.txt # hatch.envs.test # pulumi-azure-native + # pulumi-azuread # pulumi-random pulumi-azure-native==2.64.3 # via # -c .hatch/requirements.txt # hatch.envs.test +pulumi-azuread==5.53.4 + # via + # -c .hatch/requirements.txt + # hatch.envs.test pulumi-random==4.16.6 # via # -c .hatch/requirements.txt @@ -381,6 +388,7 @@ semver==2.13.0 # -c .hatch/requirements.txt # pulumi # pulumi-azure-native + # pulumi-azuread # pulumi-random shellingham==1.5.4 # via diff --git a/.hatch/requirements.txt b/.hatch/requirements.txt index 9fad62e349..3d77a909db 100644 --- a/.hatch/requirements.txt +++ b/.hatch/requirements.txt @@ -23,6 +23,7 @@ # - fqdn==1.5.1 # - psycopg[binary]==3.2.3 # - pulumi-azure-native==2.64.3 +# - pulumi-azuread==5.53.4 # - pulumi-random==4.16.6 # - pulumi==3.135.1 # - pydantic==2.9.2 @@ -181,6 +182,7 @@ oauthlib==3.2.2 parver==0.5 # via # pulumi-azure-native + # pulumi-azuread # pulumi-random portalocker==2.10.1 # via msal-extensions @@ -194,9 +196,12 @@ pulumi==3.135.1 # via # hatch.envs.default # pulumi-azure-native + # pulumi-azuread # pulumi-random pulumi-azure-native==2.64.3 # via hatch.envs.default +pulumi-azuread==5.53.4 + # via hatch.envs.default pulumi-random==4.16.6 # via hatch.envs.default pycparser==2.22 @@ -243,6 +248,7 @@ semver==2.13.0 # via # pulumi # pulumi-azure-native + # pulumi-azuread # pulumi-random shellingham==1.5.4 # via typer diff --git a/data_safe_haven/commands/sre.py b/data_safe_haven/commands/sre.py index 805bb5e72a..d1f32ea278 100644 --- a/data_safe_haven/commands/sre.py +++ b/data_safe_haven/commands/sre.py @@ -91,7 +91,23 @@ def deploy( replace=True, ) logger.info( - f"SRE will be deployed to subscription '[green]{sre_config.azure.subscription_id}[/]', '[green]{sre_subscription_name}[/]'" + f"SRE will be deployed to subscription '[green]{sre_subscription_name}[/]'" + f" ('[bold]{sre_config.azure.subscription_id}[/]')" + ) + # Set Entra options + application = graph_api.get_application_by_name(context.entra_application_name) + if not application: + msg = f"No Entra application '{context.entra_application_name}' was found. Please redeploy your SHM." + raise DataSafeHavenConfigError(msg) + stack.add_option("azuread:clientId", application.get("appId", ""), replace=True) + if not context.entra_application_secret: + msg = f"No Entra application secret '{context.entra_application_secret_name}' was found. Please redeploy your SHM." + raise DataSafeHavenConfigError(msg) + stack.add_secret( + "azuread:clientSecret", context.entra_application_secret, replace=True + ) + stack.add_option( + "azuread:tenantId", shm_config.shm.entra_tenant_id, replace=True ) # Load SHM outputs stack.add_option( diff --git a/data_safe_haven/config/context.py b/data_safe_haven/config/context.py index 426795bf93..cadd875260 100644 --- a/data_safe_haven/config/context.py +++ b/data_safe_haven/config/context.py @@ -9,6 +9,7 @@ from data_safe_haven import __version__ from data_safe_haven.directories import config_dir +from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external import AzureSdk from data_safe_haven.functions import alphanumeric from data_safe_haven.serialisers import ContextBase @@ -16,39 +17,48 @@ class Context(ContextBase, BaseModel, validate_assignment=True): + """Context for a Data Safe Haven deployment.""" + + entra_application_kvsecret_name: ClassVar[str] = "pulumi-deployment-secret" + entra_application_secret_name: ClassVar[str] = "Pulumi Deployment Secret" + pulumi_encryption_key_name: ClassVar[str] = "pulumi-encryption-key" + pulumi_storage_container_name: ClassVar[str] = "pulumi" + storage_container_name: ClassVar[str] = "config" + admin_group_name: EntraGroupName description: str name: SafeString subscription_name: AzureSubscriptionName - storage_container_name: ClassVar[str] = "config" - pulumi_storage_container_name: ClassVar[str] = "pulumi" - pulumi_encryption_key_name: ClassVar[str] = "pulumi-encryption-key" _pulumi_encryption_key = None + _entra_application_secret = None @property - def tags(self) -> dict[str, str]: - return { - "description": self.description, - "project": "Data Safe Haven", - "shm_name": self.name, - "version": __version__, - } + def entra_application_name(self) -> str: + return f"Data Safe Haven ({self.description}) Pulumi Service Principal" @property - def work_directory(self) -> Path: - return config_dir() / self.name - - @property - def resource_group_name(self) -> str: - return f"shm-{self.name}-rg" - - @property - def storage_account_name(self) -> str: - # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name - # Storage account names must be between 3 and 24 characters in length and may - # contain numbers and lowercase letters only. - return f"shm{alphanumeric(self.name)[:21]}" + def entra_application_secret(self) -> str: + if not self._entra_application_secret: + azure_sdk = AzureSdk(subscription_name=self.subscription_name) + try: + application_secret = azure_sdk.get_keyvault_secret( + secret_name=self.entra_application_kvsecret_name, + key_vault_name=self.key_vault_name, + ) + self._entra_application_secret = application_secret + except DataSafeHavenAzureError: + return "" + return self._entra_application_secret + + @entra_application_secret.setter + def entra_application_secret(self, application_secret: str) -> None: + azure_sdk = AzureSdk(subscription_name=self.subscription_name) + azure_sdk.set_keyvault_secret( + secret_name=self.entra_application_kvsecret_name, + secret_value=application_secret, + key_vault_name=self.key_vault_name, + ) @property def key_vault_name(self) -> str: @@ -83,5 +93,29 @@ def pulumi_encryption_key_version(self) -> str: def pulumi_secrets_provider_url(self) -> str: return f"azurekeyvault://{self.key_vault_name}.vault.azure.net/keys/{self.pulumi_encryption_key_name}/{self.pulumi_encryption_key_version}" + @property + def resource_group_name(self) -> str: + return f"shm-{self.name}-rg" + + @property + def storage_account_name(self) -> str: + # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name + # Storage account names must be between 3 and 24 characters in length and may + # contain numbers and lowercase letters only. + return f"shm{alphanumeric(self.name)[:21]}" + + @property + def tags(self) -> dict[str, str]: + return { + "description": self.description, + "project": "Data Safe Haven", + "shm_name": self.name, + "version": __version__, + } + + @property + def work_directory(self) -> Path: + return config_dir() / self.name + def to_yaml(self) -> str: return yaml.dump(self.model_dump(), indent=2) diff --git a/data_safe_haven/config/dsh_pulumi_config.py b/data_safe_haven/config/dsh_pulumi_config.py index 63581d31dd..1aae87996c 100644 --- a/data_safe_haven/config/dsh_pulumi_config.py +++ b/data_safe_haven/config/dsh_pulumi_config.py @@ -12,6 +12,7 @@ class DSHPulumiConfig(AzureSerialisableModel): config_type: ClassVar[str] = "Pulumi" default_filename: ClassVar[str] = "pulumi.yaml" + encrypted_key: str | None projects: dict[str, DSHPulumiProject] diff --git a/data_safe_haven/config/shm_config.py b/data_safe_haven/config/shm_config.py index 0c48f41cdf..a32d06ed57 100644 --- a/data_safe_haven/config/shm_config.py +++ b/data_safe_haven/config/shm_config.py @@ -12,8 +12,11 @@ class SHMConfig(AzureSerialisableModel): + """Serialisable config for a Data Safe Haven management component.""" + config_type: ClassVar[str] = "SHMConfig" default_filename: ClassVar[str] = "shm.yaml" + azure: ConfigSectionAzure shm: ConfigSectionSHM diff --git a/data_safe_haven/config/sre_config.py b/data_safe_haven/config/sre_config.py index 5a5d6367e1..f4ee5ed6c9 100644 --- a/data_safe_haven/config/sre_config.py +++ b/data_safe_haven/config/sre_config.py @@ -23,8 +23,11 @@ def sre_config_name(sre_name: str) -> str: class SREConfig(AzureSerialisableModel): + """Serialisable config for a secure research environment component.""" + config_type: ClassVar[str] = "SREConfig" default_filename: ClassVar[str] = "sre.yaml" + azure: ConfigSectionAzure description: str dockerhub: ConfigSectionDockerHub diff --git a/data_safe_haven/external/api/azure_sdk.py b/data_safe_haven/external/api/azure_sdk.py index a9245d069d..1792988348 100644 --- a/data_safe_haven/external/api/azure_sdk.py +++ b/data_safe_haven/external/api/azure_sdk.py @@ -14,7 +14,7 @@ ) from azure.keyvault.certificates import CertificateClient, KeyVaultCertificate from azure.keyvault.keys import KeyClient, KeyVaultKey -from azure.keyvault.secrets import SecretClient +from azure.keyvault.secrets import KeyVaultSecret, SecretClient from azure.mgmt.compute.v2021_07_01 import ComputeManagementClient from azure.mgmt.compute.v2021_07_01.models import ( ResourceSkuCapabilities, @@ -451,7 +451,7 @@ def ensure_keyvault_key( """Ensure that a key exists in the KeyVault Returns: - str: The key ID + KeyVaultKey: The key Raises: DataSafeHavenAzureError if the existence of the key could not be verified @@ -476,7 +476,7 @@ def ensure_keyvault_key( ) return key except AzureError as exc: - msg = f"Failed to create key {key_name}." + msg = f"Failed to create key '{key_name}' in KeyVault '{key_vault_name}'." raise DataSafeHavenAzureError(msg) from exc def ensure_managed_identity( @@ -693,7 +693,7 @@ def get_keyvault_secret(self, key_vault_name: str, secret_name: str) -> str: credential=self.credential(AzureSdkCredentialScope.KEY_VAULT), vault_url=f"https://{key_vault_name}.vault.azure.net", ) - # Ensure that secret exists + # Get secret if it exists try: secret = secret_client.get_secret(secret_name) if secret.value: @@ -1302,6 +1302,38 @@ def set_blob_container_acl( msg = f"Failed to set ACL '{desired_acl}' on container '{container_name}'." raise DataSafeHavenAzureError(msg) from exc + def set_keyvault_secret( + self, + secret_name: str, + secret_value: str, + key_vault_name: str, + ) -> KeyVaultSecret: + """Ensure that a secret exists in the KeyVault + + Returns: + KeyVaultSecret: The secret + + Raises: + DataSafeHavenAzureError if the secret could not be set + """ + try: + # Connect to Azure clients + secret_client = SecretClient( + credential=self.credential(AzureSdkCredentialScope.KEY_VAULT), + vault_url=f"https://{key_vault_name}.vault.azure.net", + ) + + # Set secret to given value + self.logger.debug(f"Setting secret [green]{secret_name}[/]...") + secret = secret_client.set_secret(secret_name, secret_value) + self.logger.info(f"Set secret [green]{secret_name}[/].") + return secret + except AzureError as exc: + msg = ( + f"Failed to set secret '{secret_name}' in KeyVault '{key_vault_name}'." + ) + raise DataSafeHavenAzureError(msg) from exc + def storage_exists( self, storage_account_name: str, diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 70a3f298aa..20201a4919 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -16,7 +16,6 @@ DataSafeHavenMicrosoftGraphError, DataSafeHavenValueError, ) -from data_safe_haven.functions import alphanumeric from data_safe_haven.logging import get_logger, get_null_logger from .credentials import DeferredCredential, GraphApiCredential @@ -314,40 +313,6 @@ def create_application_secret( msg = f"Could not create application secret '{application_secret_name}'." raise DataSafeHavenMicrosoftGraphError(msg) from exc - def create_group(self, group_name: str) -> None: - """Create an Entra group if it does not already exist - - Raises: - DataSafeHavenMicrosoftGraphError if the group could not be created - """ - try: - if self.get_id_from_groupname(group_name): - self.logger.info( - f"Found existing Entra group '[green]{group_name}[/]'.", - ) - return - self.logger.debug( - f"Creating Entra group '[green]{group_name}[/]'...", - ) - request_json = { - "description": group_name, - "displayName": group_name, - "groupTypes": [], - "mailEnabled": False, - "mailNickname": alphanumeric(group_name).lower(), - "securityEnabled": True, - } - self.http_post( - f"{self.base_endpoint}/groups", - json=request_json, - ).json() - self.logger.info( - f"Created Entra group '[green]{group_name}[/]'.", - ) - except Exception as exc: - msg = f"Could not create Entra group '{group_name}'." - raise DataSafeHavenMicrosoftGraphError(msg) from exc - def ensure_application_service_principal( self, application_name: str ) -> dict[str, Any]: @@ -1060,17 +1025,17 @@ def verify_custom_domain( DataSafeHavenMicrosoftGraphError if domain could not be verified """ try: - # Create the Entra custom domain if it does not already exist + # Check whether the domain has been added to Entra ID domains = self.read_domains() if not any(d["id"] == domain_name for d in domains): msg = f"Domain {domain_name} has not been added to Entra ID." raise DataSafeHavenMicrosoftGraphError(msg) - # Wait until domain delegation is complete + # Loop until domain delegation is complete while True: # Check whether all expected nameservers are active with suppress(resolver.NXDOMAIN): self.logger.debug( - f"Checking [green]{domain_name}[/] domain verification status ..." + f"Checking [green]{domain_name}[/] domain registration status ..." ) active_nameservers = [ str(ns) for ns in iter(resolver.resolve(domain_name, "NS")) @@ -1080,11 +1045,11 @@ def verify_custom_domain( for nameserver in expected_nameservers ): self.logger.info( - f"Verified that domain [green]{domain_name}[/] is delegated to Azure." + f"Verified that [green]{domain_name}[/] is registered as a custom Entra ID domain." ) break self.logger.warning( - f"Domain [green]{domain_name}[/] is not currently delegated to Azure." + f"Domain [green]{domain_name}[/] is not currently registered as a custom Entra ID domain." ) # Prompt user to set domain delegation manually docs_link = "https://learn.microsoft.com/en-us/azure/dns/dns-delegate-domain-azure-dns#delegate-the-domain" @@ -1093,15 +1058,13 @@ def verify_custom_domain( ) ns_list = ", ".join([f"[green]{n}[/]" for n in expected_nameservers]) self.logger.info( - f"You will need to create an NS record pointing to: {ns_list}" + f"You will need to create NS records pointing to: {ns_list}" ) if not console.confirm( f"Are you ready to check whether [green]{domain_name}[/] has been delegated to Azure?", default_to_yes=True, ): - self.logger.error( - "Please use `az login` to connect to the correct Azure CLI account" - ) + self.logger.error("User terminated check for domain delegation.") raise typer.Exit(1) # Send verification request if needed if not any((d["id"] == domain_name and d["isVerified"]) for d in domains): diff --git a/data_safe_haven/infrastructure/programs/declarative_sre.py b/data_safe_haven/infrastructure/programs/declarative_sre.py index 1afc7f8470..ce678dbb4a 100644 --- a/data_safe_haven/infrastructure/programs/declarative_sre.py +++ b/data_safe_haven/infrastructure/programs/declarative_sre.py @@ -18,6 +18,7 @@ from .sre.data import SREDataComponent, SREDataProps from .sre.desired_state import SREDesiredStateComponent, SREDesiredStateProps from .sre.dns_server import SREDnsServerComponent, SREDnsServerProps +from .sre.entra import SREEntraComponent, SREEntraProps from .sre.firewall import SREFirewallComponent, SREFirewallProps from .sre.identity import SREIdentityComponent, SREIdentityProps from .sre.monitoring import SREMonitoringComponent, SREMonitoringProps @@ -111,6 +112,14 @@ def __call__(self) -> None: ] ) + # Deploy Entra resources + SREEntraComponent( + "sre_entra", + SREEntraProps( + group_names=ldap_group_names, + ), + ) + # Deploy resource group resource_group = resources.ResourceGroup( "sre_resource_group", diff --git a/data_safe_haven/infrastructure/programs/imperative_shm.py b/data_safe_haven/infrastructure/programs/imperative_shm.py index b13ec1680e..73893bff61 100644 --- a/data_safe_haven/infrastructure/programs/imperative_shm.py +++ b/data_safe_haven/infrastructure/programs/imperative_shm.py @@ -114,17 +114,18 @@ def deploy(self) -> None: msg = "Failed to create SHM resources." raise DataSafeHavenAzureError(msg) from exc + # Connect to GraphAPI + graph_api = GraphApi.from_scopes( + scopes=[ + "Application.ReadWrite.All", + "Domain.ReadWrite.All", + "Group.ReadWrite.All", + ], + tenant_id=self.config.shm.entra_tenant_id, + ) # Add the SHM domain to the Entra ID via interactive GraphAPI try: # Generate the verification record - graph_api = GraphApi.from_scopes( - scopes=[ - "Application.ReadWrite.All", - "Domain.ReadWrite.All", - "Group.ReadWrite.All", - ], - tenant_id=self.config.shm.entra_tenant_id, - ) verification_record = graph_api.add_custom_domain(self.config.shm.fqdn) # Add the record to DNS self.azure_sdk.ensure_dns_txt_record( @@ -142,6 +143,28 @@ def deploy(self) -> None: except (DataSafeHavenMicrosoftGraphError, DataSafeHavenAzureError) as exc: msg = f"Failed to add custom domain '{self.config.shm.fqdn}' to Entra ID." raise DataSafeHavenAzureError(msg) from exc + # Create an application for use by the pulumi-azuread module + try: + graph_api.create_application( + self.context.entra_application_name, + application_scopes=["Group.ReadWrite.All"], + delegated_scopes=[], + request_json={ + "displayName": self.context.entra_application_name, + "signInAudience": "AzureADMyOrg", + }, + ) + # Ensure that the application secret exists + if not self.context.entra_application_secret: + self.context.entra_application_secret = ( + graph_api.create_application_secret( + self.context.entra_application_name, + self.context.entra_application_secret_name, + ) + ) + except DataSafeHavenMicrosoftGraphError as exc: + msg = "Failed to create deployment application in Entra ID." + raise DataSafeHavenAzureError(msg) from exc def teardown(self) -> None: """Destroy all created resources diff --git a/data_safe_haven/infrastructure/programs/sre/entra.py b/data_safe_haven/infrastructure/programs/sre/entra.py new file mode 100644 index 0000000000..1f44995f9f --- /dev/null +++ b/data_safe_haven/infrastructure/programs/sre/entra.py @@ -0,0 +1,40 @@ +"""Pulumi component for SRE Entra resources""" + +from collections.abc import Mapping + +from pulumi import ComponentResource, ResourceOptions +from pulumi_azuread import Group + +from data_safe_haven.functions import replace_separators + + +class SREEntraProps: + """Properties for SREEntraComponent""" + + def __init__( + self, + group_names: Mapping[str, str], + ) -> None: + self.group_names = group_names + + +class SREEntraComponent(ComponentResource): + """Deploy SRE Entra resources with Pulumi""" + + def __init__( + self, + name: str, + props: SREEntraProps, + opts: ResourceOptions | None = None, + ) -> None: + super().__init__("dsh:sre:EntraComponent", name, {}, opts) + + for group_id, group_description in props.group_names.items(): + Group( + replace_separators(f"{self._name}_group_{group_id}", "_"), + description=group_description, + display_name=group_description, + mail_enabled=False, + prevent_duplicate_names=True, + security_enabled=True, + ) diff --git a/data_safe_haven/infrastructure/project_manager.py b/data_safe_haven/infrastructure/project_manager.py index f9706ec096..dcb3941af2 100644 --- a/data_safe_haven/infrastructure/project_manager.py +++ b/data_safe_haven/infrastructure/project_manager.py @@ -146,6 +146,10 @@ def add_option(self, name: str, value: str, *, replace: bool) -> None: """Add a public configuration option""" self._options[name] = (value, False, replace) + def add_secret(self, name: str, value: str, *, replace: bool) -> None: + """Add a secret configuration option""" + self._options[name] = (value, True, replace) + def apply_config_options(self) -> None: """Set Pulumi config options""" try: diff --git a/data_safe_haven/provisioning/sre_provisioning_manager.py b/data_safe_haven/provisioning/sre_provisioning_manager.py index 1111bc573f..7c39046b86 100644 --- a/data_safe_haven/provisioning/sre_provisioning_manager.py +++ b/data_safe_haven/provisioning/sre_provisioning_manager.py @@ -71,11 +71,6 @@ def available_vm_skus(self) -> dict[str, dict[str, Any]]: self._available_vm_skus = azure_sdk.list_available_vm_skus(self.location) return self._available_vm_skus - def create_security_groups(self) -> None: - """Create groups in Entra ID""" - for group_name in self.security_group_params.values(): - self.graph_api.create_group(group_name) - def restart_remote_desktop_containers(self) -> None: """Restart the Guacamole container group""" guacamole_provisioner = AzureContainerInstance( @@ -137,6 +132,5 @@ def update_remote_desktop_connections(self) -> None: def run(self) -> None: """Apply SRE configuration""" - self.create_security_groups() self.update_remote_desktop_connections() self.restart_remote_desktop_containers() diff --git a/pyproject.toml b/pyproject.toml index 90034aa75c..ae503bed03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "fqdn==1.5.1", "psycopg[binary]==3.2.3", "pulumi-azure-native==2.64.3", + "pulumi-azuread==5.53.4", "pulumi-random==4.16.6", "pulumi==3.135.1", "pydantic==2.9.2", @@ -184,6 +185,7 @@ module = [ "numpy.*", "psycopg.*", "pulumi_azure_native.*", + "pulumi_azuread.*", "pulumi_random.*", "pulumi.*", "pymssql.*", diff --git a/tests/commands/conftest.py b/tests/commands/conftest.py index 6459c84d6c..dab10adb7b 100644 --- a/tests/commands/conftest.py +++ b/tests/commands/conftest.py @@ -26,6 +26,12 @@ def mock_azure_sdk_blob_exists_false(mocker): mocker.patch.object(AzureSdk, "blob_exists", return_value=False) +@fixture +def mock_contextmanager_assert_context(mocker, context) -> Context: + context._entra_application_secret = "dummy-secret" # noqa: S105 + mocker.patch.object(ContextManager, "assert_context", return_value=context) + + @fixture def mock_graph_api_add_custom_domain(mocker): mocker.patch.object( @@ -33,6 +39,15 @@ def mock_graph_api_add_custom_domain(mocker): ) +@fixture +def mock_graph_api_get_application_by_name(mocker, request): + mocker.patch.object( + GraphApi, + "get_application_by_name", + return_value={"appId": request.config.guid_application}, + ) + + @fixture def mock_graph_api_token(mocker): mocker.patch.object(GraphApi, "token", return_value="dummy-token") @@ -174,7 +189,7 @@ def teardown_then_exit(*args, **kwargs): # noqa: ARG001 @fixture -def runner(tmp_contexts): +def runner(tmp_contexts) -> CliRunner: runner = CliRunner( env={ "DSH_CONFIG_DIRECTORY": str(tmp_contexts), diff --git a/tests/commands/test_sre.py b/tests/commands/test_sre.py index ced39630ba..f8818d20cc 100644 --- a/tests/commands/test_sre.py +++ b/tests/commands/test_sre.py @@ -1,34 +1,88 @@ +from pytest import CaptureFixture, LogCaptureFixture +from pytest_mock import MockerFixture +from typer.testing import CliRunner + from data_safe_haven.commands.sre import sre_command_group +from data_safe_haven.config import Context, ContextManager +from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.external import AzureSdk class TestDeploySRE: def test_deploy( self, - runner, + runner: CliRunner, mock_azuresdk_get_subscription_name, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 + mock_contextmanager_assert_context, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote_or_create, # noqa: ARG002 mock_pulumi_config_upload, # noqa: ARG002 mock_shm_config_from_remote, # noqa: ARG002 mock_sre_config_from_remote, # noqa: ARG002 + mock_graph_api_get_application_by_name, # noqa: ARG002 mock_sre_project_manager_deploy_then_exit, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) assert result.exit_code == 1 assert "mock deploy" in result.stdout assert "mock deploy error" in result.stdout - def test_no_context_file(self, runner_no_context_file): + def test_no_application( + self, + caplog: LogCaptureFixture, + runner: CliRunner, + mock_azuresdk_get_subscription_name, # noqa: ARG002 + mock_contextmanager_assert_context, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 + mock_ip_1_2_3_4, # noqa: ARG002 + mock_pulumi_config_from_remote_or_create, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + ) -> None: + result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) + assert result.exit_code == 1 + assert ( + "No Entra application 'Data Safe Haven (Acme Deployment) Pulumi Service Principal' was found." + in caplog.text + ) + assert "Please redeploy your SHM." in caplog.text + + def test_no_application_secret( + self, + caplog: LogCaptureFixture, + runner: CliRunner, + context: Context, + mocker: MockerFixture, + mock_azuresdk_get_subscription_name, # noqa: ARG002 + mock_graph_api_get_application_by_name, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 + mock_ip_1_2_3_4, # noqa: ARG002 + mock_pulumi_config_from_remote_or_create, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + ) -> None: + mocker.patch.object( + AzureSdk, "get_keyvault_secret", side_effect=DataSafeHavenAzureError("") + ) + mocker.patch.object(ContextManager, "assert_context", return_value=context) + result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) + assert result.exit_code == 1 + assert ( + "No Entra application secret 'Pulumi Deployment Secret' was found. Please redeploy your SHM." + in caplog.text + ) + + def test_no_context_file(self, runner_no_context_file) -> None: result = runner_no_context_file.invoke(sre_command_group, ["deploy", "sandbox"]) assert result.exit_code == 1 assert "Could not find file" in result.stdout def test_auth_failure( self, - runner, + runner: CliRunner, mock_azuresdk_get_credential_failure, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) assert result.exit_code == 1 assert "mock get_credential\n" in result.stdout @@ -37,9 +91,9 @@ def test_auth_failure( def test_no_shm( self, capfd, - runner, + runner: CliRunner, mock_shm_config_from_remote_fails, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) out, _ = capfd.readouterr() assert result.exit_code == 1 @@ -49,19 +103,19 @@ def test_no_shm( class TestTeardownSRE: def test_teardown( self, - runner, + runner: CliRunner, mock_graph_api_token, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote, # noqa: ARG002 mock_shm_config_from_remote, # noqa: ARG002 mock_sre_config_from_remote, # noqa: ARG002 mock_sre_project_manager_teardown_then_exit, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["teardown", "sandbox"]) assert result.exit_code == 1 assert "mock teardown" in result.stdout - def test_no_context_file(self, runner_no_context_file): + def test_no_context_file(self, runner_no_context_file) -> None: result = runner_no_context_file.invoke( sre_command_group, ["teardown", "sandbox"] ) @@ -70,10 +124,10 @@ def test_no_context_file(self, runner_no_context_file): def test_no_shm( self, - capfd, - runner, + capfd: CaptureFixture, + runner: CliRunner, mock_shm_config_from_remote_fails, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["teardown", "sandbox"]) out, _ = capfd.readouterr() assert result.exit_code == 1 @@ -81,9 +135,9 @@ def test_no_shm( def test_auth_failure( self, - runner, + runner: CliRunner, mock_azuresdk_get_credential_failure, # noqa: ARG002 - ): + ) -> None: result = runner.invoke(sre_command_group, ["teardown", "sandbox"]) assert result.exit_code == 1 assert "mock get_credential\n" in result.stdout diff --git a/tests/config/test_context_manager.py b/tests/config/test_context_manager.py index cd30e84f82..6167fb2096 100644 --- a/tests/config/test_context_manager.py +++ b/tests/config/test_context_manager.py @@ -4,10 +4,12 @@ from data_safe_haven.config import Context, ContextManager from data_safe_haven.exceptions import ( + DataSafeHavenAzureError, DataSafeHavenConfigError, DataSafeHavenTypeError, DataSafeHavenValueError, ) +from data_safe_haven.external import AzureSdk from data_safe_haven.version import __version__ @@ -29,6 +31,35 @@ def test_invalid_subscription_name(self, context_dict): ): Context(**context_dict) + def test_entra_application_name(self, context: Context) -> None: + assert ( + context.entra_application_name + == "Data Safe Haven (Acme Deployment) Pulumi Service Principal" + ) + + def test_entra_application_secret(self, context: Context, mocker) -> None: + mocker.patch.object( + AzureSdk, "get_keyvault_secret", return_value="secret-value" + ) + assert context.entra_application_secret == "secret-value" # noqa: S105 + + def test_entra_application_secret_missing(self, context: Context, mocker) -> None: + mocker.patch.object( + AzureSdk, + "get_keyvault_secret", + side_effect=DataSafeHavenAzureError("Error message"), + ) + assert context.entra_application_secret == "" + + def test_entra_application_secret_setter(self, context: Context, mocker) -> None: + mock_set_keyvault_secret = mocker.patch.object(AzureSdk, "set_keyvault_secret") + context.entra_application_secret = "secret-value" # noqa: S105 + mock_set_keyvault_secret.assert_called_once_with( + key_vault_name="shm-acmedeployment-kv", + secret_name="pulumi-deployment-secret", + secret_value="secret-value", + ) + def test_tags(self, context): assert context.tags["description"] == "Acme Deployment" assert context.tags["project"] == "Data Safe Haven" diff --git a/tests/conftest.py b/tests/conftest.py index 219458a0dd..4626f1061e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,6 +38,7 @@ def pytest_configure(config): """Define constants for use across multiple tests""" config.guid_admin = "00edec65-b071-4d26-8779-a9fe791c6e14" + config.guid_application = "aa78dceb-4116-4713-8554-cf2b3027e119" config.guid_entra = "48b2425b-5f2c-4cbd-9458-0441daa8994c" config.guid_subscription = "35ebced1-4e7a-4c1f-b634-c0886937085d" config.guid_tenant = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd"