From 02afb30c88ed4b01e5c4f1abba84ea94afba4650 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 16:06:28 +0100 Subject: [PATCH 01/54] :sparkles: Add credential loaders which will wait until the credentials are needed before making external calls --- .../external/interface/credentials.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 data_safe_haven/external/interface/credentials.py diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py new file mode 100644 index 0000000000..fc07b67c04 --- /dev/null +++ b/data_safe_haven/external/interface/credentials.py @@ -0,0 +1,112 @@ +"""Classes related to Azure credentials""" + +import pathlib +from abc import abstractmethod +from collections.abc import Sequence + +from azure.core.credentials import AccessToken, TokenCredential +from azure.identity import ( + AuthenticationRecord, + AzureCliCredential, + DeviceCodeCredential, + TokenCachePersistenceOptions, +) + +from data_safe_haven.singleton import Singleton + + +class DeferredCredentialLoader(metaclass=Singleton): + """A wrapper class that initialises and caches credentials as they are needed""" + + def __init__(self) -> None: + self.credential_: TokenCredential | None = None + + @property + def credential(self) -> TokenCredential: + """Return the cached credential provider.""" + if not self.credential_: + self.credential_ = self.get_credential() + return self.credential_ + + @property + def token(self) -> str: + """Get a token from the credential provider.""" + return str(self.get_token().token) + + @abstractmethod + def get_credential(self) -> TokenCredential: + """Get new credential provider.""" + pass + + def get_token(self) -> AccessToken: + """Get new access token.""" + return self.credential.get_token() + + +class AzureApiCredentialLoader(DeferredCredentialLoader): + """ + Credential loader used by AzureApi + + Uses AzureCliCredential for authentication + """ + + def get_credential(self) -> TokenCredential: + """Get a new AzureCliCredential.""" + return AzureCliCredential(additionally_allowed_tenants=["*"]) + + +class GraphApiCredentialLoader(DeferredCredentialLoader): + """ + Credential loader used by GraphApi + + Uses DeviceCodeCredential for authentication + """ + + def __init__( + self, + tenant_id: str, + default_scopes: Sequence[str] = [], + ) -> None: + super().__init__() + self.default_scopes = default_scopes + self.tenant_id = tenant_id + + def get_credential(self) -> TokenCredential: + """Get a new DeviceCodeCredential, using cached credentials if they are available""" + cache_name = f"dsh-{self.tenant_id}" + authentication_record_path = ( + pathlib.Path.home() / f".msal-authentication-cache-{cache_name}" + ) + + # Read an existing authentication record, using default arguments if unavailable + kwargs = {} + try: + with open(authentication_record_path) as f_auth: + existing_auth_record = AuthenticationRecord.deserialize(f_auth.read()) + kwargs["authentication_record"] = existing_auth_record + except FileNotFoundError: + kwargs["authority"] = "https://login.microsoftonline.com/" + # Use the Microsoft Graph Command Line Tools client ID + kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" + kwargs["tenant_id"] = self.tenant_id + + # Get a credential + credential = DeviceCodeCredential( + cache_persistence_options=TokenCachePersistenceOptions(name=cache_name), + **kwargs, + ) + + # Write out an authentication record for this credential + new_auth_record = credential.authenticate(scopes=self.default_scopes) + with open(authentication_record_path, "w") as f_auth: + f_auth.write(new_auth_record.serialize()) + + # Return the credential + return credential + + def get_token(self) -> AccessToken: + """Get new access token using pre-defined scopes and tenant ID.""" + return self.credential.get_token( + *self.default_scopes, + tenant_id=self.tenant_id, + ) From f13a2bf310e7d216e36dcaec913d1ff0ba01c68f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 16:52:39 +0100 Subject: [PATCH 02/54] :sparkles: Add subscription_id function to AzureCliSingleton --- data_safe_haven/external/api/azure_cli.py | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index 7e612fa35f..a5bec0558e 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -103,3 +103,31 @@ def group_id_from_name(self, group_name: str) -> str: except (IndexError, KeyError) as exc: msg = f"Group '{group_name}' was not found in Azure CLI." raise DataSafeHavenAzureError(msg) from exc + + def subscription_id(self, subscription_name: str) -> str: + """Get subscription ID from an Azure subscription name.""" + try: + result = subprocess.check_output( + [ + self.path, + "account", + "subscription", + "list", + "--query", + f"[?displayName == '{subscription_name}']", + ], + stderr=subprocess.PIPE, + encoding="utf8", + ) + result_dict = json.loads(result) + return str(result_dict[0]["subscriptionId"]) + except subprocess.CalledProcessError as exc: + self.logger.critical(exc.stderr) + msg = "Error reading subscriptions from Azure CLI." + raise DataSafeHavenAzureError(msg) from exc + except json.JSONDecodeError as exc: + msg = "Unable to parse Azure CLI output as JSON." + raise DataSafeHavenAzureError(msg) from exc + except (IndexError, KeyError) as exc: + msg = f"Subscription '{subscription_name}' was not found in Azure CLI." + raise DataSafeHavenAzureError(msg) from exc From 311871377957c66ad7c5a880011c263e12a6774e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 16:51:29 +0100 Subject: [PATCH 03/54] :recycle: Remove inheritance of AzureApi from AzureAuthenticator --- data_safe_haven/external/api/azure_api.py | 32 +++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 8676febbe2..ef11bca14d 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -11,6 +11,7 @@ ResourceNotFoundError, ServiceRequestError, ) +from azure.identity import AzureCliCredential from azure.keyvault.certificates import ( CertificateClient, KeyVaultCertificate, @@ -66,13 +67,40 @@ from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator from data_safe_haven.logging import get_logger +from .azure_cli import AzureCliSingleton -class AzureApi(AzureAuthenticator): + +class AzureApi: """Interface to the Azure REST API""" def __init__(self, subscription_name: str) -> None: - super().__init__(subscription_name) self.logger = get_logger() + self.subscription_name = subscription_name + self.credential_: str | None = None + self.subscription_id_: str | None = None + self.tenant_id_: str | None = None + + @property + def credential(self) -> AzureCliCredential: + if not self.credential_: + authenticator = AzureAuthenticator(self.subscription_name) + self.credential_ = authenticator.credential + return self.credential_ + + @property + def subscription_id(self) -> str: + if not self.subscription_id_: + self.subscription_id_ = AzureCliSingleton().subscription_id( + self.subscription_name + ) + return self.subscription_id_ + + @property + def tenant_id(self) -> str: + if not self.tenant_id_: + authenticator = AzureAuthenticator(self.subscription_name) + self.tenant_id_ = authenticator.tenant_id + return self.tenant_id_ def blob_client( self, From 6c9b90e14982a4ad58114b62ec58ad58d44b76ba Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 17:33:24 +0100 Subject: [PATCH 04/54] :recycle: Use AzureApiCredentialLoader to lazy-load AzureCliCredential when needed --- data_safe_haven/external/api/azure_api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index ef11bca14d..ab0e790305 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -4,6 +4,7 @@ from contextlib import suppress from typing import Any, cast +from azure.core.credentials import TokenCredential from azure.core.exceptions import ( AzureError, HttpResponseError, @@ -11,7 +12,6 @@ ResourceNotFoundError, ServiceRequestError, ) -from azure.identity import AzureCliCredential from azure.keyvault.certificates import ( CertificateClient, KeyVaultCertificate, @@ -65,6 +65,10 @@ from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator +from data_safe_haven.external.interface.credentials import AzureApiCredentialLoader +from data_safe_haven.exceptions import ( + DataSafeHavenAzureError, +) from data_safe_haven.logging import get_logger from .azure_cli import AzureCliSingleton @@ -76,16 +80,12 @@ class AzureApi: def __init__(self, subscription_name: str) -> None: self.logger = get_logger() self.subscription_name = subscription_name - self.credential_: str | None = None self.subscription_id_: str | None = None self.tenant_id_: str | None = None @property - def credential(self) -> AzureCliCredential: - if not self.credential_: - authenticator = AzureAuthenticator(self.subscription_name) - self.credential_ = authenticator.credential - return self.credential_ + def credential(self) -> TokenCredential: + return AzureApiCredentialLoader().credential @property def subscription_id(self) -> str: From 6ca4b80f2067e8658843e19a1002e03ff5472d9c Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 17:30:20 +0100 Subject: [PATCH 05/54] :recycle: Add get_subscription function to AzureApi --- data_safe_haven/external/api/azure_api.py | 36 +++++++++++++++++------ 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index ab0e790305..f9634e6823 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -7,6 +7,7 @@ from azure.core.credentials import TokenCredential from azure.core.exceptions import ( AzureError, + ClientAuthenticationError, HttpResponseError, ResourceExistsError, ResourceNotFoundError, @@ -48,7 +49,7 @@ from azure.mgmt.resource.resources.v2021_04_01 import ResourceManagementClient from azure.mgmt.resource.resources.v2021_04_01.models import ResourceGroup from azure.mgmt.resource.subscriptions import SubscriptionClient -from azure.mgmt.resource.subscriptions.models import Location +from azure.mgmt.resource.subscriptions.models import Location, Subscription from azure.mgmt.storage.v2021_08_01 import StorageManagementClient from azure.mgmt.storage.v2021_08_01.models import ( BlobContainer, @@ -63,7 +64,11 @@ from azure.storage.blob import BlobClient, BlobServiceClient from azure.storage.filedatalake import DataLakeServiceClient -from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.exceptions import ( + DataSafeHavenAzureAPIAuthenticationError, + DataSafeHavenAzureError, + DataSafeHavenValueError, +) from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator from data_safe_haven.external.interface.credentials import AzureApiCredentialLoader from data_safe_haven.exceptions import ( @@ -71,35 +76,35 @@ ) from data_safe_haven.logging import get_logger -from .azure_cli import AzureCliSingleton - class AzureApi: """Interface to the Azure REST API""" def __init__(self, subscription_name: str) -> None: self.logger = get_logger() + self.authenticator = AzureApiCredentialLoader() self.subscription_name = subscription_name self.subscription_id_: str | None = None self.tenant_id_: str | None = None @property def credential(self) -> TokenCredential: - return AzureApiCredentialLoader().credential + return self.authenticator.credential @property def subscription_id(self) -> str: if not self.subscription_id_: - self.subscription_id_ = AzureCliSingleton().subscription_id( - self.subscription_name + self.subscription_id_ = str( + self.get_subscription(self.subscription_name).subscription_id ) return self.subscription_id_ @property def tenant_id(self) -> str: if not self.tenant_id_: - authenticator = AzureAuthenticator(self.subscription_name) - self.tenant_id_ = authenticator.tenant_id + self.tenant_id_ = str( + self.get_subscription(self.subscription_name).tenant_id + ) return self.tenant_id_ def blob_client( @@ -713,6 +718,19 @@ def get_storage_account_keys( msg = f"Keys could not be loaded for {msg_sa} in {msg_rg}." raise DataSafeHavenAzureError(msg) from exc + def get_subscription(self, subscription_name: str) -> Subscription: + """Get the current Azure subscription.""" + try: + subscription_client = SubscriptionClient(self.credential) + for subscription in subscription_client.subscriptions.list(): + if subscription.display_name == subscription_name: + return subscription + except ClientAuthenticationError as exc: + msg = "Failed to authenticate with Azure API." + raise DataSafeHavenAzureAPIAuthenticationError(msg) from exc + msg = f"Could not find subscription '{subscription_name}'" + raise DataSafeHavenValueError(msg) + def import_keyvault_certificate( self, certificate_name: str, From 837f0fe63681bc08e461b3ec7f42c8892b2d90f4 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 19:30:24 +0100 Subject: [PATCH 06/54] :coffin: Remove unused AzureAuthenticator --- .../external/interface/azure_authenticator.py | 71 ------------------- 1 file changed, 71 deletions(-) delete mode 100644 data_safe_haven/external/interface/azure_authenticator.py diff --git a/data_safe_haven/external/interface/azure_authenticator.py b/data_safe_haven/external/interface/azure_authenticator.py deleted file mode 100644 index 9724359464..0000000000 --- a/data_safe_haven/external/interface/azure_authenticator.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Standalone utility class for anything that needs to authenticate against Azure""" - -from typing import cast - -from azure.core.exceptions import ClientAuthenticationError -from azure.identity import AzureCliCredential -from azure.mgmt.resource.subscriptions import SubscriptionClient -from azure.mgmt.resource.subscriptions.models import Subscription - -from data_safe_haven.exceptions import ( - DataSafeHavenAzureAPIAuthenticationError, - DataSafeHavenAzureError, - DataSafeHavenValueError, -) - - -class AzureAuthenticator: - """Standalone utility class for anything that needs to authenticate against Azure""" - - def __init__(self, subscription_name: str) -> None: - self.subscription_name: str = subscription_name - self.credential_: AzureCliCredential | None = None - self.subscription_id_: str | None = None - self.tenant_id_: str | None = None - - @property - def credential(self) -> AzureCliCredential: - if not self.credential_: - self.credential_ = AzureCliCredential( - additionally_allowed_tenants=["*"], - ) - return self.credential_ - - @property - def subscription_id(self) -> str: - if not self.subscription_id_: - self.login() - if not self.subscription_id_: - msg = "Failed to load subscription ID." - raise DataSafeHavenAzureError(msg) - return self.subscription_id_ - - @property - def tenant_id(self) -> str: - if not self.tenant_id_: - self.login() - if not self.tenant_id_: - msg = "Failed to load tenant ID." - raise DataSafeHavenAzureError(msg) - return self.tenant_id_ - - def login(self) -> None: - """Get subscription and tenant IDs""" - # Connect to Azure clients - subscription_client = SubscriptionClient(self.credential) - - # Check that the Azure credentials are valid - try: - for subscription in cast( - list[Subscription], subscription_client.subscriptions.list() - ): - if subscription.display_name == self.subscription_name: - self.subscription_id_ = subscription.subscription_id - self.tenant_id_ = subscription.tenant_id - break - except ClientAuthenticationError as exc: - msg = "Failed to authenticate with Azure API." - raise DataSafeHavenAzureAPIAuthenticationError(msg) from exc - if not (self.subscription_id_ and self.tenant_id_): - msg = f"Could not find subscription '{self.subscription_name}'" - raise DataSafeHavenValueError(msg) From 9dc437074a35a631be6dbb46b56f10123d6e278d Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 3 Jul 2024 16:45:43 +0100 Subject: [PATCH 07/54] :truck: Move AzureCliSingleton confirmation from PulumiAccount to AzureApiCredentialLoader --- data_safe_haven/external/api/azure_cli.py | 11 +++++++++-- data_safe_haven/external/interface/credentials.py | 2 ++ data_safe_haven/external/interface/pulumi_account.py | 6 +----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index a5bec0558e..35fcfaf9a1 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -20,6 +20,7 @@ class AzureCliAccount: name: str id_: str tenant_id: str + tenant_name: str class AzureCliSingleton(metaclass=Singleton): @@ -60,6 +61,7 @@ def account(self) -> AzureCliAccount: name=result_dict.get("user").get("name"), id_=result_dict.get("id"), tenant_id=result_dict.get("tenantId"), + tenant_name=result_dict.get("tenantDisplayName"), ) return self._account @@ -70,8 +72,13 @@ def confirm(self) -> None: return None account = self.account - self.logger.info(f"Azure user: {account.name} ({account.id_})") - self.logger.info(f"Azure tenant ID: {account.tenant_id})") + self.logger.info( + "You are currently logged into the Azure CLI with the following details:" + ) + self.logger.info(f"... Azure user: [blue]{account.name}[/] ({account.id_})") + self.logger.info( + f"... Azure tenant: [blue]{account.tenant_name}[/] ({account.tenant_id})" + ) if not console.confirm( "Is this the Azure account you expect?", default_to_yes=False ): diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index fc07b67c04..f96b7eb914 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -13,6 +13,7 @@ ) from data_safe_haven.singleton import Singleton +from data_safe_haven.external.api.azure_cli import AzureCliSingleton class DeferredCredentialLoader(metaclass=Singleton): @@ -52,6 +53,7 @@ class AzureApiCredentialLoader(DeferredCredentialLoader): def get_credential(self) -> TokenCredential: """Get a new AzureCliCredential.""" + AzureCliSingleton().confirm() # get user confirmation of the current account return AzureCliCredential(additionally_allowed_tenants=["*"]) diff --git a/data_safe_haven/external/interface/pulumi_account.py b/data_safe_haven/external/interface/pulumi_account.py index fbd270f99a..8f7fc55c90 100644 --- a/data_safe_haven/external/interface/pulumi_account.py +++ b/data_safe_haven/external/interface/pulumi_account.py @@ -4,7 +4,7 @@ from typing import Any from data_safe_haven.exceptions import DataSafeHavenPulumiError -from data_safe_haven.external import AzureApi, AzureCliSingleton +from data_safe_haven.external import AzureApi class PulumiAccount: @@ -26,10 +26,6 @@ def __init__( msg = "Unable to find Pulumi CLI executable in your path.\nPlease ensure that Pulumi is installed" raise DataSafeHavenPulumiError(msg) - # Ensure Azure CLI account is correct - # This will be needed to populate env - AzureCliSingleton().confirm() - @property def env(self) -> dict[str, Any]: """Get necessary Pulumi environment variables""" From 792bc35ab1cc77a9a9105732c4ccc2bc8598f314 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 15:15:37 +0100 Subject: [PATCH 08/54] :recycle: Make DeferredCredentialLoader a regular class instead of a singleton --- .../external/interface/credentials.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index f96b7eb914..a9fd2a7057 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -1,10 +1,10 @@ """Classes related to Azure credentials""" import pathlib -from abc import abstractmethod +from abc import ABCMeta, abstractmethod from collections.abc import Sequence -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import TokenCredential from azure.identity import ( AuthenticationRecord, AzureCliCredential, @@ -12,15 +12,20 @@ TokenCachePersistenceOptions, ) -from data_safe_haven.singleton import Singleton from data_safe_haven.external.api.azure_cli import AzureCliSingleton -class DeferredCredentialLoader(metaclass=Singleton): +class DeferredCredentialLoader(metaclass=ABCMeta): """A wrapper class that initialises and caches credentials as they are needed""" - def __init__(self) -> None: + def __init__( + self, + scopes: Sequence[str], + tenant_id: str | None = None, + ) -> None: self.credential_: TokenCredential | None = None + self.scopes = scopes + self.tenant_id = tenant_id @property def credential(self) -> TokenCredential: @@ -32,17 +37,18 @@ def credential(self) -> TokenCredential: @property def token(self) -> str: """Get a token from the credential provider.""" - return str(self.get_token().token) + return str( + self.credential.get_token( + *self.scopes, + tenant_id=self.tenant_id, + ).token + ) @abstractmethod def get_credential(self) -> TokenCredential: """Get new credential provider.""" pass - def get_token(self) -> AccessToken: - """Get new access token.""" - return self.credential.get_token() - class AzureApiCredentialLoader(DeferredCredentialLoader): """ @@ -51,6 +57,9 @@ class AzureApiCredentialLoader(DeferredCredentialLoader): Uses AzureCliCredential for authentication """ + def __init__(self) -> None: + super().__init__(scopes=["https://management.azure.com/.default"]) + def get_credential(self) -> TokenCredential: """Get a new AzureCliCredential.""" AzureCliSingleton().confirm() # get user confirmation of the current account @@ -67,11 +76,9 @@ class GraphApiCredentialLoader(DeferredCredentialLoader): def __init__( self, tenant_id: str, - default_scopes: Sequence[str] = [], + scopes: Sequence[str] = [], ) -> None: - super().__init__() - self.default_scopes = default_scopes - self.tenant_id = tenant_id + super().__init__(scopes=scopes, tenant_id=tenant_id) def get_credential(self) -> TokenCredential: """Get a new DeviceCodeCredential, using cached credentials if they are available""" @@ -99,16 +106,9 @@ def get_credential(self) -> TokenCredential: ) # Write out an authentication record for this credential - new_auth_record = credential.authenticate(scopes=self.default_scopes) + new_auth_record = credential.authenticate(scopes=self.scopes) with open(authentication_record_path, "w") as f_auth: f_auth.write(new_auth_record.serialize()) # Return the credential return credential - - def get_token(self) -> AccessToken: - """Get new access token using pre-defined scopes and tenant ID.""" - return self.credential.get_token( - *self.default_scopes, - tenant_id=self.tenant_id, - ) From 1d88c62b2c8ab892b53a111b2b4b7bf6487dde4b Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 15:20:38 +0100 Subject: [PATCH 09/54] :recycle: Use GraphApiCredentialLoader to get credentials for GraphApi --- data_safe_haven/commands/pulumi.py | 4 +-- data_safe_haven/commands/sre.py | 6 ++-- data_safe_haven/commands/users.py | 12 +++---- data_safe_haven/external/api/graph_api.py | 31 +++++++++---------- .../infrastructure/programs/imperative_shm.py | 4 +-- 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/data_safe_haven/commands/pulumi.py b/data_safe_haven/commands/pulumi.py index 569217ccaf..3949467636 100644 --- a/data_safe_haven/commands/pulumi.py +++ b/data_safe_haven/commands/pulumi.py @@ -42,13 +42,13 @@ def run( sre_config = SREConfig.from_remote_by_name(context, sre_name) graph_api = GraphApi( - tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=[ + scopes=[ "Application.ReadWrite.All", "AppRoleAssignment.ReadWrite.All", "Directory.ReadWrite.All", "Group.ReadWrite.All", ], + tenant_id=shm_config.shm.entra_tenant_id, ) project = SREProjectManager( diff --git a/data_safe_haven/commands/sre.py b/data_safe_haven/commands/sre.py index 734292fed0..825efc80b0 100644 --- a/data_safe_haven/commands/sre.py +++ b/data_safe_haven/commands/sre.py @@ -41,13 +41,13 @@ def deploy( # Load GraphAPI as this may require user-interaction graph_api = GraphApi( - tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=[ + scopes=[ "Application.ReadWrite.All", "AppRoleAssignment.ReadWrite.All", "Directory.ReadWrite.All", "Group.ReadWrite.All", ], + tenant_id=shm_config.shm.entra_tenant_id, ) # Load Pulumi and SRE configs @@ -146,8 +146,8 @@ def teardown( # Load GraphAPI as this may require user-interaction graph_api = GraphApi( + scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], ) # Load Pulumi and SRE configs diff --git a/data_safe_haven/commands/users.py b/data_safe_haven/commands/users.py index efb9e56d91..272f4dc995 100644 --- a/data_safe_haven/commands/users.py +++ b/data_safe_haven/commands/users.py @@ -42,12 +42,12 @@ def add( # Load GraphAPI graph_api = GraphApi( - tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=[ + scopes=[ "Group.Read.All", "User.ReadWrite.All", "UserAuthenticationMethod.ReadWrite.All", ], + tenant_id=shm_config.shm.entra_tenant_id, ) # Add users to SHM @@ -81,8 +81,8 @@ def list_users( # Load GraphAPI graph_api = GraphApi( + scopes=["Directory.Read.All", "Group.Read.All"], tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=["Directory.Read.All", "Group.Read.All"], ) # Load Pulumi config @@ -137,8 +137,8 @@ def register( # Load GraphAPI graph_api = GraphApi( + scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], ) logger.debug( @@ -188,8 +188,8 @@ def remove( # Load GraphAPI graph_api = GraphApi( + scopes=["User.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=["User.ReadWrite.All"], ) # Remove users from SHM @@ -242,8 +242,8 @@ def unregister( # Load GraphAPI graph_api = GraphApi( + scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, - default_scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], ) logger.debug( diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 363c90a2b5..585d1365ce 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -21,6 +21,7 @@ from data_safe_haven.exceptions import ( DataSafeHavenMicrosoftGraphError, ) +from data_safe_haven.external.interface.credentials import GraphApiCredentialLoader from data_safe_haven.functions import alphanumeric from data_safe_haven.logging import get_logger @@ -71,27 +72,23 @@ class GraphApi: def __init__( self, *, - tenant_id: str | None = None, auth_token: str | None = None, - application_id: str | None = None, - application_secret: str | None = None, - base_endpoint: str = "", - default_scopes: Sequence[str] = [], + scopes: Sequence[str] = [], + tenant_id: str | None = None, ): - self.base_endpoint = ( - base_endpoint if base_endpoint else "https://graph.microsoft.com/v1.0" - ) - self.default_scopes = list(default_scopes) + if tenant_id and scopes: + self.authenticator = GraphApiCredentialLoader(tenant_id, list(scopes)) + self.base_endpoint = "https://graph.microsoft.com/v1.0" + self.default_scopes = list(scopes) self.logger = get_logger() self.tenant_id = tenant_id - if auth_token: - self.token = auth_token - elif application_id and application_secret: - self.token = self.create_token_application( - application_id, application_secret - ) - else: - self.token = self.create_token_administrator() + self.auth_token = auth_token + + @property + def token(self) -> str: + if self.auth_token: + return self.auth_token + return self.authenticator.token def add_custom_domain(self, domain_name: str) -> str: """Add Entra ID custom domain diff --git a/data_safe_haven/infrastructure/programs/imperative_shm.py b/data_safe_haven/infrastructure/programs/imperative_shm.py index 3f73da49f9..f59f8480ba 100644 --- a/data_safe_haven/infrastructure/programs/imperative_shm.py +++ b/data_safe_haven/infrastructure/programs/imperative_shm.py @@ -118,12 +118,12 @@ def deploy(self) -> None: try: # Generate the verification record graph_api = GraphApi( - tenant_id=self.config.shm.entra_tenant_id, - default_scopes=[ + scopes=[ "Application.ReadWrite.All", "Domain.ReadWrite.All", "Group.ReadWrite.All", ], + tenant_id=self.config.shm.entra_tenant_id, ) verification_record = graph_api.add_custom_domain(self.config.shm.fqdn) # Add the record to DNS From 05497e18151c9944a8fa33139f934109929e92ce Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 15:30:23 +0100 Subject: [PATCH 10/54] :coffin: Drop unused create_token_administrator and create_token_application methods --- data_safe_haven/external/api/azure_api.py | 4 -- data_safe_haven/external/api/graph_api.py | 87 ----------------------- 2 files changed, 91 deletions(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index f9634e6823..66ce83db66 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -69,11 +69,7 @@ DataSafeHavenAzureError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator from data_safe_haven.external.interface.credentials import AzureApiCredentialLoader -from data_safe_haven.exceptions import ( - DataSafeHavenAzureError, -) from data_safe_haven.logging import get_logger diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 585d1365ce..5933e6c8af 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -12,8 +12,6 @@ import typer from dns import resolver from msal import ( - ConfidentialClientApplication, - PublicClientApplication, SerializableTokenCache, ) @@ -79,9 +77,7 @@ def __init__( if tenant_id and scopes: self.authenticator = GraphApiCredentialLoader(tenant_id, list(scopes)) self.base_endpoint = "https://graph.microsoft.com/v1.0" - self.default_scopes = list(scopes) self.logger = get_logger() - self.tenant_id = tenant_id self.auth_token = auth_token @property @@ -381,89 +377,6 @@ def ensure_application_service_principal( msg = f"Could not create service principal for application '{application_name}'." raise DataSafeHavenMicrosoftGraphError(msg) from exc - def create_token_administrator(self) -> str: - """Create an access token for a global administrator - - Raises: - DataSafeHavenMicrosoftGraphError if the token could not be created - """ - result = None - try: - # Load local token cache - local_token_cache = LocalTokenCache( - pathlib.Path.home() / f".msal_cache_{self.tenant_id}" - ) - # Use the Powershell application by default as this should be pre-installed - app = PublicClientApplication( - authority=f"https://login.microsoftonline.com/{self.tenant_id}", - client_id="14d82eec-204b-4c2f-b7e8-296a70dab67e", # this is the Powershell client id - token_cache=local_token_cache, - ) - # Attempt to load token from cache - if accounts := app.get_accounts(): - result = app.acquire_token_silent( - self.default_scopes, account=accounts[0] - ) - # Initiate device code flow - if not result: - flow = app.initiate_device_flow(scopes=self.default_scopes) - if "user_code" not in flow: - msg = f"Could not initiate device login for scopes {self.default_scopes}." - raise DataSafeHavenMicrosoftGraphError(msg) - self.logger.info( - "Administrator approval is needed in order to interact with Entra ID." - ) - self.logger.info( - "Please sign-in with [bold]global administrator[/] credentials for" - f" Entra tenant '[green]{self.tenant_id}[/]'." - ) - self.logger.info( - "Note that the sign-in screen will prompt you to sign-in to" - " [blue]Microsoft Graph Command Line Tools[/] - this is expected." - ) - self.logger.info(flow["message"]) - # Block until a response is received - result = app.acquire_token_by_device_flow(flow) - return str(result["access_token"]) - except Exception as exc: - error_description = "Could not create Microsoft Graph access token." - if isinstance(result, dict) and "error_description" in result: - error_description += f"\n{result['error_description']}." - msg = f"{error_description}" - raise DataSafeHavenMicrosoftGraphError(msg) from exc - - def create_token_application( - self, application_id: str, application_secret: str - ) -> str: - """Return an access token for the given application ID and secret - - Raises: - DataSafeHavenMicrosoftGraphError if the token could not be created - """ - result = None - try: - # Use a created application - app = ConfidentialClientApplication( - client_id=application_id, - client_credential=application_secret, - authority=f"https://login.microsoftonline.com/{self.tenant_id}", - ) - # Block until a response is received - # For this call the scopes are pre-defined by the application privileges - result = app.acquire_token_for_client( - scopes=["https://graph.microsoft.com/.default"] - ) - if not isinstance(result, dict): - msg = "Invalid application token returned from Microsoft Graph." - raise DataSafeHavenMicrosoftGraphError(msg) - return str(result["access_token"]) - except Exception as exc: - error_description = "Could not create access token" - if result and "error_description" in result: - error_description += f": {result['error_description']}" - msg = f"{error_description}." - raise DataSafeHavenMicrosoftGraphError(msg) from exc - def create_user( self, request_json: dict[str, Any], From 613dc87f63f30bdaa6d0fd01e13171fcc5828f4a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 17:04:01 +0100 Subject: [PATCH 11/54] :sparkles: Add a custom callback to the GraphApiCredentialLoader prompt --- data_safe_haven/external/api/azure_cli.py | 6 ++--- .../external/interface/credentials.py | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index 35fcfaf9a1..3cd5c337a5 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -73,11 +73,11 @@ def confirm(self) -> None: account = self.account self.logger.info( - "You are currently logged into the Azure CLI with the following details:" + "You are currently logged into the [blue]Azure CLI[/] with the following details:" ) - self.logger.info(f"... Azure user: [blue]{account.name}[/] ({account.id_})") + self.logger.info(f"... user: [green]{account.name}[/] ({account.id_})") self.logger.info( - f"... Azure tenant: [blue]{account.tenant_name}[/] ({account.tenant_id})" + f"... tenant: [green]{account.tenant_name}[/] ({account.tenant_id})" ) if not console.confirm( "Is this the Azure account you expect?", default_to_yes=False diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index a9fd2a7057..07b2738314 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -3,6 +3,7 @@ import pathlib from abc import ABCMeta, abstractmethod from collections.abc import Sequence +from datetime import datetime from azure.core.credentials import TokenCredential from azure.identity import ( @@ -13,6 +14,7 @@ ) from data_safe_haven.external.api.azure_cli import AzureCliSingleton +from data_safe_haven.logging import get_logger class DeferredCredentialLoader(metaclass=ABCMeta): @@ -79,6 +81,7 @@ def __init__( scopes: Sequence[str] = [], ) -> None: super().__init__(scopes=scopes, tenant_id=tenant_id) + self.logger = get_logger() def get_credential(self) -> TokenCredential: """Get a new DeviceCodeCredential, using cached credentials if they are available""" @@ -99,9 +102,18 @@ def get_credential(self) -> TokenCredential: kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" kwargs["tenant_id"] = self.tenant_id - # Get a credential + # Get a credential with a custom callback + def callback(verification_uri: str, user_code: str, _: datetime) -> None: + self.logger.info( + f"Go to [bold]{verification_uri}[/] in a web browser and enter the code [bold]{user_code}[/] at the prompt." + ) + self.logger.info( + "Use [bold]global administrator credentials[/] for your [blue]Entra ID directory[/] to sign-in." + ) + credential = DeviceCodeCredential( cache_persistence_options=TokenCachePersistenceOptions(name=cache_name), + prompt_callback=callback, **kwargs, ) @@ -110,5 +122,16 @@ def get_credential(self) -> TokenCredential: with open(authentication_record_path, "w") as f_auth: f_auth.write(new_auth_record.serialize()) + # Write confirmation details about this login + self.logger.info( + "You are currently logged into the [blue]Microsoft Graph API[/] with the following details:" + ) + self.logger.info( + f"... user: [green]{new_auth_record.username}[/] ({new_auth_record._home_account_id.split('.')[0]})" + ) + self.logger.info( + f"... tenant: [green]{new_auth_record._username.split('@')[1]}[/] ({new_auth_record._tenant_id})" + ) + # Return the credential return credential From d92f605f2fd5e3aa2dcbce79d2e501f3f4aabc24 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 17:52:30 +0100 Subject: [PATCH 12/54] :art: Simplify the GraphApi constructor by using an auth_token to build a new GraphApiCredentialLoader rather than using the token directly --- data_safe_haven/external/api/graph_api.py | 27 ++++++++++++------- .../components/dynamic/entra_application.py | 6 ++--- .../provisioning/sre_provisioning_manager.py | 2 +- pyproject.toml | 2 ++ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 5933e6c8af..9bc37613c4 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -6,8 +6,9 @@ from collections.abc import Sequence from contextlib import suppress from io import UnsupportedOperation -from typing import Any, ClassVar +from typing import Any, ClassVar, Self +import jwt import requests import typer from dns import resolver @@ -18,6 +19,7 @@ from data_safe_haven import console from data_safe_haven.exceptions import ( DataSafeHavenMicrosoftGraphError, + DataSafeHavenValueError, ) from data_safe_haven.external.interface.credentials import GraphApiCredentialLoader from data_safe_haven.functions import alphanumeric @@ -70,20 +72,27 @@ class GraphApi: def __init__( self, *, - auth_token: str | None = None, - scopes: Sequence[str] = [], - tenant_id: str | None = None, + scopes: Sequence[str], + tenant_id: str, ): - if tenant_id and scopes: - self.authenticator = GraphApiCredentialLoader(tenant_id, list(scopes)) + self.authenticator = GraphApiCredentialLoader(tenant_id, list(scopes)) self.base_endpoint = "https://graph.microsoft.com/v1.0" self.logger = get_logger() - self.auth_token = auth_token + + @classmethod + def from_token(cls: type[Self], auth_token: str) -> "GraphApi": + """Construct a GraphApi from an existing authentication token.""" + try: + decoded = jwt.decode( + auth_token, algorithms=["RS256"], options={"verify_signature": False} + ) + return cls(scopes=str(decoded["scp"]).split(), tenant_id=decoded["tid"]) + except (jwt.exceptions.DecodeError, KeyError) as exc: + msg = "Could not interpret Graph API authentication token." + raise DataSafeHavenValueError(msg) from exc @property def token(self) -> str: - if self.auth_token: - return self.auth_token return self.authenticator.token def add_custom_domain(self, domain_name: str) -> str: diff --git a/data_safe_haven/infrastructure/components/dynamic/entra_application.py b/data_safe_haven/infrastructure/components/dynamic/entra_application.py index 475a5e95cf..5cc909af41 100644 --- a/data_safe_haven/infrastructure/components/dynamic/entra_application.py +++ b/data_safe_haven/infrastructure/components/dynamic/entra_application.py @@ -41,7 +41,7 @@ def refresh(self, props: dict[str, Any]) -> dict[str, Any]: try: outs = dict(**props) with suppress(DataSafeHavenMicrosoftGraphError): - graph_api = GraphApi(auth_token=self.auth_token) + graph_api = GraphApi.from_token(self.auth_token) if json_response := graph_api.get_application_by_name( outs["application_name"] ): @@ -67,7 +67,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: """Create new Entra application.""" outs = dict(**props) try: - graph_api = GraphApi(auth_token=self.auth_token) + graph_api = GraphApi.from_token(self.auth_token) request_json = { "displayName": props["application_name"], "signInAudience": "AzureADMyOrg", @@ -123,7 +123,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) try: - graph_api = GraphApi(auth_token=self.auth_token) + graph_api = GraphApi.from_token(self.auth_token) graph_api.delete_application(props["application_name"]) except Exception as exc: msg = f"Failed to delete application '{props['application_name']}' from Entra ID." diff --git a/data_safe_haven/provisioning/sre_provisioning_manager.py b/data_safe_haven/provisioning/sre_provisioning_manager.py index 9b685cb71f..096739d742 100644 --- a/data_safe_haven/provisioning/sre_provisioning_manager.py +++ b/data_safe_haven/provisioning/sre_provisioning_manager.py @@ -28,7 +28,7 @@ def __init__( ): self._available_vm_skus: dict[str, dict[str, Any]] | None = None self.location = location - self.graph_api = GraphApi(auth_token=graph_api_token) + self.graph_api = GraphApi.from_token(graph_api_token) self.logger = get_logger() self.sre_name = sre_name self.subscription_name = subscription_name diff --git a/pyproject.toml b/pyproject.toml index 4f302a7719..f1f7779307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pulumi-azure-native>=2.14", "pulumi-random>=4.14", "pydantic>=2.4", + "pyjwt>=2.8", "pytz>=2023.3", "PyYAML>=6.0", "rich>=13.4", @@ -144,6 +145,7 @@ module = [ "azure.storage.*", "cryptography.*", "dns.*", + "jwt.*", "msal.*", "numpy.*", "pandas.*", From 4a3b3d87217fb14619e2efc3d8b4e5a2149940d9 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 18:01:25 +0100 Subject: [PATCH 13/54] :loud_sound: Log AzureCli command line error separately from calling exception --- data_safe_haven/external/api/azure_cli.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index 3cd5c337a5..135a41752a 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -47,12 +47,11 @@ def account(self) -> AzureCliAccount: stderr=subprocess.PIPE, encoding="utf8", ) + result_dict = json.loads(result) except subprocess.CalledProcessError as exc: - msg = f"Error getting account information from Azure CLI.\n{exc.stderr}" + self.logger.error(str(exc.stderr).replace("ERROR:", "").strip()) + msg = "Error getting account information from Azure CLI." raise DataSafeHavenAzureError(msg) from exc - - try: - result_dict = json.loads(result) except json.JSONDecodeError as exc: msg = f"Unable to parse Azure CLI output as JSON.\n{result}" raise DataSafeHavenAzureError(msg) from exc From 613b4c2e2ab2c5de0d9c71db36523a5f4fe88cb4 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 23:24:56 +0100 Subject: [PATCH 14/54] :recycle: Make DeferredCredential a TokenCredential with class-level token caching --- data_safe_haven/commands/sre.py | 3 +- data_safe_haven/external/api/azure_api.py | 9 +--- data_safe_haven/external/api/graph_api.py | 6 +-- .../external/interface/credentials.py | 52 +++++++++++-------- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/data_safe_haven/commands/sre.py b/data_safe_haven/commands/sre.py index 825efc80b0..3172b73766 100644 --- a/data_safe_haven/commands/sre.py +++ b/data_safe_haven/commands/sre.py @@ -39,7 +39,7 @@ def deploy( context = ContextManager.from_file().assert_context() shm_config = SHMConfig.from_remote(context) - # Load GraphAPI as this may require user-interaction + # Load GraphAPI graph_api = GraphApi( scopes=[ "Application.ReadWrite.All", @@ -65,6 +65,7 @@ def deploy( raise DataSafeHavenConfigError(msg) # Initialise Pulumi stack + # Note that requesting a GraphApi token will trigger possible user-interaction stack = SREProjectManager( context=context, config=sre_config, diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 66ce83db66..406254bb76 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -4,7 +4,6 @@ from contextlib import suppress from typing import Any, cast -from azure.core.credentials import TokenCredential from azure.core.exceptions import ( AzureError, ClientAuthenticationError, @@ -69,7 +68,7 @@ DataSafeHavenAzureError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.credentials import AzureApiCredentialLoader +from data_safe_haven.external.interface.credentials import AzureApiCredential from data_safe_haven.logging import get_logger @@ -78,15 +77,11 @@ class AzureApi: def __init__(self, subscription_name: str) -> None: self.logger = get_logger() - self.authenticator = AzureApiCredentialLoader() + self.credential = AzureApiCredential() self.subscription_name = subscription_name self.subscription_id_: str | None = None self.tenant_id_: str | None = None - @property - def credential(self) -> TokenCredential: - return self.authenticator.credential - @property def subscription_id(self) -> str: if not self.subscription_id_: diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 9bc37613c4..d94cb33c1d 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -21,7 +21,7 @@ DataSafeHavenMicrosoftGraphError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.credentials import GraphApiCredentialLoader +from data_safe_haven.external.interface.credentials import GraphApiCredential from data_safe_haven.functions import alphanumeric from data_safe_haven.logging import get_logger @@ -75,8 +75,8 @@ def __init__( scopes: Sequence[str], tenant_id: str, ): - self.authenticator = GraphApiCredentialLoader(tenant_id, list(scopes)) self.base_endpoint = "https://graph.microsoft.com/v1.0" + self.credential = GraphApiCredential(tenant_id, list(scopes)) self.logger = get_logger() @classmethod @@ -93,7 +93,7 @@ def from_token(cls: type[Self], auth_token: str) -> "GraphApi": @property def token(self) -> str: - return self.authenticator.token + return self.credential.token def add_custom_domain(self, domain_name: str) -> str: """Add Entra ID custom domain diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index 07b2738314..88de571788 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -1,11 +1,12 @@ """Classes related to Azure credentials""" import pathlib -from abc import ABCMeta, abstractmethod +from abc import abstractmethod from collections.abc import Sequence -from datetime import datetime +from datetime import UTC, datetime +from typing import Any -from azure.core.credentials import TokenCredential +from azure.core.credentials import AccessToken, TokenCredential from azure.identity import ( AuthenticationRecord, AzureCliCredential, @@ -17,42 +18,48 @@ from data_safe_haven.logging import get_logger -class DeferredCredentialLoader(metaclass=ABCMeta): - """A wrapper class that initialises and caches credentials as they are needed""" +class DeferredCredential(TokenCredential): + """A token credential that wraps and caches other credential classes.""" + + token_: AccessToken | None = None def __init__( self, scopes: Sequence[str], tenant_id: str | None = None, ) -> None: - self.credential_: TokenCredential | None = None + self.logger = get_logger() self.scopes = scopes self.tenant_id = tenant_id - @property - def credential(self) -> TokenCredential: - """Return the cached credential provider.""" - if not self.credential_: - self.credential_ = self.get_credential() - return self.credential_ - @property def token(self) -> str: """Get a token from the credential provider.""" - return str( - self.credential.get_token( - *self.scopes, - tenant_id=self.tenant_id, - ).token - ) + return str(self.get_token(*self.scopes, tenant_id=self.tenant_id).token) @abstractmethod def get_credential(self) -> TokenCredential: - """Get new credential provider.""" + """Get a credential provider from the child class.""" pass + def get_token( + self, + *scopes: str, + **kwargs: Any, + ) -> AccessToken: + # Require at least 10 minutes of remaining validity + validity_cutoff = datetime.now(tz=UTC).timestamp() + 10 * 60 + if not DeferredCredential.token_ or ( + DeferredCredential.token_.expires_on < validity_cutoff + ): + # Generate a new token and store it at class-level token + DeferredCredential.token_ = self.get_credential().get_token( + *scopes, **kwargs + ) + return DeferredCredential.token_ -class AzureApiCredentialLoader(DeferredCredentialLoader): + +class AzureApiCredential(DeferredCredential): """ Credential loader used by AzureApi @@ -68,7 +75,7 @@ def get_credential(self) -> TokenCredential: return AzureCliCredential(additionally_allowed_tenants=["*"]) -class GraphApiCredentialLoader(DeferredCredentialLoader): +class GraphApiCredential(DeferredCredential): """ Credential loader used by GraphApi @@ -81,7 +88,6 @@ def __init__( scopes: Sequence[str] = [], ) -> None: super().__init__(scopes=scopes, tenant_id=tenant_id) - self.logger = get_logger() def get_credential(self) -> TokenCredential: """Get a new DeviceCodeCredential, using cached credentials if they are available""" From 5e43ca6831c09bd25822cf77db0be3c8f33c40a5 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 4 Jul 2024 23:28:08 +0100 Subject: [PATCH 15/54] :coffin: Drop unused LocalTokenCache --- data_safe_haven/external/api/graph_api.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index d94cb33c1d..7bc8b97b8e 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -1,20 +1,15 @@ """Interface to the Microsoft Graph API""" import datetime -import pathlib import time from collections.abc import Sequence from contextlib import suppress -from io import UnsupportedOperation from typing import Any, ClassVar, Self import jwt import requests import typer from dns import resolver -from msal import ( - SerializableTokenCache, -) from data_safe_haven import console from data_safe_haven.exceptions import ( @@ -26,22 +21,6 @@ from data_safe_haven.logging import get_logger -class LocalTokenCache(SerializableTokenCache): - def __init__(self, token_cache_filename: pathlib.Path) -> None: - super().__init__() - self.token_cache_filename = token_cache_filename - try: - if self.token_cache_filename.exists(): - with open(self.token_cache_filename, encoding="utf-8") as f_token: - self.deserialize(f_token.read()) - except (FileNotFoundError, UnsupportedOperation): - self.deserialize(None) - - def __del__(self) -> None: - with open(self.token_cache_filename, "w", encoding="utf-8") as f_token: - f_token.write(self.serialize()) - - class GraphApi: """Interface to the Microsoft Graph REST API""" From fc3816249914db82a5aa4fd5a454e5937b6f1bae Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 5 Jul 2024 01:28:11 +0100 Subject: [PATCH 16/54] :white_check_mark: Fix failing tests --- tests/commands/conftest.py | 21 ++------------------- tests/commands/test_pulumi.py | 6 +++--- tests/commands/test_shm.py | 6 +++--- tests/commands/test_sre.py | 8 ++++---- tests/conftest.py | 24 ++++++++++++++++++++++-- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/tests/commands/conftest.py b/tests/commands/conftest.py index 2751098efb..e5ac3a4646 100644 --- a/tests/commands/conftest.py +++ b/tests/commands/conftest.py @@ -13,7 +13,6 @@ DataSafeHavenAzureError, ) from data_safe_haven.external import AzureApi, GraphApi -from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator from data_safe_haven.infrastructure import ImperativeSHM, SREProjectManager @@ -27,20 +26,6 @@ def mock_azure_api_blob_exists_false(mocker): mocker.patch.object(AzureApi, "blob_exists", return_value=False) -@fixture -def mock_azure_authenticator_login_exception(mocker): - def login_then_exit(): - print("mock login") # noqa: T201 - msg = "mock login error" - raise DataSafeHavenAzureAPIAuthenticationError(msg) - - mocker.patch.object( - AzureAuthenticator, - "login", - side_effect=login_then_exit, - ) - - @fixture def mock_graph_api_add_custom_domain(mocker): mocker.patch.object( @@ -49,10 +34,8 @@ def mock_graph_api_add_custom_domain(mocker): @fixture -def mock_graph_api_create_token_administrator(mocker): - mocker.patch.object( - GraphApi, "create_token_administrator", return_value="dummy-token" - ) +def mock_graph_api_token(mocker): + mocker.patch.object(GraphApi, "token", return_value="dummy-token") @fixture diff --git a/tests/commands/test_pulumi.py b/tests/commands/test_pulumi.py index 4e86af1049..8182ec501c 100644 --- a/tests/commands/test_pulumi.py +++ b/tests/commands/test_pulumi.py @@ -8,7 +8,7 @@ def test_run_sre( mock_shm_config_from_remote, # noqa: ARG002 mock_sre_config_from_remote, # noqa: ARG002 mock_pulumi_config_no_key_from_remote, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 @@ -33,7 +33,7 @@ def test_run_sre_invalid_command( mock_shm_config_from_remote, # noqa: ARG002 mock_sre_config_from_remote, # noqa: ARG002 mock_pulumi_config_no_key_from_remote, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 @@ -52,7 +52,7 @@ def test_run_sre_invalid_name( mock_shm_config_from_remote, # noqa: ARG002 mock_sre_config_alternate_from_remote, # noqa: ARG002 mock_pulumi_config_no_key_from_remote, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 diff --git a/tests/commands/test_shm.py b/tests/commands/test_shm.py index 19c8fc0575..e426b43f5e 100644 --- a/tests/commands/test_shm.py +++ b/tests/commands/test_shm.py @@ -7,7 +7,7 @@ def test_infrastructure_deploy( runner, mock_imperative_shm_deploy_then_exit, # noqa: ARG002 mock_graph_api_add_custom_domain, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_shm_config_from_remote, # noqa: ARG002 mock_shm_config_remote_exists, # noqa: ARG002 mock_shm_config_upload, # noqa: ARG002 @@ -30,7 +30,7 @@ def test_infrastructure_show_none(self, runner_none): def test_infrastructure_auth_failure( self, runner, - mock_azure_authenticator_login_exception, # noqa: ARG002 + mock_azure_cli_confirm_then_exit, # noqa: ARG002 ): result = runner.invoke(shm_command_group, ["deploy"]) assert result.exit_code == 1 @@ -62,7 +62,7 @@ def test_show_none(self, runner_none): def test_auth_failure( self, runner, - mock_azure_authenticator_login_exception, # noqa: ARG002 + mock_azure_cli_confirm_then_exit, # noqa: ARG002 ): result = runner.invoke(shm_command_group, ["teardown"]) assert result.exit_code == 1 diff --git a/tests/commands/test_sre.py b/tests/commands/test_sre.py index 548ff89f83..4b1bc887da 100644 --- a/tests/commands/test_sre.py +++ b/tests/commands/test_sre.py @@ -6,7 +6,7 @@ def test_deploy( self, runner, mock_azure_cli_confirm, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote_or_create, # noqa: ARG002 mock_pulumi_config_upload, # noqa: ARG002 @@ -27,7 +27,7 @@ def test_no_context_file(self, runner_no_context_file): def test_auth_failure( self, runner, - mock_azure_authenticator_login_exception, # noqa: ARG002 + mock_azure_cli_confirm_then_exit, # noqa: ARG002 ): result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) assert result.exit_code == 1 @@ -51,7 +51,7 @@ def test_teardown( self, runner, mock_azure_cli_confirm, # noqa: ARG002 - mock_graph_api_create_token_administrator, # noqa: ARG002 + mock_graph_api_token, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote, # noqa: ARG002 mock_shm_config_from_remote, # noqa: ARG002 @@ -83,7 +83,7 @@ def test_no_shm( def test_auth_failure( self, runner, - mock_azure_authenticator_login_exception, # noqa: ARG002 + mock_azure_cli_confirm_then_exit, # noqa: ARG002 ): result = runner.invoke(sre_command_group, ["teardown", "sandbox"]) assert result.exit_code == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 27e89296ed..f7fd8234e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from pulumi.automation import ProjectSettings from pytest import fixture +import data_safe_haven.commands.sre as sre_mod import data_safe_haven.config.context_manager as context_mod import data_safe_haven.logging.logger from data_safe_haven.config import ( @@ -22,6 +23,7 @@ ConfigSectionSRE, ConfigSubsectionRemoteDesktopOpts, ) +from data_safe_haven.exceptions import DataSafeHavenAzureAPIAuthenticationError from data_safe_haven.external import AzureApi, AzureCliSingleton, PulumiAccount from data_safe_haven.infrastructure import SREProjectManager from data_safe_haven.infrastructure.project_manager import ProjectManager @@ -123,9 +125,27 @@ def log_directory(session_mocker, tmp_path_factory): @fixture -def mock_azure_cli_confirm(monkeypatch): +def mock_azure_cli_confirm(mocker): """Always pass AzureCliSingleton.confirm without attempting login""" - monkeypatch.setattr(AzureCliSingleton, "confirm", lambda self: None) # noqa: ARG005 + mocker.patch.object( + AzureCliSingleton, + "confirm", + return_value=None, + ) + + +@fixture +def mock_azure_cli_confirm_then_exit(mocker): + def confirm_then_exit(): + print("mock login") # noqa: T201 + msg = "mock login error" + raise DataSafeHavenAzureAPIAuthenticationError(msg) + + mocker.patch.object( + AzureCliSingleton, + "confirm", + side_effect=confirm_then_exit, + ) @fixture From 780905775fe73fc86155005e8643c6b5e21ebdff Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 10:17:52 +0100 Subject: [PATCH 17/54] :sparkles: Enable AzureApi to make GraphApi calls using the graph.microsoft.com/.default scope (note that this only allows a subset of functions). --- data_safe_haven/external/api/azure_api.py | 8 ++++++++ data_safe_haven/external/api/graph_api.py | 16 ++++++++++++---- .../external/interface/credentials.py | 4 ++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 406254bb76..1d752e80b4 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -71,6 +71,8 @@ from data_safe_haven.external.interface.credentials import AzureApiCredential from data_safe_haven.logging import get_logger +from .graph_api import GraphApi + class AzureApi: """Interface to the Azure REST API""" @@ -82,6 +84,12 @@ def __init__(self, subscription_name: str) -> None: self.subscription_id_: str | None = None self.tenant_id_: str | None = None + @property + def entra_directory(self) -> GraphApi: + return GraphApi( + credential=AzureApiCredential("https://graph.microsoft.com//.default"), + ) + @property def subscription_id(self) -> str: if not self.subscription_id_: diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 7bc8b97b8e..d7be843738 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -16,7 +16,10 @@ DataSafeHavenMicrosoftGraphError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.credentials import GraphApiCredential +from data_safe_haven.external.interface.credentials import ( + DeferredCredential, + GraphApiCredential, +) from data_safe_haven.functions import alphanumeric from data_safe_haven.logging import get_logger @@ -51,11 +54,16 @@ class GraphApi: def __init__( self, *, - scopes: Sequence[str], - tenant_id: str, + scopes: Sequence[str] | None = None, + tenant_id: str | None = None, + credential: DeferredCredential | None = None, ): self.base_endpoint = "https://graph.microsoft.com/v1.0" - self.credential = GraphApiCredential(tenant_id, list(scopes)) + self.credential = ( + credential + if credential + else GraphApiCredential(str(tenant_id), list(scopes) if scopes else []) + ) self.logger = get_logger() @classmethod diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index 88de571788..7f2b7ed92d 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -66,8 +66,8 @@ class AzureApiCredential(DeferredCredential): Uses AzureCliCredential for authentication """ - def __init__(self) -> None: - super().__init__(scopes=["https://management.azure.com/.default"]) + def __init__(self, scope: str = "https://management.azure.com/.default") -> None: + super().__init__(scopes=[scope]) def get_credential(self) -> TokenCredential: """Get a new AzureCliCredential.""" From d16997f7132380354d24e7bdd64ce30a97fe508c Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 11:07:02 +0100 Subject: [PATCH 18/54] :recycle: Simplify GraphApi constructor so that it will usually be called through classmethods --- data_safe_haven/commands/pulumi.py | 2 +- data_safe_haven/commands/sre.py | 4 ++-- data_safe_haven/commands/users.py | 10 +++++----- data_safe_haven/external/api/graph_api.py | 20 ++++++++++--------- .../infrastructure/programs/imperative_shm.py | 2 +- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/data_safe_haven/commands/pulumi.py b/data_safe_haven/commands/pulumi.py index 3949467636..a3c6fec243 100644 --- a/data_safe_haven/commands/pulumi.py +++ b/data_safe_haven/commands/pulumi.py @@ -41,7 +41,7 @@ def run( shm_config = SHMConfig.from_remote(context) sre_config = SREConfig.from_remote_by_name(context, sre_name) - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=[ "Application.ReadWrite.All", "AppRoleAssignment.ReadWrite.All", diff --git a/data_safe_haven/commands/sre.py b/data_safe_haven/commands/sre.py index 3172b73766..c4c992e33e 100644 --- a/data_safe_haven/commands/sre.py +++ b/data_safe_haven/commands/sre.py @@ -40,7 +40,7 @@ def deploy( shm_config = SHMConfig.from_remote(context) # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=[ "Application.ReadWrite.All", "AppRoleAssignment.ReadWrite.All", @@ -146,7 +146,7 @@ def teardown( shm_config = SHMConfig.from_remote(context) # Load GraphAPI as this may require user-interaction - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, ) diff --git a/data_safe_haven/commands/users.py b/data_safe_haven/commands/users.py index 272f4dc995..596e386737 100644 --- a/data_safe_haven/commands/users.py +++ b/data_safe_haven/commands/users.py @@ -41,7 +41,7 @@ def add( raise # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=[ "Group.Read.All", "User.ReadWrite.All", @@ -80,7 +80,7 @@ def list_users( raise # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=["Directory.Read.All", "Group.Read.All"], tenant_id=shm_config.shm.entra_tenant_id, ) @@ -136,7 +136,7 @@ def register( raise DataSafeHavenError(msg) # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, ) @@ -187,7 +187,7 @@ def remove( raise # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=["User.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, ) @@ -241,7 +241,7 @@ def unregister( raise DataSafeHavenError(msg) # Load GraphAPI - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=["Group.ReadWrite.All", "GroupMember.ReadWrite.All"], tenant_id=shm_config.shm.entra_tenant_id, ) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index d7be843738..017d63f3f5 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -54,18 +54,18 @@ class GraphApi: def __init__( self, *, - scopes: Sequence[str] | None = None, - tenant_id: str | None = None, - credential: DeferredCredential | None = None, + credential: DeferredCredential, ): self.base_endpoint = "https://graph.microsoft.com/v1.0" - self.credential = ( - credential - if credential - else GraphApiCredential(str(tenant_id), list(scopes) if scopes else []) - ) + self.credential = credential self.logger = get_logger() + @classmethod + def from_scopes( + cls: type[Self], *, scopes: Sequence[str], tenant_id: str + ) -> "GraphApi": + return cls(credential=GraphApiCredential(tenant_id, scopes)) + @classmethod def from_token(cls: type[Self], auth_token: str) -> "GraphApi": """Construct a GraphApi from an existing authentication token.""" @@ -73,7 +73,9 @@ def from_token(cls: type[Self], auth_token: str) -> "GraphApi": decoded = jwt.decode( auth_token, algorithms=["RS256"], options={"verify_signature": False} ) - return cls(scopes=str(decoded["scp"]).split(), tenant_id=decoded["tid"]) + return cls.from_scopes( + scopes=str(decoded["scp"]).split(), tenant_id=decoded["tid"] + ) except (jwt.exceptions.DecodeError, KeyError) as exc: msg = "Could not interpret Graph API authentication token." raise DataSafeHavenValueError(msg) from exc diff --git a/data_safe_haven/infrastructure/programs/imperative_shm.py b/data_safe_haven/infrastructure/programs/imperative_shm.py index f59f8480ba..e003980660 100644 --- a/data_safe_haven/infrastructure/programs/imperative_shm.py +++ b/data_safe_haven/infrastructure/programs/imperative_shm.py @@ -117,7 +117,7 @@ def deploy(self) -> None: # Add the SHM domain to the Entra ID via interactive GraphAPI try: # Generate the verification record - graph_api = GraphApi( + graph_api = GraphApi.from_scopes( scopes=[ "Application.ReadWrite.All", "Domain.ReadWrite.All", From 367115d4c57e8cbe0cb059e00b3b4c4d74c7107c Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 12:24:19 +0100 Subject: [PATCH 19/54] :alien: Hack GET responses to include all values from paged content --- data_safe_haven/external/api/graph_api.py | 35 ++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 017d63f3f5..07f462ed17 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -1,6 +1,7 @@ """Interface to the Microsoft Graph API""" import datetime +import json import time from collections.abc import Sequence from contextlib import suppress @@ -691,7 +692,7 @@ def http_delete(self, url: str, **kwargs: Any) -> requests.Response: msg = f"Could not execute DELETE request to '{url}'. Response content received: '{response.content.decode()}'." raise DataSafeHavenMicrosoftGraphError(msg) - def http_get(self, url: str, **kwargs: Any) -> requests.Response: + def http_get_single_page(self, url: str, **kwargs: Any) -> requests.Response: """Make an HTTP GET request Returns: @@ -718,6 +719,38 @@ def http_get(self, url: str, **kwargs: Any) -> requests.Response: msg = f"Could not execute GET request to '{url}'. Response content received: '{response.content.decode()}'. Token {self.token}" raise DataSafeHavenMicrosoftGraphError(msg) + def http_get(self, url: str, **kwargs: Any) -> requests.Response: + """Make a paged HTTP GET request and return all values + + Returns: + requests.Response: The response from the remote server, with all values combined + + Raises: + DataSafeHavenMicrosoftGraphError if the request failed + """ + try: + base_url = url + values = [] + + # Keep requesting new pages until there are no more + while True: + response = self.http_get_single_page(url, **kwargs) + values += response.json()["value"] + url = response.json().get("@odata.nextLink", None) + if not url: + break + + # Add previous response values into the content bytes + json_content = response.json() + json_content["value"] = values + response._content = json.dumps(json_content).encode("utf-8") + + # Return the full response + return response + except requests.exceptions.RequestException as exc: + msg = f"Could not execute GET request to '{base_url}'. Response content received: '{response.content.decode()}'. Token {self.token}" + raise DataSafeHavenMicrosoftGraphError(msg) from exc + def http_patch(self, url: str, **kwargs: Any) -> requests.Response: """Make an HTTP PATCH request From 08a01ad7f6bdc26229853568739220a28cd38c9e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 12:26:54 +0100 Subject: [PATCH 20/54] :sparkles: Get Azure group name from AzureApi (via GraphApi) --- data_safe_haven/config/shm_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_safe_haven/config/shm_config.py b/data_safe_haven/config/shm_config.py index 44339deb57..8df5834511 100644 --- a/data_safe_haven/config/shm_config.py +++ b/data_safe_haven/config/shm_config.py @@ -4,7 +4,7 @@ from typing import ClassVar, Self -from data_safe_haven.external import AzureApi, AzureCliSingleton +from data_safe_haven.external import AzureApi from data_safe_haven.serialisers import AzureSerialisableModel, ContextBase from .config_sections import ConfigSectionAzure, ConfigSectionSHM @@ -27,7 +27,7 @@ def from_args( ) -> SHMConfig: """Construct an SHMConfig from arguments.""" azure_api = AzureApi(subscription_name=context.subscription_name) - admin_group_id = AzureCliSingleton().group_id_from_name( + admin_group_id = azure_api.entra_directory.get_id_from_groupname( context.admin_group_name ) return SHMConfig.model_construct( From e6aae728271f1461455d86660d1cd602b7746494 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 12:53:51 +0100 Subject: [PATCH 21/54] :truck: Run Azure account verification in AzureApiCredential --- .../external/interface/credentials.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/interface/credentials.py index 7f2b7ed92d..61eeefdc56 100644 --- a/data_safe_haven/external/interface/credentials.py +++ b/data_safe_haven/external/interface/credentials.py @@ -6,15 +6,17 @@ from datetime import UTC, datetime from typing import Any +import jwt from azure.core.credentials import AccessToken, TokenCredential from azure.identity import ( AuthenticationRecord, AzureCliCredential, + CredentialUnavailableError, DeviceCodeCredential, TokenCachePersistenceOptions, ) -from data_safe_haven.external.api.azure_cli import AzureCliSingleton +from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.logging import get_logger @@ -71,8 +73,29 @@ def __init__(self, scope: str = "https://management.azure.com/.default") -> None def get_credential(self) -> TokenCredential: """Get a new AzureCliCredential.""" - AzureCliSingleton().confirm() # get user confirmation of the current account - return AzureCliCredential(additionally_allowed_tenants=["*"]) + credential = AzureCliCredential(additionally_allowed_tenants=["*"]) + # Check that we are logged into Azure + try: + token = credential.get_token(*self.scopes).token + decoded = jwt.decode( + token, algorithms=["RS256"], options={"verify_signature": False} + ) + self.logger.info( + "You are currently logged into the [blue]Azure CLI[/] with the following details:" + ) + self.logger.info( + f"... user: [green]{decoded['name']}[/] ({decoded['oid']})" + ) + self.logger.info( + f"... tenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" + ) + except CredentialUnavailableError as exc: + self.logger.error( + "Please authenticate with Azure: run '[green]az login[/]' using [bold]infrastructure administrator[/] credentials." + ) + msg = "Error getting account information from Azure CLI." + raise DataSafeHavenAzureError(msg) from exc + return credential class GraphApiCredential(DeferredCredential): From 6acd336a862acfd26865a851d3a28bb1b92b7260 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 12:55:19 +0100 Subject: [PATCH 22/54] :coffin: Drop unused AzureCli class --- data_safe_haven/external/__init__.py | 2 - data_safe_haven/external/api/azure_cli.py | 139 ---------------------- 2 files changed, 141 deletions(-) delete mode 100644 data_safe_haven/external/api/azure_cli.py diff --git a/data_safe_haven/external/__init__.py b/data_safe_haven/external/__init__.py index 9dc3780e64..f0eb5a42fb 100644 --- a/data_safe_haven/external/__init__.py +++ b/data_safe_haven/external/__init__.py @@ -1,5 +1,4 @@ from .api.azure_api import AzureApi -from .api.azure_cli import AzureCliSingleton from .api.graph_api import GraphApi from .interface.azure_container_instance import AzureContainerInstance from .interface.azure_ipv4_range import AzureIPv4Range @@ -8,7 +7,6 @@ __all__ = [ "AzureApi", - "AzureCliSingleton", "AzureContainerInstance", "AzureIPv4Range", "AzurePostgreSQLDatabase", diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py deleted file mode 100644 index 135a41752a..0000000000 --- a/data_safe_haven/external/api/azure_cli.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Interface to the Azure CLI""" - -import json -import subprocess -from dataclasses import dataclass -from shutil import which - -import typer - -from data_safe_haven import console -from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.logging import get_logger -from data_safe_haven.singleton import Singleton - - -@dataclass -class AzureCliAccount: - """Dataclass for Azure CLI Account details""" - - name: str - id_: str - tenant_id: str - tenant_name: str - - -class AzureCliSingleton(metaclass=Singleton): - """Interface to the Azure CLI""" - - def __init__(self) -> None: - self.logger = get_logger() - - path = which("az") - if path is None: - msg = "Unable to find Azure CLI executable in your path.\nPlease ensure that Azure CLI is installed" - raise DataSafeHavenAzureError(msg) - self.path = path - - self._account: AzureCliAccount | None = None - self._confirmed = False - - @property - def account(self) -> AzureCliAccount: - if not self._account: - try: - result = subprocess.check_output( - [self.path, "account", "show", "--output", "json"], - stderr=subprocess.PIPE, - encoding="utf8", - ) - result_dict = json.loads(result) - except subprocess.CalledProcessError as exc: - self.logger.error(str(exc.stderr).replace("ERROR:", "").strip()) - msg = "Error getting account information from Azure CLI." - raise DataSafeHavenAzureError(msg) from exc - except json.JSONDecodeError as exc: - msg = f"Unable to parse Azure CLI output as JSON.\n{result}" - raise DataSafeHavenAzureError(msg) from exc - - self._account = AzureCliAccount( - name=result_dict.get("user").get("name"), - id_=result_dict.get("id"), - tenant_id=result_dict.get("tenantId"), - tenant_name=result_dict.get("tenantDisplayName"), - ) - - return self._account - - def confirm(self) -> None: - """Prompt user to confirm the Azure CLI account is correct""" - if self._confirmed: - return None - - account = self.account - self.logger.info( - "You are currently logged into the [blue]Azure CLI[/] with the following details:" - ) - self.logger.info(f"... user: [green]{account.name}[/] ({account.id_})") - self.logger.info( - f"... tenant: [green]{account.tenant_name}[/] ({account.tenant_id})" - ) - if not console.confirm( - "Is this the Azure account you expect?", default_to_yes=False - ): - self.logger.error( - "Please use `az login` to connect to the correct Azure CLI account" - ) - raise typer.Exit(1) - - self._confirmed = True - - def group_id_from_name(self, group_name: str) -> str: - """Get ID for an Entra ID group that this user is permitted to view.""" - try: - result = subprocess.check_output( - [self.path, "ad", "group", "list", "--display-name", group_name], - stderr=subprocess.PIPE, - encoding="utf8", - ) - except subprocess.CalledProcessError as exc: - msg = f"Error reading groups from Azure CLI.\n{exc.stderr}" - raise DataSafeHavenAzureError(msg) from exc - - try: - result_dict = json.loads(result) - return str(result_dict[0]["id"]) - except json.JSONDecodeError as exc: - msg = f"Unable to parse Azure CLI output as JSON.\n{result}" - raise DataSafeHavenAzureError(msg) from exc - except (IndexError, KeyError) as exc: - msg = f"Group '{group_name}' was not found in Azure CLI." - raise DataSafeHavenAzureError(msg) from exc - - def subscription_id(self, subscription_name: str) -> str: - """Get subscription ID from an Azure subscription name.""" - try: - result = subprocess.check_output( - [ - self.path, - "account", - "subscription", - "list", - "--query", - f"[?displayName == '{subscription_name}']", - ], - stderr=subprocess.PIPE, - encoding="utf8", - ) - result_dict = json.loads(result) - return str(result_dict[0]["subscriptionId"]) - except subprocess.CalledProcessError as exc: - self.logger.critical(exc.stderr) - msg = "Error reading subscriptions from Azure CLI." - raise DataSafeHavenAzureError(msg) from exc - except json.JSONDecodeError as exc: - msg = "Unable to parse Azure CLI output as JSON." - raise DataSafeHavenAzureError(msg) from exc - except (IndexError, KeyError) as exc: - msg = f"Subscription '{subscription_name}' was not found in Azure CLI." - raise DataSafeHavenAzureError(msg) from exc From 73137b41e94a4b36618eb80f39c53db5682c7a73 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 12:58:16 +0100 Subject: [PATCH 23/54] :truck: Move credentials into api submodule --- data_safe_haven/external/api/azure_api.py | 2 +- data_safe_haven/external/{interface => api}/credentials.py | 0 data_safe_haven/external/api/graph_api.py | 7 ++++--- 3 files changed, 5 insertions(+), 4 deletions(-) rename data_safe_haven/external/{interface => api}/credentials.py (100%) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 1d752e80b4..ab1908184b 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -68,9 +68,9 @@ DataSafeHavenAzureError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.credentials import AzureApiCredential from data_safe_haven.logging import get_logger +from .credentials import AzureApiCredential from .graph_api import GraphApi diff --git a/data_safe_haven/external/interface/credentials.py b/data_safe_haven/external/api/credentials.py similarity index 100% rename from data_safe_haven/external/interface/credentials.py rename to data_safe_haven/external/api/credentials.py diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 07f462ed17..eeb77558e0 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -17,12 +17,13 @@ DataSafeHavenMicrosoftGraphError, DataSafeHavenValueError, ) -from data_safe_haven.external.interface.credentials import ( +from data_safe_haven.functions import alphanumeric +from data_safe_haven.logging import get_logger + +from .credentials import ( DeferredCredential, GraphApiCredential, ) -from data_safe_haven.functions import alphanumeric -from data_safe_haven.logging import get_logger class GraphApi: From 87db46ae223174fc8289a5902d12305e67b5fa99 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 15:22:46 +0100 Subject: [PATCH 24/54] :white_check_mark: Fix broken tests --- tests/commands/test_pulumi.py | 27 ++++----- tests/commands/test_shm.py | 13 +++-- tests/commands/test_sre.py | 14 ++--- tests/conftest.py | 58 +++++++++++++------- tests/infrastructure/test_project_manager.py | 6 -- 5 files changed, 63 insertions(+), 55 deletions(-) diff --git a/tests/commands/test_pulumi.py b/tests/commands/test_pulumi.py index 8182ec501c..33b727f3c6 100644 --- a/tests/commands/test_pulumi.py +++ b/tests/commands/test_pulumi.py @@ -5,15 +5,14 @@ class TestRun: def test_run_sre( self, runner, - mock_shm_config_from_remote, # noqa: ARG002 - mock_sre_config_from_remote, # noqa: ARG002 - mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + local_project_settings, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 - mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 offline_pulumi_account, # noqa: ARG002 - local_project_settings, # noqa: ARG002 ): result = runner.invoke(pulumi_command_group, ["sandbox", "stack ls"]) assert result.exit_code == 0 @@ -30,15 +29,14 @@ def test_run_sre_incorrect_arguments( def test_run_sre_invalid_command( self, runner, - mock_shm_config_from_remote, # noqa: ARG002 - mock_sre_config_from_remote, # noqa: ARG002 - mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + local_project_settings, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 - mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 offline_pulumi_account, # noqa: ARG002 - local_project_settings, # noqa: ARG002 ): result = runner.invoke( pulumi_command_group, ["sandbox", "not a pulumi command"] @@ -49,15 +47,14 @@ def test_run_sre_invalid_command( def test_run_sre_invalid_name( self, runner, - mock_shm_config_from_remote, # noqa: ARG002 - mock_sre_config_alternate_from_remote, # noqa: ARG002 - mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + local_project_settings, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 - mock_azure_cli_confirm, # noqa: ARG002 mock_install_plugins, # noqa: ARG002 mock_key_vault_key, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_alternate_from_remote, # noqa: ARG002 offline_pulumi_account, # noqa: ARG002 - local_project_settings, # noqa: ARG002 ): result = runner.invoke(pulumi_command_group, ["alternate", "stack ls"]) assert result.exit_code == 1 diff --git a/tests/commands/test_shm.py b/tests/commands/test_shm.py index e426b43f5e..eac80346d4 100644 --- a/tests/commands/test_shm.py +++ b/tests/commands/test_shm.py @@ -30,12 +30,12 @@ def test_infrastructure_show_none(self, runner_none): def test_infrastructure_auth_failure( self, runner, - mock_azure_cli_confirm_then_exit, # noqa: ARG002 + mock_azureapicredential_get_credential_failure, # noqa: ARG002 ): result = runner.invoke(shm_command_group, ["deploy"]) assert result.exit_code == 1 - assert "mock login" in result.stdout - assert "mock login error" in result.stdout + assert "mock get_credential\n" in result.stdout + assert "mock get_credential error" in result.stdout class TestTeardownSHM: @@ -62,9 +62,10 @@ def test_show_none(self, runner_none): def test_auth_failure( self, runner, - mock_azure_cli_confirm_then_exit, # noqa: ARG002 + mock_azureapicredential_get_credential_failure, # noqa: ARG002 ): result = runner.invoke(shm_command_group, ["teardown"]) assert result.exit_code == 1 - assert "mock login" in result.stdout - assert "mock login error" in result.stdout + assert "mock get_credential\n" in result.stdout + assert "mock get_credential error" in result.stdout + assert "Could not teardown Safe Haven Management environment." in result.stdout diff --git a/tests/commands/test_sre.py b/tests/commands/test_sre.py index 4b1bc887da..ccd0c67700 100644 --- a/tests/commands/test_sre.py +++ b/tests/commands/test_sre.py @@ -5,7 +5,6 @@ class TestDeploySRE: def test_deploy( self, runner, - mock_azure_cli_confirm, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote_or_create, # noqa: ARG002 @@ -27,12 +26,12 @@ def test_no_context_file(self, runner_no_context_file): def test_auth_failure( self, runner, - mock_azure_cli_confirm_then_exit, # noqa: ARG002 + mock_azureapicredential_get_credential_failure, # noqa: ARG002 ): result = runner.invoke(sre_command_group, ["deploy", "sandbox"]) assert result.exit_code == 1 - assert "mock login" in result.stdout - assert "mock login error" in result.stdout + assert "mock get_credential\n" in result.stdout + assert "mock get_credential error" in result.stdout def test_no_shm( self, @@ -50,7 +49,6 @@ class TestTeardownSRE: def test_teardown( self, runner, - mock_azure_cli_confirm, # noqa: ARG002 mock_graph_api_token, # noqa: ARG002 mock_ip_1_2_3_4, # noqa: ARG002 mock_pulumi_config_from_remote, # noqa: ARG002 @@ -83,9 +81,9 @@ def test_no_shm( def test_auth_failure( self, runner, - mock_azure_cli_confirm_then_exit, # noqa: ARG002 + mock_azureapicredential_get_credential_failure, # noqa: ARG002 ): result = runner.invoke(sre_command_group, ["teardown", "sandbox"]) assert result.exit_code == 1 - assert "mock login" in result.stdout - assert "mock login error" in result.stdout + assert "mock get_credential\n" in result.stdout + assert "mock get_credential error" in result.stdout diff --git a/tests/conftest.py b/tests/conftest.py index f7fd8234e5..e424e7c9f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ from subprocess import run import yaml +from azure.core.credentials import AccessToken, TokenCredential +from azure.mgmt.resource.subscriptions.models import Subscription from pulumi.automation import ProjectSettings from pytest import fixture @@ -23,8 +25,9 @@ ConfigSectionSRE, ConfigSubsectionRemoteDesktopOpts, ) -from data_safe_haven.exceptions import DataSafeHavenAzureAPIAuthenticationError -from data_safe_haven.external import AzureApi, AzureCliSingleton, PulumiAccount +from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.external import AzureApi, PulumiAccount +from data_safe_haven.external.api.credentials import AzureApiCredential from data_safe_haven.infrastructure import SREProjectManager from data_safe_haven.infrastructure.project_manager import ProjectManager from data_safe_haven.logging import init_logging @@ -125,26 +128,42 @@ def log_directory(session_mocker, tmp_path_factory): @fixture -def mock_azure_cli_confirm(mocker): - """Always pass AzureCliSingleton.confirm without attempting login""" +def mock_azureapi_get_subscription(mocker): + subscription = Subscription() + subscription.display_name = "Data Safe Haven Acme" + subscription.subscription_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" + subscription.tenant_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" mocker.patch.object( - AzureCliSingleton, - "confirm", - return_value=None, + AzureApi, + "get_subscription", + return_value=subscription, ) @fixture -def mock_azure_cli_confirm_then_exit(mocker): - def confirm_then_exit(): - print("mock login") # noqa: T201 - msg = "mock login error" - raise DataSafeHavenAzureAPIAuthenticationError(msg) +def mock_azureapicredential_get_credential(mocker): + class MockCredential(TokenCredential): + def get_token(*args, **kwargs): # noqa: ARG002 + return AccessToken("dummy-token", 0) mocker.patch.object( - AzureCliSingleton, - "confirm", - side_effect=confirm_then_exit, + AzureApiCredential, + "get_credential", + return_value=MockCredential(), + ) + + +@fixture +def mock_azureapicredential_get_credential_failure(mocker): + def fail_get_credential(): + print("mock get_credential") # noqa: T201 + msg = "mock get_credential error" + raise DataSafeHavenAzureError(msg) + + mocker.patch.object( + AzureApiCredential, + "get_credential", + side_effect=fail_get_credential, ) @@ -168,7 +187,7 @@ def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001 @fixture -def offline_pulumi_account(monkeypatch, mock_azure_cli_confirm): # noqa: ARG001 +def offline_pulumi_account(monkeypatch): """Overwrite PulumiAccount so that it runs locally""" monkeypatch.setattr( PulumiAccount, "env", {"PULUMI_CONFIG_PASSPHRASE": "passphrase"} @@ -407,11 +426,10 @@ def sre_project_manager( context_no_secrets, sre_config, pulumi_config_no_key, - mock_azure_cli_confirm, # noqa: ARG001 - mock_install_plugins, # noqa: ARG001 - mock_key_vault_key, # noqa: ARG001 - offline_pulumi_account, # noqa: ARG001 local_project_settings, # noqa: ARG001 + mock_azureapi_get_subscription, # noqa: ARG001 + mock_azureapicredential_get_credential, # noqa: ARG001 + offline_pulumi_account, # noqa: ARG001 ): return SREProjectManager( context=context_no_secrets, diff --git a/tests/infrastructure/test_project_manager.py b/tests/infrastructure/test_project_manager.py index 5d7d7926bd..806e6246ce 100644 --- a/tests/infrastructure/test_project_manager.py +++ b/tests/infrastructure/test_project_manager.py @@ -22,8 +22,6 @@ def test_constructor( sre_config, pulumi_config_no_key, pulumi_project_sandbox, - mock_azure_cli_confirm, # noqa: ARG002 - mock_install_plugins, # noqa: ARG002 ): sre = SREProjectManager( context_no_secrets, @@ -40,8 +38,6 @@ def test_new_project( context_no_secrets, sre_config, pulumi_config_empty, - mock_azure_cli_confirm, # noqa: ARG002 - mock_install_plugins, # noqa: ARG002 ): sre = SREProjectManager( context_no_secrets, @@ -63,8 +59,6 @@ def test_new_project_fail( context_no_secrets, sre_config, pulumi_config_empty, - mock_azure_cli_confirm, # noqa: ARG002 - mock_install_plugins, # noqa: ARG002 ): sre = SREProjectManager( context_no_secrets, sre_config, pulumi_config_empty, create_project=False From d60ccd29e82db03d75f591ba7cf789e0e55027aa Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 15:30:58 +0100 Subject: [PATCH 25/54] :wrench: Ignore the resources folder when looking for tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f1f7779307..d3d95c0086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,7 @@ addopts = [ "-vvv", "--import-mode=importlib", "--disable-warnings", - "--ignore=tests/*", + "--ignore=data_safe_haven/resources/*", ] [tool.ruff.lint] From 5b52f7f9ce41084ff5d011236ec36f665168f4d1 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 15:41:30 +0100 Subject: [PATCH 26/54] :truck: Rename AzureApi to AzureSdk --- .../administration/users/guacamole_users.py | 6 ++-- data_safe_haven/config/context.py | 6 ++-- data_safe_haven/config/shm_config.py | 10 +++--- data_safe_haven/external/__init__.py | 4 +-- .../api/{azure_api.py => azure_sdk.py} | 10 +++--- data_safe_haven/external/api/credentials.py | 4 +-- .../interface/azure_container_instance.py | 10 +++--- .../interface/azure_postgresql_database.py | 6 ++-- .../external/interface/pulumi_account.py | 6 ++-- .../components/dynamic/blob_container_acl.py | 10 +++--- .../components/dynamic/file_upload.py | 10 +++--- .../components/dynamic/ssl_certificate.py | 18 +++++----- .../infrastructure/programs/imperative_shm.py | 36 +++++++++---------- .../infrastructure/project_manager.py | 6 ++-- .../provisioning/sre_provisioning_manager.py | 10 +++--- .../serialisers/azure_serialisable_model.py | 14 ++++---- tests/commands/conftest.py | 6 ++-- tests/commands/test_config_shm.py | 6 ++-- tests/commands/test_config_sre.py | 16 ++++----- tests/config/test_pulumi.py | 12 +++---- tests/config/test_shm_config.py | 6 ++-- tests/config/test_sre_config.py | 6 ++-- tests/conftest.py | 12 +++---- tests/external/api/test_azure_api.py | 18 +++++----- .../test_azure_serialisable_model.py | 10 +++--- 25 files changed, 129 insertions(+), 129 deletions(-) rename data_safe_haven/external/api/{azure_api.py => azure_sdk.py} (99%) diff --git a/data_safe_haven/administration/users/guacamole_users.py b/data_safe_haven/administration/users/guacamole_users.py index a2c35754be..8c0c5381a0 100644 --- a/data_safe_haven/administration/users/guacamole_users.py +++ b/data_safe_haven/administration/users/guacamole_users.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from data_safe_haven.config import Context, DSHPulumiConfig, SREConfig -from data_safe_haven.external import AzureApi, AzurePostgreSQLDatabase +from data_safe_haven.external import AzurePostgreSQLDatabase, AzureSdk from data_safe_haven.infrastructure import SREProjectManager from .research_user import ResearchUser @@ -23,8 +23,8 @@ def __init__( pulumi_config=pulumi_config, ) # Read the SRE database secret from key vault - azure_api = AzureApi(context.subscription_name) - connection_db_server_password = azure_api.get_keyvault_secret( + azure_sdk = AzureSdk(context.subscription_name) + connection_db_server_password = azure_sdk.get_keyvault_secret( sre_stack.output("data")["key_vault_name"], sre_stack.output("data")["password_user_database_admin_secret"], ) diff --git a/data_safe_haven/config/context.py b/data_safe_haven/config/context.py index b7e5b0cdd6..fe9d52e06a 100644 --- a/data_safe_haven/config/context.py +++ b/data_safe_haven/config/context.py @@ -9,7 +9,7 @@ from data_safe_haven import __version__ from data_safe_haven.directories import config_dir -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.serialisers import ContextBase from data_safe_haven.types import ( AzureSubscriptionName, @@ -67,8 +67,8 @@ def pulumi_backend_url(self) -> str: @property def pulumi_encryption_key(self) -> KeyVaultKey: if not self._pulumi_encryption_key: - azure_api = AzureApi(subscription_name=self.subscription_name) - self._pulumi_encryption_key = azure_api.get_keyvault_key( + azure_sdk = AzureSdk(subscription_name=self.subscription_name) + self._pulumi_encryption_key = azure_sdk.get_keyvault_key( key_name=self.pulumi_encryption_key_name, key_vault_name=self.key_vault_name, ) diff --git a/data_safe_haven/config/shm_config.py b/data_safe_haven/config/shm_config.py index 8df5834511..1fa3410f36 100644 --- a/data_safe_haven/config/shm_config.py +++ b/data_safe_haven/config/shm_config.py @@ -4,7 +4,7 @@ from typing import ClassVar, Self -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.serialisers import AzureSerialisableModel, ContextBase from .config_sections import ConfigSectionAzure, ConfigSectionSHM @@ -26,15 +26,15 @@ def from_args( location: str, ) -> SHMConfig: """Construct an SHMConfig from arguments.""" - azure_api = AzureApi(subscription_name=context.subscription_name) - admin_group_id = azure_api.entra_directory.get_id_from_groupname( + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + admin_group_id = azure_sdk.entra_directory.get_id_from_groupname( context.admin_group_name ) return SHMConfig.model_construct( azure=ConfigSectionAzure.model_construct( location=location, - subscription_id=azure_api.subscription_id, - tenant_id=azure_api.tenant_id, + subscription_id=azure_sdk.subscription_id, + tenant_id=azure_sdk.tenant_id, ), shm=ConfigSectionSHM.model_construct( admin_group_id=admin_group_id, diff --git a/data_safe_haven/external/__init__.py b/data_safe_haven/external/__init__.py index f0eb5a42fb..5e46325958 100644 --- a/data_safe_haven/external/__init__.py +++ b/data_safe_haven/external/__init__.py @@ -1,4 +1,4 @@ -from .api.azure_api import AzureApi +from .api.azure_sdk import AzureSdk from .api.graph_api import GraphApi from .interface.azure_container_instance import AzureContainerInstance from .interface.azure_ipv4_range import AzureIPv4Range @@ -6,7 +6,7 @@ from .interface.pulumi_account import PulumiAccount __all__ = [ - "AzureApi", + "AzureSdk", "AzureContainerInstance", "AzureIPv4Range", "AzurePostgreSQLDatabase", diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_sdk.py similarity index 99% rename from data_safe_haven/external/api/azure_api.py rename to data_safe_haven/external/api/azure_sdk.py index ab1908184b..f6ecfb1d7e 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_sdk.py @@ -70,16 +70,16 @@ ) from data_safe_haven.logging import get_logger -from .credentials import AzureApiCredential +from .credentials import AzureSdkCredential from .graph_api import GraphApi -class AzureApi: - """Interface to the Azure REST API""" +class AzureSdk: + """Interface to the Azure Python SDK""" def __init__(self, subscription_name: str) -> None: self.logger = get_logger() - self.credential = AzureApiCredential() + self.credential = AzureSdkCredential() self.subscription_name = subscription_name self.subscription_id_: str | None = None self.tenant_id_: str | None = None @@ -87,7 +87,7 @@ def __init__(self, subscription_name: str) -> None: @property def entra_directory(self) -> GraphApi: return GraphApi( - credential=AzureApiCredential("https://graph.microsoft.com//.default"), + credential=AzureSdkCredential("https://graph.microsoft.com//.default"), ) @property diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index 61eeefdc56..a891560e81 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -61,9 +61,9 @@ def get_token( return DeferredCredential.token_ -class AzureApiCredential(DeferredCredential): +class AzureSdkCredential(DeferredCredential): """ - Credential loader used by AzureApi + Credential loader used by AzureSdk Uses AzureCliCredential for authentication """ diff --git a/data_safe_haven/external/interface/azure_container_instance.py b/data_safe_haven/external/interface/azure_container_instance.py index b2e5cbdafa..0d214ba8ee 100644 --- a/data_safe_haven/external/interface/azure_container_instance.py +++ b/data_safe_haven/external/interface/azure_container_instance.py @@ -10,7 +10,7 @@ ) from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.logging import get_logger @@ -23,7 +23,7 @@ def __init__( resource_group_name: str, subscription_name: str, ): - self.azure_api = AzureApi(subscription_name) + self.azure_sdk = AzureSdk(subscription_name) self.logger = get_logger() self.resource_group_name = resource_group_name self.container_group_name = container_group_name @@ -36,7 +36,7 @@ def wait(poller: LROPoller[None]) -> None: @property def current_ip_address(self) -> str: aci_client = ContainerInstanceManagementClient( - self.azure_api.credential, self.azure_api.subscription_id + self.azure_sdk.credential, self.azure_sdk.subscription_id ) ip_address = aci_client.container_groups.get( self.resource_group_name, self.container_group_name @@ -51,7 +51,7 @@ def restart(self, target_ip_address: str | None = None) -> None: # Connect to Azure clients try: aci_client = ContainerInstanceManagementClient( - self.azure_api.credential, self.azure_api.subscription_id + self.azure_sdk.credential, self.azure_sdk.subscription_id ) if not target_ip_address: target_ip_address = self.current_ip_address @@ -98,7 +98,7 @@ def run_executable(self, container_name: str, executable_path: str) -> list[str] """ # Connect to Azure clients aci_client = ContainerInstanceManagementClient( - self.azure_api.credential, self.azure_api.subscription_id + self.azure_sdk.credential, self.azure_sdk.subscription_id ) # Run command diff --git a/data_safe_haven/external/interface/azure_postgresql_database.py b/data_safe_haven/external/interface/azure_postgresql_database.py index 2185c142bf..ccd37f2ba4 100644 --- a/data_safe_haven/external/interface/azure_postgresql_database.py +++ b/data_safe_haven/external/interface/azure_postgresql_database.py @@ -16,7 +16,7 @@ DataSafeHavenAzureError, DataSafeHavenValueError, ) -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.functions import current_ip_address from data_safe_haven.logging import get_logger from data_safe_haven.types import PathType @@ -43,7 +43,7 @@ def __init__( resource_group_name: str, subscription_name: str, ) -> None: - self.azure_api = AzureApi(subscription_name) + self.azure_sdk = AzureSdk(subscription_name) self.current_ip = current_ip_address() self.db_client_ = None self.db_name = database_name @@ -81,7 +81,7 @@ def db_client(self) -> PostgreSQLManagementClient: """Get the database client.""" if not self.db_client_: self.db_client_ = PostgreSQLManagementClient( - self.azure_api.credential, self.azure_api.subscription_id + self.azure_sdk.credential, self.azure_sdk.subscription_id ) return self.db_client_ diff --git a/data_safe_haven/external/interface/pulumi_account.py b/data_safe_haven/external/interface/pulumi_account.py index 8f7fc55c90..4db66c52c8 100644 --- a/data_safe_haven/external/interface/pulumi_account.py +++ b/data_safe_haven/external/interface/pulumi_account.py @@ -4,7 +4,7 @@ from typing import Any from data_safe_haven.exceptions import DataSafeHavenPulumiError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk class PulumiAccount: @@ -30,8 +30,8 @@ def __init__( def env(self) -> dict[str, Any]: """Get necessary Pulumi environment variables""" if not self._env: - azure_api = AzureApi(self.subscription_name) - storage_account_keys = azure_api.get_storage_account_keys( + azure_sdk = AzureSdk(self.subscription_name) + storage_account_keys = azure_sdk.get_storage_account_keys( self.resource_group_name, self.storage_account_name, ) diff --git a/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py b/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py index 8d3e3e8d80..b4317d575b 100644 --- a/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py +++ b/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py @@ -6,7 +6,7 @@ from pulumi.dynamic import CreateResult, DiffResult, Resource from data_safe_haven.exceptions import DataSafeHavenPulumiError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from .dsh_resource_provider import DshResourceProvider @@ -55,8 +55,8 @@ def create(self, props: dict[str, Any]) -> CreateResult: """Set ACLs for a given blob container.""" outs = dict(**props) try: - azure_api = AzureApi(props["subscription_name"]) - azure_api.set_blob_container_acl( + azure_sdk = AzureSdk(props["subscription_name"]) + azure_sdk.set_blob_container_acl( container_name=props["container_name"], desired_acl=props["desired_acl"], resource_group_name=props["resource_group_name"], @@ -75,8 +75,8 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) try: - azure_api = AzureApi(props["subscription_name"]) - azure_api.set_blob_container_acl( + azure_sdk = AzureSdk(props["subscription_name"]) + azure_sdk.set_blob_container_acl( container_name=props["container_name"], desired_acl="user::rwx,group::r-x,other::---", resource_group_name=props["resource_group_name"], diff --git a/data_safe_haven/infrastructure/components/dynamic/file_upload.py b/data_safe_haven/infrastructure/components/dynamic/file_upload.py index f6fa242ef7..bf547da498 100644 --- a/data_safe_haven/infrastructure/components/dynamic/file_upload.py +++ b/data_safe_haven/infrastructure/components/dynamic/file_upload.py @@ -6,7 +6,7 @@ from pulumi.dynamic import CreateResult, DiffResult, Resource, UpdateResult from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.functions import b64encode from .dsh_resource_provider import DshResourceProvider @@ -42,7 +42,7 @@ class FileUploadProvider(DshResourceProvider): def create(self, props: dict[str, Any]) -> CreateResult: """Run a remote script to create a file on a VM""" outs = dict(**props) - azure_api = AzureApi(props["subscription_name"]) + azure_sdk = AzureSdk(props["subscription_name"]) script_contents = f""" target_dir=$(dirname "$target"); mkdir -p $target_dir 2> /dev/null; @@ -59,7 +59,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: "target": props["file_target"], } # Run remote script - script_output = azure_api.run_remote_script_waiting( + script_output = azure_sdk.run_remote_script_waiting( props["vm_resource_group_name"], script_contents, script_parameters, @@ -83,7 +83,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: """Delete the remote file from the VM""" # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) - azure_api = AzureApi(props["subscription_name"]) + azure_sdk = AzureSdk(props["subscription_name"]) script_contents = """ rm -f "$target"; echo "Removed file at $target"; @@ -92,7 +92,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: "target": props["file_target"], } # Run remote script - azure_api.run_remote_script_waiting( + azure_sdk.run_remote_script_waiting( props["vm_resource_group_name"], script_contents, script_parameters, diff --git a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py index 723b0555d0..d9c8bbc4d4 100644 --- a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py +++ b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py @@ -17,7 +17,7 @@ from simple_acme_dns import ACMEClient from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenSSLError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from .dsh_resource_provider import DshResourceProvider @@ -47,8 +47,8 @@ def refresh(self, props: dict[str, Any]) -> dict[str, Any]: try: outs = dict(**props) with suppress(DataSafeHavenAzureError): - azure_api = AzureApi(outs["subscription_name"]) - certificate = azure_api.get_keyvault_certificate( + azure_sdk = AzureSdk(outs["subscription_name"]) + certificate = azure_sdk.get_keyvault_certificate( outs["certificate_secret_name"], outs["key_vault_name"] ) if certificate.secret_id: @@ -77,10 +77,10 @@ def create(self, props: dict[str, Any]) -> CreateResult: private_key_bytes = client.generate_private_key(key_type="rsa2048") client.generate_csr() # Request DNS verification tokens and add them to the DNS record - azure_api = AzureApi(props["subscription_name"]) + azure_sdk = AzureSdk(props["subscription_name"]) verification_tokens = client.request_verification_tokens().items() for record_name, record_values in verification_tokens: - record_set = azure_api.ensure_dns_txt_record( + record_set = azure_sdk.ensure_dns_txt_record( record_name=record_name.replace(f".{props['domain_name']}", ""), record_value=record_values[0], resource_group_name=props["networking_resource_group_name"], @@ -130,7 +130,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: NoEncryption(), ) # Add certificate to KeyVault - kvcert = azure_api.import_keyvault_certificate( + kvcert = azure_sdk.import_keyvault_certificate( certificate_name=props["certificate_secret_name"], certificate_contents=pfx_bytes, key_vault_name=props["key_vault_name"], @@ -152,14 +152,14 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: id(id_) try: # Remove the DNS record - azure_api = AzureApi(props["subscription_name"]) - azure_api.remove_dns_txt_record( + azure_sdk = AzureSdk(props["subscription_name"]) + azure_sdk.remove_dns_txt_record( record_name="_acme_challenge", resource_group_name=props["networking_resource_group_name"], zone_name=props["domain_name"], ) # Remove the Key Vault certificate - azure_api.remove_keyvault_certificate( + azure_sdk.remove_keyvault_certificate( certificate_name=props["certificate_secret_name"], key_vault_name=props["key_vault_name"], ) diff --git a/data_safe_haven/infrastructure/programs/imperative_shm.py b/data_safe_haven/infrastructure/programs/imperative_shm.py index e003980660..3d461aaeba 100644 --- a/data_safe_haven/infrastructure/programs/imperative_shm.py +++ b/data_safe_haven/infrastructure/programs/imperative_shm.py @@ -3,7 +3,7 @@ DataSafeHavenAzureError, DataSafeHavenMicrosoftGraphError, ) -from data_safe_haven.external import AzureApi, GraphApi +from data_safe_haven.external import AzureSdk, GraphApi from data_safe_haven.logging import get_logger @@ -11,23 +11,23 @@ class ImperativeSHM: """Azure resources to support Data Safe Haven context""" def __init__(self, context: Context, config: SHMConfig) -> None: - self.azure_api_: AzureApi | None = None + self.azure_sdk_: AzureSdk | None = None self.config = config self.context = context self.tags = {"component": "SHM"} | context.tags @property - def azure_api(self) -> AzureApi: + def azure_sdk(self) -> AzureSdk: """Load AzureAPI on demand Returns: - AzureApi: An initialised AzureApi object + AzureSdk: An initialised AzureSdk object """ - if not self.azure_api_: - self.azure_api_ = AzureApi( + if not self.azure_sdk_: + self.azure_sdk_ = AzureSdk( subscription_name=self.context.subscription_name, ) - return self.azure_api_ + return self.azure_sdk_ def deploy(self) -> None: """Deploy all desired resources @@ -39,7 +39,7 @@ def deploy(self) -> None: logger.info(f"Preparing to deploy [green]{self.context.description}[/] SHM.") # Deploy the resources needed by Pulumi try: - resource_group = self.azure_api.ensure_resource_group( + resource_group = self.azure_sdk.ensure_resource_group( location=self.config.azure.location, resource_group_name=self.context.resource_group_name, tags=self.tags, @@ -47,12 +47,12 @@ def deploy(self) -> None: if not resource_group.name: msg = f"Resource group '{self.context.resource_group_name}' was not created." raise DataSafeHavenAzureError(msg) - identity = self.azure_api.ensure_managed_identity( + identity = self.azure_sdk.ensure_managed_identity( identity_name=self.context.managed_identity_name, location=resource_group.location, resource_group_name=resource_group.name, ) - storage_account = self.azure_api.ensure_storage_account( + storage_account = self.azure_sdk.ensure_storage_account( location=resource_group.location, resource_group_name=resource_group.name, storage_account_name=self.context.storage_account_name, @@ -61,17 +61,17 @@ def deploy(self) -> None: if not storage_account.name: msg = f"Storage account '{self.context.storage_account_name}' was not created." raise DataSafeHavenAzureError(msg) - _ = self.azure_api.ensure_storage_blob_container( + _ = self.azure_sdk.ensure_storage_blob_container( container_name=self.context.storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) - _ = self.azure_api.ensure_storage_blob_container( + _ = self.azure_sdk.ensure_storage_blob_container( container_name=self.context.pulumi_storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) - keyvault = self.azure_api.ensure_keyvault( + keyvault = self.azure_sdk.ensure_keyvault( admin_group_id=self.config.shm.admin_group_id, key_vault_name=self.context.key_vault_name, location=resource_group.location, @@ -82,7 +82,7 @@ def deploy(self) -> None: if not keyvault.name: msg = f"Keyvault '{self.context.key_vault_name}' was not created." raise DataSafeHavenAzureError(msg) - self.azure_api.ensure_keyvault_key( + self.azure_sdk.ensure_keyvault_key( key_name=self.context.pulumi_encryption_key_name, key_vault_name=keyvault.name, ) @@ -92,7 +92,7 @@ def deploy(self) -> None: # Deploy common resources that will be needed by SREs try: - zone = self.azure_api.ensure_dns_zone( + zone = self.azure_sdk.ensure_dns_zone( resource_group_name=resource_group.name, zone_name=self.config.shm.fqdn, tags=self.tags, @@ -101,7 +101,7 @@ def deploy(self) -> None: msg = f"DNS zone '{self.config.shm.fqdn}' was not created." raise DataSafeHavenAzureError(msg) nameservers = [str(n) for n in zone.name_servers] - self.azure_api.ensure_dns_caa_record( + self.azure_sdk.ensure_dns_caa_record( record_flags=0, record_name="@", record_tag="issue", @@ -127,7 +127,7 @@ def deploy(self) -> None: ) verification_record = graph_api.add_custom_domain(self.config.shm.fqdn) # Add the record to DNS - self.azure_api.ensure_dns_txt_record( + self.azure_sdk.ensure_dns_txt_record( record_name="@", record_value=verification_record, resource_group_name=resource_group.name, @@ -154,7 +154,7 @@ def teardown(self) -> None: logger.info( f"Removing [green]{self.context.description}[/] resource group {self.context.resource_group_name}." ) - self.azure_api.remove_resource_group(self.context.resource_group_name) + self.azure_sdk.remove_resource_group(self.context.resource_group_name) except DataSafeHavenAzureError as exc: msg = "Failed to destroy context resources." raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/infrastructure/project_manager.py b/data_safe_haven/infrastructure/project_manager.py index 56885621cd..29a91d7f75 100644 --- a/data_safe_haven/infrastructure/project_manager.py +++ b/data_safe_haven/infrastructure/project_manager.py @@ -19,7 +19,7 @@ DataSafeHavenConfigError, DataSafeHavenPulumiError, ) -from data_safe_haven.external import AzureApi, PulumiAccount +from data_safe_haven.external import AzureSdk, PulumiAccount from data_safe_haven.functions import replace_separators from data_safe_haven.logging import from_ansi, get_console_handler, get_logger @@ -238,8 +238,8 @@ def destroy(self) -> None: self.logger.debug( f"Removing Pulumi stack backup [green]{stack_backup_name}[/]." ) - azure_api = AzureApi(self.context.subscription_name) - azure_api.remove_blob( + azure_sdk = AzureSdk(self.context.subscription_name) + azure_sdk.remove_blob( blob_name=f".pulumi/stacks/{self.project_name}/{stack_backup_name}", resource_group_name=self.context.resource_group_name, storage_account_name=self.context.storage_account_name, diff --git a/data_safe_haven/provisioning/sre_provisioning_manager.py b/data_safe_haven/provisioning/sre_provisioning_manager.py index 096739d742..cd178dbd58 100644 --- a/data_safe_haven/provisioning/sre_provisioning_manager.py +++ b/data_safe_haven/provisioning/sre_provisioning_manager.py @@ -4,9 +4,9 @@ from typing import Any from data_safe_haven.external import ( - AzureApi, AzureContainerInstance, AzurePostgreSQLDatabase, + AzureSdk, GraphApi, ) from data_safe_haven.infrastructure import SREProjectManager @@ -36,8 +36,8 @@ def __init__( # Read secrets from key vault keyvault_name = sre_stack.output("data")["key_vault_name"] secret_name = sre_stack.output("data")["password_user_database_admin_secret"] - azure_api = AzureApi(self.subscription_name) - connection_db_server_password = azure_api.get_keyvault_secret( + azure_sdk = AzureSdk(self.subscription_name) + connection_db_server_password = azure_sdk.get_keyvault_secret( keyvault_name, secret_name ) @@ -67,8 +67,8 @@ def __init__( def available_vm_skus(self) -> dict[str, dict[str, Any]]: """Load available VM SKUs for this region""" if not self._available_vm_skus: - azure_api = AzureApi(self.subscription_name) - self._available_vm_skus = azure_api.list_available_vm_skus(self.location) + azure_sdk = AzureSdk(self.subscription_name) + self._available_vm_skus = azure_sdk.list_available_vm_skus(self.location) return self._available_vm_skus def create_security_groups(self) -> None: diff --git a/data_safe_haven/serialisers/azure_serialisable_model.py b/data_safe_haven/serialisers/azure_serialisable_model.py index 252ab1708b..5e3a41acb6 100644 --- a/data_safe_haven/serialisers/azure_serialisable_model.py +++ b/data_safe_haven/serialisers/azure_serialisable_model.py @@ -3,7 +3,7 @@ from typing import Any, ClassVar, TypeVar from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenError -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from .context_base import ContextBase from .yaml_serialisable_model import YAMLSerialisableModel @@ -28,8 +28,8 @@ def from_remote( DataSafeHavenAzureError: if the file cannot be loaded """ try: - azure_api = AzureApi(subscription_name=context.subscription_name) - config_yaml = azure_api.download_blob( + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + config_yaml = azure_sdk.download_blob( filename or cls.default_filename, context.resource_group_name, context.storage_account_name, @@ -58,8 +58,8 @@ def remote_exists( cls: type[T], context: ContextBase, *, filename: str | None = None ) -> bool: """Check whether a remote instance of this model exists.""" - azure_api = AzureApi(subscription_name=context.subscription_name) - return azure_api.blob_exists( + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + return azure_sdk.blob_exists( filename or cls.default_filename, context.resource_group_name, context.storage_account_name, @@ -80,8 +80,8 @@ def remote_yaml_diff( def upload(self: T, context: ContextBase, *, filename: str | None = None) -> None: """Serialise an AzureSerialisableModel to a YAML file in Azure storage.""" - azure_api = AzureApi(subscription_name=context.subscription_name) - azure_api.upload_blob( + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + azure_sdk.upload_blob( self.to_yaml(), filename or self.default_filename, context.resource_group_name, diff --git a/tests/commands/conftest.py b/tests/commands/conftest.py index e5ac3a4646..6459c84d6c 100644 --- a/tests/commands/conftest.py +++ b/tests/commands/conftest.py @@ -12,7 +12,7 @@ DataSafeHavenAzureAPIAuthenticationError, DataSafeHavenAzureError, ) -from data_safe_haven.external import AzureApi, GraphApi +from data_safe_haven.external import AzureSdk, GraphApi from data_safe_haven.infrastructure import ImperativeSHM, SREProjectManager @@ -22,8 +22,8 @@ def context(context_yaml) -> Context: @fixture -def mock_azure_api_blob_exists_false(mocker): - mocker.patch.object(AzureApi, "blob_exists", return_value=False) +def mock_azure_sdk_blob_exists_false(mocker): + mocker.patch.object(AzureSdk, "blob_exists", return_value=False) @fixture diff --git a/tests/commands/test_config_shm.py b/tests/commands/test_config_shm.py index 495e8716d8..a51c6699aa 100644 --- a/tests/commands/test_config_shm.py +++ b/tests/commands/test_config_shm.py @@ -1,12 +1,12 @@ from data_safe_haven.commands.config import config_command_group from data_safe_haven.config import SHMConfig -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk class TestShowSHM: def test_show(self, mocker, runner, context, shm_config_yaml): mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=shm_config_yaml + AzureSdk, "download_blob", return_value=shm_config_yaml ) result = runner.invoke(config_command_group, ["show-shm"]) @@ -21,7 +21,7 @@ def test_show(self, mocker, runner, context, shm_config_yaml): ) def test_show_file(self, mocker, runner, shm_config_yaml, tmp_path): - mocker.patch.object(AzureApi, "download_blob", return_value=shm_config_yaml) + mocker.patch.object(AzureSdk, "download_blob", return_value=shm_config_yaml) template_file = (tmp_path / "template_show.yaml").absolute() result = runner.invoke( config_command_group, ["show-shm", "--file", str(template_file)] diff --git a/tests/commands/test_config_sre.py b/tests/commands/test_config_sre.py index dc1f551923..5e7d8e13cd 100644 --- a/tests/commands/test_config_sre.py +++ b/tests/commands/test_config_sre.py @@ -1,14 +1,14 @@ from data_safe_haven.commands.config import config_command_group from data_safe_haven.config import SREConfig from data_safe_haven.config.sre_config import sre_config_name -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk class TestShowSRE: def test_show(self, mocker, runner, context, sre_config_yaml): sre_name = "sandbox" mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=sre_config_yaml + AzureSdk, "download_blob", return_value=sre_config_yaml ) result = runner.invoke(config_command_group, ["show", sre_name]) @@ -23,7 +23,7 @@ def test_show(self, mocker, runner, context, sre_config_yaml): ) def test_show_file(self, mocker, runner, sre_config_yaml, tmp_path): - mocker.patch.object(AzureApi, "download_blob", return_value=sre_config_yaml) + mocker.patch.object(AzureSdk, "download_blob", return_value=sre_config_yaml) template_file = (tmp_path / "template_show.yaml").absolute() result = runner.invoke( config_command_group, ["show", "sre-name", "--file", str(template_file)] @@ -69,7 +69,7 @@ def test_upload_new( mock_exists = mocker.patch.object( SREConfig, "remote_exists", return_value=False ) - mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_upload = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) result = runner.invoke( config_command_group, ["upload", str(sre_config_file)], @@ -94,7 +94,7 @@ def test_upload_no_changes( mock_from_remote = mocker.patch.object( SREConfig, "from_remote", return_value=sre_config ) - mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_upload = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) result = runner.invoke( config_command_group, ["upload", str(sre_config_file)], @@ -122,7 +122,7 @@ def test_upload_changes( mock_from_remote = mocker.patch.object( SREConfig, "from_remote", return_value=sre_config_alternate ) - mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_upload = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) result = runner.invoke( config_command_group, ["upload", str(sre_config_file)], @@ -152,7 +152,7 @@ def test_upload_changes_n( mock_from_remote = mocker.patch.object( SREConfig, "from_remote", return_value=sre_config_alternate ) - mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_upload = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) result = runner.invoke( config_command_group, ["upload", str(sre_config_file)], @@ -168,7 +168,7 @@ def test_upload_changes_n( assert "+++ local" in result.stdout def test_upload_no_file(self, mocker, runner): - mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mocker.patch.object(AzureSdk, "upload_blob", return_value=None) result = runner.invoke( config_command_group, ["upload"], diff --git a/tests/config/test_pulumi.py b/tests/config/test_pulumi.py index 704f1a6a75..5d64ca8da7 100644 --- a/tests/config/test_pulumi.py +++ b/tests/config/test_pulumi.py @@ -5,7 +5,7 @@ DataSafeHavenConfigError, DataSafeHavenTypeError, ) -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk class TestDSHPulumiProject: @@ -120,7 +120,7 @@ def test_from_yaml_validation_error(self): DSHPulumiConfig.from_yaml(not_valid) def test_upload(self, mocker, pulumi_config, context): - mock_method = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_method = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) pulumi_config.upload(context) mock_method.assert_called_once_with( @@ -133,7 +133,7 @@ def test_upload(self, mocker, pulumi_config, context): def test_from_remote(self, mocker, pulumi_config_yaml, context): mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=pulumi_config_yaml + AzureSdk, "download_blob", return_value=pulumi_config_yaml ) pulumi_config = DSHPulumiConfig.from_remote(context) @@ -149,9 +149,9 @@ def test_from_remote(self, mocker, pulumi_config_yaml, context): ) def test_from_remote_or_create(self, mocker, pulumi_config_yaml, context): - mock_exists = mocker.patch.object(AzureApi, "blob_exists", return_value=True) + mock_exists = mocker.patch.object(AzureSdk, "blob_exists", return_value=True) mock_download = mocker.patch.object( - AzureApi, "download_blob", return_value=pulumi_config_yaml + AzureSdk, "download_blob", return_value=pulumi_config_yaml ) pulumi_config = DSHPulumiConfig.from_remote_or_create(context, projects={}) @@ -176,7 +176,7 @@ def test_from_remote_or_create(self, mocker, pulumi_config_yaml, context): def test_from_remote_or_create_create( self, mocker, pulumi_config_yaml, context # noqa: ARG002 ): - mock_exists = mocker.patch.object(AzureApi, "blob_exists", return_value=False) + mock_exists = mocker.patch.object(AzureSdk, "blob_exists", return_value=False) pulumi_config = DSHPulumiConfig.from_remote_or_create( context, encrypted_key="abc", projects={} ) diff --git a/tests/config/test_shm_config.py b/tests/config/test_shm_config.py index 5eda4d2152..af3a0e4c16 100644 --- a/tests/config/test_shm_config.py +++ b/tests/config/test_shm_config.py @@ -9,7 +9,7 @@ from data_safe_haven.exceptions import ( DataSafeHavenTypeError, ) -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk class TestConfig: @@ -52,7 +52,7 @@ def test_from_remote( self, mocker, context, shm_config: SHMConfig, shm_config_yaml ) -> None: mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=shm_config_yaml + AzureSdk, "download_blob", return_value=shm_config_yaml ) config = SHMConfig.from_remote(context) @@ -68,7 +68,7 @@ def test_to_yaml(self, shm_config: SHMConfig, shm_config_yaml) -> None: assert shm_config.to_yaml() == shm_config_yaml def test_upload(self, mocker, context: Context, shm_config: SHMConfig) -> None: - mock_method = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_method = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) shm_config.upload(context) mock_method.assert_called_once_with( diff --git a/tests/config/test_sre_config.py b/tests/config/test_sre_config.py index 8997cd716f..54daf3ac53 100644 --- a/tests/config/test_sre_config.py +++ b/tests/config/test_sre_config.py @@ -10,7 +10,7 @@ from data_safe_haven.exceptions import ( DataSafeHavenTypeError, ) -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.types import SoftwarePackageCategory @@ -83,7 +83,7 @@ def test_from_remote( self, mocker, context: Context, sre_config: SREConfig, sre_config_yaml: str ) -> None: mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=sre_config_yaml + AzureSdk, "download_blob", return_value=sre_config_yaml ) config = SREConfig.from_remote(context) @@ -99,7 +99,7 @@ def test_to_yaml(self, sre_config: SREConfig, sre_config_yaml: str) -> None: assert sre_config.to_yaml() == sre_config_yaml def test_upload(self, mocker, context, sre_config) -> None: - mock_method = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_method = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) sre_config.upload(context) mock_method.assert_called_once_with( diff --git a/tests/conftest.py b/tests/conftest.py index e424e7c9f2..e1fe0f9033 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,8 +26,8 @@ ConfigSubsectionRemoteDesktopOpts, ) from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.external import AzureApi, PulumiAccount -from data_safe_haven.external.api.credentials import AzureApiCredential +from data_safe_haven.external import AzureSdk, PulumiAccount +from data_safe_haven.external.api.credentials import AzureSdkCredential from data_safe_haven.infrastructure import SREProjectManager from data_safe_haven.infrastructure.project_manager import ProjectManager from data_safe_haven.logging import init_logging @@ -134,7 +134,7 @@ def mock_azureapi_get_subscription(mocker): subscription.subscription_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" subscription.tenant_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" mocker.patch.object( - AzureApi, + AzureSdk, "get_subscription", return_value=subscription, ) @@ -147,7 +147,7 @@ def get_token(*args, **kwargs): # noqa: ARG002 return AccessToken("dummy-token", 0) mocker.patch.object( - AzureApiCredential, + AzureSdkCredential, "get_credential", return_value=MockCredential(), ) @@ -161,7 +161,7 @@ def fail_get_credential(): raise DataSafeHavenAzureError(msg) mocker.patch.object( - AzureApiCredential, + AzureSdkCredential, "get_credential", side_effect=fail_get_credential, ) @@ -183,7 +183,7 @@ def __init__(self, key_name, key_vault_name): def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001 return MockKeyVaultKey(key_name, key_vault_name) - monkeypatch.setattr(AzureApi, "get_keyvault_key", mock_get_keyvault_key) + monkeypatch.setattr(AzureSdk, "get_keyvault_key", mock_get_keyvault_key) @fixture diff --git a/tests/external/api/test_azure_api.py b/tests/external/api/test_azure_api.py index 20e435d1d9..38fbfbed1a 100644 --- a/tests/external/api/test_azure_api.py +++ b/tests/external/api/test_azure_api.py @@ -2,9 +2,9 @@ from azure.core.exceptions import ResourceNotFoundError from pytest import fixture -import data_safe_haven.external.api.azure_api +import data_safe_haven.external.api.azure_sdk from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.external.api.azure_api import AzureApi +from data_safe_haven.external.api.azure_sdk import AzureSdk @fixture @@ -21,7 +21,7 @@ def get_key(self, key_name): raise ResourceNotFoundError monkeypatch.setattr( - data_safe_haven.external.api.azure_api, "KeyClient", MockKeyClient + data_safe_haven.external.api.azure_sdk, "KeyClient", MockKeyClient ) @@ -58,25 +58,25 @@ def mock_blob_client( ) monkeypatch.setattr( - data_safe_haven.external.api.azure_api.AzureApi, "blob_client", mock_blob_client + data_safe_haven.external.api.azure_sdk.AzureSdk, "blob_client", mock_blob_client ) -class TestAzureApi: +class TestAzureSdk: def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 - api = AzureApi("subscription name") + api = AzureSdk("subscription name") key = api.get_keyvault_key("exists", "key vault name") assert key == "key: exists" def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 - api = AzureApi("subscription name") + api = AzureSdk("subscription name") with pytest.raises( DataSafeHavenAzureError, match="Failed to retrieve key does not exist" ): api.get_keyvault_key("does not exist", "key vault name") def test_blob_exists(self, mock_blob_client): # noqa: ARG002 - api = AzureApi("subscription name") + api = AzureSdk("subscription name") exists = api.blob_exists( "exists", "resource_group", "storage_account", "storage_container" ) @@ -84,7 +84,7 @@ def test_blob_exists(self, mock_blob_client): # noqa: ARG002 assert exists def test_blob_does_not_exist(self, mock_blob_client): # noqa: ARG002 - api = AzureApi("subscription name") + api = AzureSdk("subscription name") exists = api.blob_exists( "abc.txt", "resource_group", "storage_account", "storage_container" ) diff --git a/tests/serialisers/test_azure_serialisable_model.py b/tests/serialisers/test_azure_serialisable_model.py index 81636386a4..093d4f70d9 100644 --- a/tests/serialisers/test_azure_serialisable_model.py +++ b/tests/serialisers/test_azure_serialisable_model.py @@ -4,7 +4,7 @@ DataSafeHavenConfigError, DataSafeHavenTypeError, ) -from data_safe_haven.external import AzureApi +from data_safe_haven.external import AzureSdk from data_safe_haven.serialisers import AzureSerialisableModel @@ -36,7 +36,7 @@ def test_constructor(self, example_config_class): def test_remote_yaml_diff(self, mocker, example_config_class, context): mocker.patch.object( - AzureApi, "download_blob", return_value=example_config_class.to_yaml() + AzureSdk, "download_blob", return_value=example_config_class.to_yaml() ) diff = example_config_class.remote_yaml_diff(context) assert not diff @@ -44,7 +44,7 @@ def test_remote_yaml_diff(self, mocker, example_config_class, context): def test_remote_yaml_diff_difference(self, mocker, example_config_class, context): mocker.patch.object( - AzureApi, "download_blob", return_value=example_config_class.to_yaml() + AzureSdk, "download_blob", return_value=example_config_class.to_yaml() ) example_config_class.integer = 0 example_config_class.string = "abc" @@ -74,7 +74,7 @@ def test_to_yaml(self, example_config_class): assert "config_type" not in yaml def test_upload(self, mocker, example_config_class, context): - mock_method = mocker.patch.object(AzureApi, "upload_blob", return_value=None) + mock_method = mocker.patch.object(AzureSdk, "upload_blob", return_value=None) example_config_class.upload(context) mock_method.assert_called_once_with( @@ -124,7 +124,7 @@ def test_from_yaml_validation_error(self): def test_from_remote(self, mocker, context, example_config_yaml): mock_method = mocker.patch.object( - AzureApi, "download_blob", return_value=example_config_yaml + AzureSdk, "download_blob", return_value=example_config_yaml ) example_config = ExampleAzureSerialisableModel.from_remote(context) From 8405746e2d93cd496d7f5c6088ce68dad918294d Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 15:45:26 +0100 Subject: [PATCH 27/54] :memo: Fix docstring --- data_safe_haven/external/api/azure_sdk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_safe_haven/external/api/azure_sdk.py b/data_safe_haven/external/api/azure_sdk.py index f6ecfb1d7e..726b3a55ea 100644 --- a/data_safe_haven/external/api/azure_sdk.py +++ b/data_safe_haven/external/api/azure_sdk.py @@ -718,7 +718,7 @@ def get_storage_account_keys( raise DataSafeHavenAzureError(msg) from exc def get_subscription(self, subscription_name: str) -> Subscription: - """Get the current Azure subscription.""" + """Get an Azure subscription by name.""" try: subscription_client = SubscriptionClient(self.credential) for subscription in subscription_client.subscriptions.list(): From 9ff142db1cb23977ae07dc3f3d2a9742566a4341 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 8 Jul 2024 15:55:18 +0100 Subject: [PATCH 28/54] :truck: Move token decoding into DeferredCredential --- data_safe_haven/external/api/credentials.py | 23 +++++++++++++++------ data_safe_haven/external/api/graph_api.py | 9 +++----- tests/conftest.py | 1 - 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index a891560e81..b5b869a679 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -16,7 +16,7 @@ TokenCachePersistenceOptions, ) -from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenValueError from data_safe_haven.logging import get_logger @@ -39,6 +39,20 @@ def token(self) -> str: """Get a token from the credential provider.""" return str(self.get_token(*self.scopes, tenant_id=self.tenant_id).token) + @classmethod + def decode_token(cls, auth_token: str) -> dict[str, Any]: + try: + return dict( + jwt.decode( + auth_token, + algorithms=["RS256"], + options={"verify_signature": False}, + ) + ) + except (jwt.exceptions.DecodeError, KeyError) as exc: + msg = "Could not interpret input as an Azure authentication token." + raise DataSafeHavenValueError(msg) from exc + @abstractmethod def get_credential(self) -> TokenCredential: """Get a credential provider from the child class.""" @@ -76,10 +90,7 @@ def get_credential(self) -> TokenCredential: credential = AzureCliCredential(additionally_allowed_tenants=["*"]) # Check that we are logged into Azure try: - token = credential.get_token(*self.scopes).token - decoded = jwt.decode( - token, algorithms=["RS256"], options={"verify_signature": False} - ) + decoded = self.decode_token(credential.get_token(*self.scopes).token) self.logger.info( "You are currently logged into the [blue]Azure CLI[/] with the following details:" ) @@ -89,7 +100,7 @@ def get_credential(self) -> TokenCredential: self.logger.info( f"... tenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" ) - except CredentialUnavailableError as exc: + except (CredentialUnavailableError, DataSafeHavenValueError) as exc: self.logger.error( "Please authenticate with Azure: run '[green]az login[/]' using [bold]infrastructure administrator[/] credentials." ) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index eeb77558e0..4b0355085a 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -7,7 +7,6 @@ from contextlib import suppress from typing import Any, ClassVar, Self -import jwt import requests import typer from dns import resolver @@ -72,14 +71,12 @@ def from_scopes( def from_token(cls: type[Self], auth_token: str) -> "GraphApi": """Construct a GraphApi from an existing authentication token.""" try: - decoded = jwt.decode( - auth_token, algorithms=["RS256"], options={"verify_signature": False} - ) + decoded = DeferredCredential.decode_token(auth_token) return cls.from_scopes( scopes=str(decoded["scp"]).split(), tenant_id=decoded["tid"] ) - except (jwt.exceptions.DecodeError, KeyError) as exc: - msg = "Could not interpret Graph API authentication token." + except DataSafeHavenValueError as exc: + msg = "Could not construct GraphApi from provided token." raise DataSafeHavenValueError(msg) from exc @property diff --git a/tests/conftest.py b/tests/conftest.py index e1fe0f9033..c307746717 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ from pulumi.automation import ProjectSettings from pytest import fixture -import data_safe_haven.commands.sre as sre_mod import data_safe_haven.config.context_manager as context_mod import data_safe_haven.logging.logger from data_safe_haven.config import ( From 3555b29c8f0e8ad1239ae75207a0db8a2723d71a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 11:25:02 +0100 Subject: [PATCH 29/54] :art: Better check for existing authentication record Co-authored-by: Jim Madge --- data_safe_haven/external/api/credentials.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index b5b869a679..bdb9f032d9 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -132,15 +132,15 @@ def get_credential(self) -> TokenCredential: # Read an existing authentication record, using default arguments if unavailable kwargs = {} - try: - with open(authentication_record_path) as f_auth: - existing_auth_record = AuthenticationRecord.deserialize(f_auth.read()) - kwargs["authentication_record"] = existing_auth_record - except FileNotFoundError: - kwargs["authority"] = "https://login.microsoftonline.com/" - # Use the Microsoft Graph Command Line Tools client ID - kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" - kwargs["tenant_id"] = self.tenant_id +if authentication_record_path.is_file(): + with open(authentication_record_path) as f_auth: + existing_auth_record = AuthenticationRecord.deserialize(f_auth.read()) + kwargs["authentication_record"] = existing_auth_record +else: + kwargs["authority"] = "https://login.microsoftonline.com/" + # Use the Microsoft Graph Command Line Tools client ID + kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" + kwargs["tenant_id"] = self.tenant_id # Get a credential with a custom callback def callback(verification_uri: str, user_code: str, _: datetime) -> None: From a748bc3f84ae9b74e20f7064ef1208ea46850646 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 11:27:12 +0100 Subject: [PATCH 30/54] :loud_sound: Replace ellipses with tabs Co-authored-by: Jim Madge --- data_safe_haven/external/api/credentials.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index bdb9f032d9..cedc7fa692 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -95,10 +95,10 @@ def get_credential(self) -> TokenCredential: "You are currently logged into the [blue]Azure CLI[/] with the following details:" ) self.logger.info( - f"... user: [green]{decoded['name']}[/] ({decoded['oid']})" + f"\tuser: [green]{decoded['name']}[/] ({decoded['oid']})" ) self.logger.info( - f"... tenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" + f"\ttenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" ) except (CredentialUnavailableError, DataSafeHavenValueError) as exc: self.logger.error( @@ -167,10 +167,10 @@ def callback(verification_uri: str, user_code: str, _: datetime) -> None: "You are currently logged into the [blue]Microsoft Graph API[/] with the following details:" ) self.logger.info( - f"... user: [green]{new_auth_record.username}[/] ({new_auth_record._home_account_id.split('.')[0]})" + f"\tuser: [green]{new_auth_record.username}[/] ({new_auth_record._home_account_id.split('.')[0]})" ) self.logger.info( - f"... tenant: [green]{new_auth_record._username.split('@')[1]}[/] ({new_auth_record._tenant_id})" + f"\ttenant: [green]{new_auth_record._username.split('@')[1]}[/] ({new_auth_record._tenant_id})" ) # Return the credential From c6b75233fd43a73f234037637e33155baaf7ac5a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 11:29:33 +0100 Subject: [PATCH 31/54] :loud_sound: Explain 'expires_on' property --- data_safe_haven/external/api/credentials.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index cedc7fa692..4b5a84c469 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -64,6 +64,7 @@ def get_token( **kwargs: Any, ) -> AccessToken: # Require at least 10 minutes of remaining validity + # The 'expires_on' property is a Unix timestamp integer in seconds validity_cutoff = datetime.now(tz=UTC).timestamp() + 10 * 60 if not DeferredCredential.token_ or ( DeferredCredential.token_.expires_on < validity_cutoff From fe7f740960a09045e28441cc6f721d0a03ef9f4a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:00:15 +0100 Subject: [PATCH 32/54] :loud_sound: Only print the login confirmation once for each credential --- data_safe_haven/external/api/credentials.py | 41 ++++++++++++--------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index 4b5a84c469..841124fff4 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -30,6 +30,7 @@ def __init__( scopes: Sequence[str], tenant_id: str | None = None, ) -> None: + self._show_login_msg = False self.logger = get_logger() self.scopes = scopes self.tenant_id = tenant_id @@ -92,15 +93,17 @@ def get_credential(self) -> TokenCredential: # Check that we are logged into Azure try: decoded = self.decode_token(credential.get_token(*self.scopes).token) - self.logger.info( - "You are currently logged into the [blue]Azure CLI[/] with the following details:" - ) - self.logger.info( - f"\tuser: [green]{decoded['name']}[/] ({decoded['oid']})" - ) - self.logger.info( - f"\ttenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" - ) + if self._show_login_msg: + self.logger.info( + "You are currently logged into the [blue]Azure CLI[/] with the following details:" + ) + self.logger.info( + f"\tuser: [green]{decoded['name']}[/] ({decoded['oid']})" + ) + self.logger.info( + f"\ttenant: [green]{decoded['upn'].split('@')[1]}[/] ({decoded['tid']})" + ) + self._show_login_msg = False except (CredentialUnavailableError, DataSafeHavenValueError) as exc: self.logger.error( "Please authenticate with Azure: run '[green]az login[/]' using [bold]infrastructure administrator[/] credentials." @@ -164,15 +167,17 @@ def callback(verification_uri: str, user_code: str, _: datetime) -> None: f_auth.write(new_auth_record.serialize()) # Write confirmation details about this login - self.logger.info( - "You are currently logged into the [blue]Microsoft Graph API[/] with the following details:" - ) - self.logger.info( - f"\tuser: [green]{new_auth_record.username}[/] ({new_auth_record._home_account_id.split('.')[0]})" - ) - self.logger.info( - f"\ttenant: [green]{new_auth_record._username.split('@')[1]}[/] ({new_auth_record._tenant_id})" - ) + if self._show_login_msg: + self.logger.info( + "You are currently logged into the [blue]Microsoft Graph API[/] with the following details:" + ) + self.logger.info( + f"\tuser: [green]{new_auth_record.username}[/] ({new_auth_record._home_account_id.split('.')[0]})" + ) + self.logger.info( + f"\ttenant: [green]{new_auth_record._username.split('@')[1]}[/] ({new_auth_record._tenant_id})" + ) + self._show_login_msg = False # Return the credential return credential From 0ffbf301a2fe3229752511b366f4060afcec3ee8 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:01:00 +0100 Subject: [PATCH 33/54] :bug: Fix indentation of review suggestion --- data_safe_haven/external/api/credentials.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index 841124fff4..01ede110d4 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -136,15 +136,15 @@ def get_credential(self) -> TokenCredential: # Read an existing authentication record, using default arguments if unavailable kwargs = {} -if authentication_record_path.is_file(): - with open(authentication_record_path) as f_auth: - existing_auth_record = AuthenticationRecord.deserialize(f_auth.read()) - kwargs["authentication_record"] = existing_auth_record -else: - kwargs["authority"] = "https://login.microsoftonline.com/" - # Use the Microsoft Graph Command Line Tools client ID - kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" - kwargs["tenant_id"] = self.tenant_id + if authentication_record_path.is_file(): + with open(authentication_record_path) as f_auth: + existing_auth_record = AuthenticationRecord.deserialize(f_auth.read()) + kwargs["authentication_record"] = existing_auth_record + else: + kwargs["authority"] = "https://login.microsoftonline.com/" + # Use the Microsoft Graph Command Line Tools client ID + kwargs["client_id"] = "14d82eec-204b-4c2f-b7e8-296a70dab67e" + kwargs["tenant_id"] = self.tenant_id # Get a credential with a custom callback def callback(verification_uri: str, user_code: str, _: datetime) -> None: From e95a0d7a08a032ba46faafd0c426b3c57ca8fa0b Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:03:14 +0100 Subject: [PATCH 34/54] :bug: Use config_dir instead of home directory --- data_safe_haven/external/api/credentials.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index 01ede110d4..34fc301d81 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -1,6 +1,5 @@ """Classes related to Azure credentials""" -import pathlib from abc import abstractmethod from collections.abc import Sequence from datetime import UTC, datetime @@ -16,6 +15,7 @@ TokenCachePersistenceOptions, ) +from data_safe_haven.directories import config_dir from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenValueError from data_safe_haven.logging import get_logger @@ -131,7 +131,7 @@ def get_credential(self) -> TokenCredential: """Get a new DeviceCodeCredential, using cached credentials if they are available""" cache_name = f"dsh-{self.tenant_id}" authentication_record_path = ( - pathlib.Path.home() / f".msal-authentication-cache-{cache_name}" + config_dir() / f".msal-authentication-cache-{cache_name}" ) # Read an existing authentication record, using default arguments if unavailable From aa062851256551fdce02316392c7747c377df172 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:07:28 +0100 Subject: [PATCH 35/54] :truck: Rename AzureSdk test file to match classname --- tests/external/api/{test_azure_api.py => test_azure_sdk.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/external/api/{test_azure_api.py => test_azure_sdk.py} (100%) diff --git a/tests/external/api/test_azure_api.py b/tests/external/api/test_azure_sdk.py similarity index 100% rename from tests/external/api/test_azure_api.py rename to tests/external/api/test_azure_sdk.py From d141b0cfb41646a02ca42eb153ece2d95d928b3a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:17:54 +0100 Subject: [PATCH 36/54] :wrench: Factor out the GUIDs used in various config files --- tests/conftest.py | 50 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c307746717..6a412adc45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from shutil import which from subprocess import run +import pytest import yaml from azure.core.credentials import AccessToken, TokenCredential from azure.mgmt.resource.subscriptions.models import Subscription @@ -32,12 +33,20 @@ from data_safe_haven.logging import init_logging +def pytest_configure(): + """Define constants for use across multiple tests""" + pytest.guid_admin = "00edec65-b071-4d26-8779-a9fe791c6e14" + pytest.guid_entra = "48b2425b-5f2c-4cbd-9458-0441daa8994c" + pytest.guid_subscription = "35ebced1-4e7a-4c1f-b634-c0886937085d" + pytest.guid_tenant = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" + + @fixture def azure_config(): return ConfigSectionAzure( location="uksouth", - subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + subscription_id=pytest.guid_subscription, + tenant_id=pytest.guid_tenant, ) @@ -130,8 +139,8 @@ def log_directory(session_mocker, tmp_path_factory): def mock_azureapi_get_subscription(mocker): subscription = Subscription() subscription.display_name = "Data Safe Haven Acme" - subscription.subscription_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" - subscription.tenant_id = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" + subscription.subscription_id = pytest.guid_subscription + subscription.tenant_id = pytest.guid_tenant mocker.patch.object( AzureSdk, "get_subscription", @@ -329,24 +338,31 @@ def shm_config_section(shm_config_section_dict): @fixture def shm_config_section_dict(): return { - "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - "entra_tenant_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "admin_group_id": pytest.guid_admin, + "entra_tenant_id": pytest.guid_entra, "fqdn": "shm.acme.com", } @fixture def shm_config_yaml(): - content = """--- + content = ( + """--- azure: location: uksouth - subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + subscription_id: guid_subscription + tenant_id: guid_tenant shm: - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - entra_tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + admin_group_id: guid_admin + entra_tenant_id: guid_entra fqdn: shm.acme.com - """ + """.replace( + "guid_admin", pytest.guid_admin + ) + .replace("guid_entra", pytest.guid_entra) + .replace("guid_subscription", pytest.guid_subscription) + .replace("guid_tenant", pytest.guid_tenant) + ) return yaml.dump(yaml.safe_load(content)) @@ -399,8 +415,8 @@ def sre_config_yaml(): content = """--- azure: location: uksouth - subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + subscription_id: guid_subscription + tenant_id: guid_tenant description: Sandbox Project name: sandbox sre: @@ -416,7 +432,11 @@ def sre_config_yaml(): software_packages: none timezone: Europe/London workspace_skus: [] - """ + """.replace( + "guid_subscription", pytest.guid_subscription + ).replace( + "guid_tenant", pytest.guid_tenant + ) return yaml.dump(yaml.safe_load(content)) From 334cc80d48036a37bc357e6cd3a131cca9c0cd9a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:23:16 +0100 Subject: [PATCH 37/54] :white_check_mark: Add tests for entra_directory, subscription_id and tenant_id --- tests/external/api/test_azure_sdk.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index 38fbfbed1a..d85ad968e7 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -4,7 +4,7 @@ import data_safe_haven.external.api.azure_sdk from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.external.api.azure_sdk import AzureSdk +from data_safe_haven.external import AzureSdk, GraphApi @fixture @@ -63,6 +63,24 @@ def mock_blob_client( class TestAzureSdk: + def test_entra_directory(self): + api = AzureSdk("subscription name") + assert isinstance(api.entra_directory, GraphApi) + + def test_subscription_id( + self, + mock_azureapi_get_subscription, # noqa: ARG002 + ): + api = AzureSdk("subscription name") + assert api.subscription_id == pytest.guid_subscription + + def test_tenant_id( + self, + mock_azureapi_get_subscription, # noqa: ARG002 + ): + api = AzureSdk("subscription name") + assert api.tenant_id == pytest.guid_tenant + def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 api = AzureSdk("subscription name") key = api.get_keyvault_key("exists", "key vault name") From bf88e6fb92bcad949b0b9cbdeed57c0774fc48ce Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 12:23:57 +0100 Subject: [PATCH 38/54] :wrench: Alphabetise TestAzureSdk functions --- tests/external/api/test_azure_sdk.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index d85ad968e7..51e36aca36 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -81,18 +81,6 @@ def test_tenant_id( api = AzureSdk("subscription name") assert api.tenant_id == pytest.guid_tenant - def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 - api = AzureSdk("subscription name") - key = api.get_keyvault_key("exists", "key vault name") - assert key == "key: exists" - - def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 - api = AzureSdk("subscription name") - with pytest.raises( - DataSafeHavenAzureError, match="Failed to retrieve key does not exist" - ): - api.get_keyvault_key("does not exist", "key vault name") - def test_blob_exists(self, mock_blob_client): # noqa: ARG002 api = AzureSdk("subscription name") exists = api.blob_exists( @@ -108,3 +96,15 @@ def test_blob_does_not_exist(self, mock_blob_client): # noqa: ARG002 ) assert isinstance(exists, bool) assert not exists + + def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 + api = AzureSdk("subscription name") + key = api.get_keyvault_key("exists", "key vault name") + assert key == "key: exists" + + def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 + api = AzureSdk("subscription name") + with pytest.raises( + DataSafeHavenAzureError, match="Failed to retrieve key does not exist" + ): + api.get_keyvault_key("does not exist", "key vault name") From d27c2b70710a842dea3ade9a8efc7787726d4e3f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 15:37:11 +0100 Subject: [PATCH 39/54] :white_check_mark: Add tests for AzureSdkCredential --- tests/external/api/conftest.py | 32 ++++++++++++++++++++++++++ tests/external/api/test_credentials.py | 22 ++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/external/api/conftest.py create mode 100644 tests/external/api/test_credentials.py diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py new file mode 100644 index 0000000000..81d5b47cc7 --- /dev/null +++ b/tests/external/api/conftest.py @@ -0,0 +1,32 @@ +import jwt +import pytest +from azure.core.credentials import AccessToken +from azure.identity import AzureCliCredential +from pytest import fixture + + +def pytest_configure(): + """Define constants for use across multiple tests""" + pytest.user_id = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" + + +@fixture +def jwt_token(): + return jwt.encode( + { + "name": "username", + "oid": pytest.user_id, + "upn": "username@example.com", + "tid": pytest.guid_tenant, + }, + "key", + ) + + +@fixture +def mock_azureclicredential_get_token(mocker, jwt_token): + mocker.patch.object( + AzureCliCredential, + "get_token", + return_value=AccessToken(jwt_token, 0), + ) diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py new file mode 100644 index 0000000000..20e4052543 --- /dev/null +++ b/tests/external/api/test_credentials.py @@ -0,0 +1,22 @@ +import pytest +from azure.identity import AzureCliCredential + +from data_safe_haven.external.api.credentials import AzureSdkCredential + + +class TestAzureSdkCredential: + def test_get_credential(self, mock_azureclicredential_get_token): # noqa: ARG002 + credential = AzureSdkCredential() + assert isinstance(credential.get_credential(), AzureCliCredential) + + def test_get_token(self, mock_azureclicredential_get_token): # noqa: ARG002 + credential = AzureSdkCredential() + assert isinstance(credential.token, str) + + def test_decode_token(self, mock_azureclicredential_get_token): # noqa: ARG002 + token = AzureSdkCredential().token + decoded = AzureSdkCredential.decode_token(token) + assert decoded["name"] == "username" + assert decoded["oid"] == pytest.user_id + assert decoded["upn"] == "username@example.com" + assert decoded["tid"] == pytest.guid_tenant From 6ad75e8178c964306095d8c95a0115052a4a0122 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 16:41:23 +0100 Subject: [PATCH 40/54] :white_check_mark: Add tests for GraphApiCredential --- tests/external/api/conftest.py | 78 +++++++++++++++++++++++--- tests/external/api/test_credentials.py | 38 +++++++++++-- 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index 81d5b47cc7..fc6179b874 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -1,32 +1,94 @@ +import os + import jwt import pytest from azure.core.credentials import AccessToken -from azure.identity import AzureCliCredential -from pytest import fixture +from azure.identity import ( + AuthenticationRecord, + AzureCliCredential, + DeviceCodeCredential, +) + +from data_safe_haven.external.api.credentials import GraphApiCredential def pytest_configure(): """Define constants for use across multiple tests""" + pytest.user_upn = "username@example.com" pytest.user_id = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" -@fixture -def jwt_token(): +@pytest.fixture +def authentication_record(): + return AuthenticationRecord( + tenant_id=pytest.guid_tenant, + client_id="14d82eec-204b-4c2f-b7e8-296a70dab67e", + authority="login.microsoftonline.com", + home_account_id=pytest.user_id, + username=pytest.user_upn, + ) + + +@pytest.fixture +def azure_cli_token(): return jwt.encode( { "name": "username", "oid": pytest.user_id, - "upn": "username@example.com", + "upn": pytest.user_upn, "tid": pytest.guid_tenant, }, "key", ) -@fixture -def mock_azureclicredential_get_token(mocker, jwt_token): +@pytest.fixture +def graph_api_token(): + return jwt.encode( + { + "scp": "GroupMember.Read.All User.Read.All", + "tid": pytest.guid_tenant, + }, + "key", + ) + + +@pytest.fixture +def mock_azureclicredential_get_token(mocker, azure_cli_token): mocker.patch.object( AzureCliCredential, "get_token", - return_value=AccessToken(jwt_token, 0), + return_value=AccessToken(azure_cli_token, 0), ) + + +@pytest.fixture +def mock_devicecodecredential_get_token(mocker, graph_api_token): + mocker.patch.object( + DeviceCodeCredential, + "get_token", + return_value=AccessToken(graph_api_token, 0), + ) + + +@pytest.fixture +def mock_devicecodecredential_authenticate(mocker, authentication_record): + mocker.patch.object( + DeviceCodeCredential, + "authenticate", + return_value=authentication_record, + ) + + +@pytest.fixture +def mock_graphapicredential_get_token(mocker, graph_api_token): + mocker.patch.object( + GraphApiCredential, + "get_token", + return_value=AccessToken(graph_api_token, 0), + ) + + +@pytest.fixture +def tmp_config_dir(mocker, tmp_path): + mocker.patch.dict(os.environ, {"DSH_CONFIG_DIRECTORY": str(tmp_path)}) diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index 20e4052543..fa382000db 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -1,7 +1,10 @@ import pytest -from azure.identity import AzureCliCredential +from azure.identity import AzureCliCredential, DeviceCodeCredential -from data_safe_haven.external.api.credentials import AzureSdkCredential +from data_safe_haven.external.api.credentials import ( + AzureSdkCredential, + GraphApiCredential, +) class TestAzureSdkCredential: @@ -14,9 +17,36 @@ def test_get_token(self, mock_azureclicredential_get_token): # noqa: ARG002 assert isinstance(credential.token, str) def test_decode_token(self, mock_azureclicredential_get_token): # noqa: ARG002 - token = AzureSdkCredential().token - decoded = AzureSdkCredential.decode_token(token) + credential = AzureSdkCredential() + decoded = credential.decode_token(credential.token) assert decoded["name"] == "username" assert decoded["oid"] == pytest.user_id assert decoded["upn"] == "username@example.com" assert decoded["tid"] == pytest.guid_tenant + + +class TestGraphApiCredential: + def test_get_credential( + self, + mock_devicecodecredential_get_token, # noqa: ARG002 + mock_devicecodecredential_authenticate, # noqa: ARG002 + tmp_config_dir, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + assert isinstance(credential.get_credential(), DeviceCodeCredential) + + def test_get_token( + self, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + assert isinstance(credential.token, str) + + def test_decode_token( + self, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + decoded = credential.decode_token(credential.token) + assert decoded["scp"] == "GroupMember.Read.All User.Read.All" + assert decoded["tid"] == pytest.guid_tenant From b8c94602d5e80044b205b0cbb5044404a3173383 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 18:25:25 +0100 Subject: [PATCH 41/54] :white_check_mark: Add tests for get_subscription --- tests/external/api/conftest.py | 1 + tests/external/api/test_azure_sdk.py | 75 ++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index fc6179b874..b69e3f1f9d 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -14,6 +14,7 @@ def pytest_configure(): """Define constants for use across multiple tests""" + pytest.subscription_id = "64954419-0f4b-4f6f-bd76-0d4de6ca8b83" pytest.user_upn = "username@example.com" pytest.user_id = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index 51e36aca36..fb184a4355 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -1,9 +1,10 @@ import pytest from azure.core.exceptions import ResourceNotFoundError +from azure.mgmt.resource.subscriptions.models import Subscription from pytest import fixture import data_safe_haven.external.api.azure_sdk -from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenValueError from data_safe_haven.external import AzureSdk, GraphApi @@ -62,49 +63,93 @@ def mock_blob_client( ) +@fixture +def mock_subscription_client(monkeypatch): + class MockSubscriptionsOperations: + def __init__(self, *args, **kwargs): + pass + + def list(self): + subscription_1 = Subscription() + subscription_1.display_name = "Subscription 1" + subscription_1.id = pytest.subscription_id + subscription_2 = Subscription() + subscription_2.display_name = "Subscription 2" + return [subscription_1, subscription_2] + + class MockSubscriptionClient: + def __init__(self, *args, **kwargs): + pass + + @property + def subscriptions(self): + return MockSubscriptionsOperations() + + monkeypatch.setattr( + data_safe_haven.external.api.azure_sdk, + "SubscriptionClient", + MockSubscriptionClient, + ) + + class TestAzureSdk: def test_entra_directory(self): - api = AzureSdk("subscription name") - assert isinstance(api.entra_directory, GraphApi) + sdk = AzureSdk("subscription name") + assert isinstance(sdk.entra_directory, GraphApi) def test_subscription_id( self, mock_azureapi_get_subscription, # noqa: ARG002 ): - api = AzureSdk("subscription name") - assert api.subscription_id == pytest.guid_subscription + sdk = AzureSdk("subscription name") + assert sdk.subscription_id == pytest.guid_subscription def test_tenant_id( self, mock_azureapi_get_subscription, # noqa: ARG002 ): - api = AzureSdk("subscription name") - assert api.tenant_id == pytest.guid_tenant + sdk = AzureSdk("subscription name") + assert sdk.tenant_id == pytest.guid_tenant def test_blob_exists(self, mock_blob_client): # noqa: ARG002 - api = AzureSdk("subscription name") - exists = api.blob_exists( + sdk = AzureSdk("subscription name") + exists = sdk.blob_exists( "exists", "resource_group", "storage_account", "storage_container" ) assert isinstance(exists, bool) assert exists def test_blob_does_not_exist(self, mock_blob_client): # noqa: ARG002 - api = AzureSdk("subscription name") - exists = api.blob_exists( + sdk = AzureSdk("subscription name") + exists = sdk.blob_exists( "abc.txt", "resource_group", "storage_account", "storage_container" ) assert isinstance(exists, bool) assert not exists def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 - api = AzureSdk("subscription name") - key = api.get_keyvault_key("exists", "key vault name") + sdk = AzureSdk("subscription name") + key = sdk.get_keyvault_key("exists", "key vault name") assert key == "key: exists" def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 - api = AzureSdk("subscription name") + sdk = AzureSdk("subscription name") with pytest.raises( DataSafeHavenAzureError, match="Failed to retrieve key does not exist" ): - api.get_keyvault_key("does not exist", "key vault name") + sdk.get_keyvault_key("does not exist", "key vault name") + + def test_get_subscription(self, mock_subscription_client): # noqa: ARG002 + sdk = AzureSdk("subscription name") + subscription = sdk.get_subscription("Subscription 1") + assert isinstance(subscription, Subscription) + assert subscription.display_name == "Subscription 1" + assert subscription.id == pytest.subscription_id + + def test_get_subscription_fails(self, mock_subscription_client): # noqa: ARG002 + sdk = AzureSdk("subscription name") + with pytest.raises( + DataSafeHavenValueError, + match="Could not find subscription 'Subscription 3'", + ): + sdk.get_subscription("Subscription 3") From 4d1688297093a94aeae533a2a5906d7bfd16f682 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 18:26:36 +0100 Subject: [PATCH 42/54] :bug: Show login message by default --- data_safe_haven/external/api/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index 34fc301d81..de820fcff9 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -30,7 +30,7 @@ def __init__( scopes: Sequence[str], tenant_id: str | None = None, ) -> None: - self._show_login_msg = False + self._show_login_msg = True self.logger = get_logger() self.scopes = scopes self.tenant_id = tenant_id From a6913b68d591b8863b2802f960e9ea45e0e81625 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 9 Jul 2024 18:37:53 +0100 Subject: [PATCH 43/54] :white_check_mark: Add basic tests for GraphApi --- tests/external/api/conftest.py | 1 - tests/external/api/test_azure_sdk.py | 4 ++-- tests/external/api/test_graph_api.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 tests/external/api/test_graph_api.py diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index b69e3f1f9d..fc6179b874 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -14,7 +14,6 @@ def pytest_configure(): """Define constants for use across multiple tests""" - pytest.subscription_id = "64954419-0f4b-4f6f-bd76-0d4de6ca8b83" pytest.user_upn = "username@example.com" pytest.user_id = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index fb184a4355..4f7ff84dd2 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -72,7 +72,7 @@ def __init__(self, *args, **kwargs): def list(self): subscription_1 = Subscription() subscription_1.display_name = "Subscription 1" - subscription_1.id = pytest.subscription_id + subscription_1.id = pytest.guid_subscription subscription_2 = Subscription() subscription_2.display_name = "Subscription 2" return [subscription_1, subscription_2] @@ -144,7 +144,7 @@ def test_get_subscription(self, mock_subscription_client): # noqa: ARG002 subscription = sdk.get_subscription("Subscription 1") assert isinstance(subscription, Subscription) assert subscription.display_name == "Subscription 1" - assert subscription.id == pytest.subscription_id + assert subscription.id == pytest.guid_subscription def test_get_subscription_fails(self, mock_subscription_client): # noqa: ARG002 sdk = AzureSdk("subscription name") diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py new file mode 100644 index 0000000000..a9813f866c --- /dev/null +++ b/tests/external/api/test_graph_api.py @@ -0,0 +1,19 @@ +import pytest + +from data_safe_haven.external import GraphApi + + +class TestGraphApi: + def test_from_scopes(self): + api = GraphApi.from_scopes( + scopes=["scope1", "scope2"], tenant_id=pytest.guid_tenant + ) + assert api.credential.tenant_id == pytest.guid_tenant + assert "scope1" in api.credential.scopes + assert "scope2" in api.credential.scopes + + def test_from_token(self, graph_api_token): + api = GraphApi.from_token(graph_api_token) + assert api.credential.tenant_id == pytest.guid_tenant + assert "GroupMember.Read.All" in api.credential.scopes + assert "User.Read.All" in api.credential.scopes From f6951620eea894b9343a2910e78dec8612705ba2 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 11 Jul 2024 10:09:10 +0100 Subject: [PATCH 44/54] Move coverage omit to toml configuration --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d3d95c0086..bcd291c64a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,10 @@ source = ["data_safe_haven/"] [tool.coverage.run] relative_files = true +omit= [ + "tests/*", + "data_safe_haven/resources/*", +] [tool.hatch.envs.default] pre-install-commands = ["pip install -r requirements.txt"] @@ -122,7 +126,7 @@ dependencies = [ pre-install-commands = ["pip install -r requirements.txt"] [tool.hatch.envs.test.scripts] -test = "coverage run --omit=tests/* -m pytest {args: tests}" +test = "coverage run -m pytest {args: tests}" test-report = "coverage report {args:}" test-coverage = ["test", "test-report"] From ea2419b564b6f78ddfa0c704ee7afc6b36a8a103 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 11:14:05 +0100 Subject: [PATCH 45/54] :white_check_mark: Test invalid JWT --- data_safe_haven/external/api/credentials.py | 1 - tests/external/api/conftest.py | 9 +++++++++ tests/external/api/test_credentials.py | 13 +++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/data_safe_haven/external/api/credentials.py b/data_safe_haven/external/api/credentials.py index de820fcff9..b5b79ca5c4 100644 --- a/data_safe_haven/external/api/credentials.py +++ b/data_safe_haven/external/api/credentials.py @@ -57,7 +57,6 @@ def decode_token(cls, auth_token: str) -> dict[str, Any]: @abstractmethod def get_credential(self) -> TokenCredential: """Get a credential provider from the child class.""" - pass def get_token( self, diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index fc6179b874..bbbc088c02 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -62,6 +62,15 @@ def mock_azureclicredential_get_token(mocker, azure_cli_token): ) +@pytest.fixture +def mock_azureclicredential_get_token_invalid(mocker): + mocker.patch.object( + AzureCliCredential, + "get_token", + return_value=AccessToken("not a jwt", 0), + ) + + @pytest.fixture def mock_devicecodecredential_get_token(mocker, graph_api_token): mocker.patch.object( diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index fa382000db..b912e85745 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -1,12 +1,25 @@ import pytest from azure.identity import AzureCliCredential, DeviceCodeCredential +from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external.api.credentials import ( AzureSdkCredential, GraphApiCredential, ) +class TestDeferredCredential: + def test_decode_token_error( + self, mock_azureclicredential_get_token_invalid # noqa: ARG002 + ): + credential = AzureSdkCredential() + with pytest.raises( + DataSafeHavenAzureError, + match="Error getting account information from Azure CLI.", + ): + credential.decode_token(credential.token) + + class TestAzureSdkCredential: def test_get_credential(self, mock_azureclicredential_get_token): # noqa: ARG002 credential = AzureSdkCredential() From 87a4af45424c073e7406fdda222460811535ea63 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 12:05:14 +0100 Subject: [PATCH 46/54] :white_check_mark: Add test for loading an AuthenticationRecord from cache --- tests/external/api/conftest.py | 21 +++++++++++---- tests/external/api/test_credentials.py | 37 +++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index bbbc088c02..6fe9f2d37b 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -71,6 +71,15 @@ def mock_azureclicredential_get_token_invalid(mocker): ) +@pytest.fixture +def mock_devicecodecredential_authenticate(mocker, authentication_record): + mocker.patch.object( + DeviceCodeCredential, + "authenticate", + return_value=authentication_record, + ) + + @pytest.fixture def mock_devicecodecredential_get_token(mocker, graph_api_token): mocker.patch.object( @@ -81,11 +90,13 @@ def mock_devicecodecredential_get_token(mocker, graph_api_token): @pytest.fixture -def mock_devicecodecredential_authenticate(mocker, authentication_record): - mocker.patch.object( - DeviceCodeCredential, - "authenticate", - return_value=authentication_record, +def mock_devicecodecredential_new(mocker, authentication_record): + class MockDeviceCodeCredential: + def authenticate(self, *args, **kwargs): # noqa: ARG002 + return authentication_record + + return mocker.patch.object( + DeviceCodeCredential, "__new__", return_value=MockDeviceCodeCredential() ) diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index b912e85745..12e60a83be 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -1,6 +1,10 @@ import pytest -from azure.identity import AzureCliCredential, DeviceCodeCredential +from azure.identity import ( + AzureCliCredential, + DeviceCodeCredential, +) +from data_safe_haven.directories import config_dir from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external.api.credentials import ( AzureSdkCredential, @@ -63,3 +67,34 @@ def test_decode_token( decoded = credential.decode_token(credential.token) assert decoded["scp"] == "GroupMember.Read.All User.Read.All" assert decoded["tid"] == pytest.guid_tenant + + def test_authentication_record_is_used( + self, + mocker, + authentication_record, + mock_devicecodecredential_new, + tmp_config_dir, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + + # Write an authentication record + cache_name = f"dsh-{credential.tenant_id}" + authentication_record_path = ( + config_dir() / f".msal-authentication-cache-{cache_name}" + ) + with open(authentication_record_path, "w") as f_auth: + f_auth.write(authentication_record.serialize()) + + credential.get_credential() + + # Note that we cannot check the calls exactly as the objects we use would have + # different IDs + mock_devicecodecredential_new.assert_called_once_with( + mocker.ANY, # this is 'self' + authentication_record=mocker.ANY, + cache_persistence_options=mocker.ANY, + prompt_callback=mocker.ANY, + ) + + # Remove the authentication record + authentication_record_path.unlink(missing_ok=True) From ac3edd81f96c6c35b46e2bc4710655049c0f988e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 12:16:18 +0100 Subject: [PATCH 47/54] :white_check_mark: Test GraphApi constructor with invalid token --- tests/external/api/test_graph_api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py index a9813f866c..36ce525dfa 100644 --- a/tests/external/api/test_graph_api.py +++ b/tests/external/api/test_graph_api.py @@ -1,5 +1,6 @@ import pytest +from data_safe_haven.exceptions import DataSafeHavenValueError from data_safe_haven.external import GraphApi @@ -17,3 +18,10 @@ def test_from_token(self, graph_api_token): assert api.credential.tenant_id == pytest.guid_tenant assert "GroupMember.Read.All" in api.credential.scopes assert "User.Read.All" in api.credential.scopes + + def test_from_token_invalid(self): + with pytest.raises( + DataSafeHavenValueError, + match="Could not construct GraphApi from provided token.", + ): + GraphApi.from_token("not a jwt") From 6cf4be678c6ec5123934e5626f0ffc52b32a1837 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 12:34:51 +0100 Subject: [PATCH 48/54] :white_check_mark: Add test for GraphApi.token property --- tests/external/api/test_graph_api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py index 36ce525dfa..0f8892ff13 100644 --- a/tests/external/api/test_graph_api.py +++ b/tests/external/api/test_graph_api.py @@ -25,3 +25,11 @@ def test_from_token_invalid(self): match="Could not construct GraphApi from provided token.", ): GraphApi.from_token("not a jwt") + + def test_token( + self, + graph_api_token, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + api = GraphApi.from_token(graph_api_token) + assert api.token == graph_api_token From badc2996fa2410a77340be2ab7455fc610edd304 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 13:32:08 +0100 Subject: [PATCH 49/54] :white_check_mark: Test handling of ClientAuthenticationError from SubscriptionClient --- tests/external/api/test_azure_sdk.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index 4f7ff84dd2..28f2fce878 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -1,10 +1,15 @@ import pytest -from azure.core.exceptions import ResourceNotFoundError +from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError +from azure.mgmt.resource.subscriptions import SubscriptionClient from azure.mgmt.resource.subscriptions.models import Subscription from pytest import fixture import data_safe_haven.external.api.azure_sdk -from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenValueError +from data_safe_haven.exceptions import ( + DataSafeHavenAzureAPIAuthenticationError, + DataSafeHavenAzureError, + DataSafeHavenValueError, +) from data_safe_haven.external import AzureSdk, GraphApi @@ -146,10 +151,26 @@ def test_get_subscription(self, mock_subscription_client): # noqa: ARG002 assert subscription.display_name == "Subscription 1" assert subscription.id == pytest.guid_subscription - def test_get_subscription_fails(self, mock_subscription_client): # noqa: ARG002 + def test_get_subscription_does_not_exist( + self, mock_subscription_client # noqa: ARG002 + ): sdk = AzureSdk("subscription name") with pytest.raises( DataSafeHavenValueError, match="Could not find subscription 'Subscription 3'", ): sdk.get_subscription("Subscription 3") + + def test_get_subscription_authentication_error(self, mocker): + def raise_client_authentication_error(*args): # noqa: ARG001 + raise ClientAuthenticationError + + mocker.patch.object( + SubscriptionClient, "__new__", side_effect=raise_client_authentication_error + ) + sdk = AzureSdk("subscription name") + with pytest.raises( + DataSafeHavenAzureAPIAuthenticationError, + match="Failed to authenticate with Azure API.", + ): + sdk.get_subscription("Subscription 1") From ef575067bb392a915adfc3e9ff9fd319600c89ee Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 14:39:56 +0100 Subject: [PATCH 50/54] :white_check_mark: Test the device code callback --- tests/external/api/conftest.py | 13 ++++++++++++- tests/external/api/test_credentials.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index 6fe9f2d37b..c56f497550 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -1,3 +1,4 @@ +import datetime import os import jwt @@ -92,11 +93,21 @@ def mock_devicecodecredential_get_token(mocker, graph_api_token): @pytest.fixture def mock_devicecodecredential_new(mocker, authentication_record): class MockDeviceCodeCredential: + def __init__(self, *args, prompt_callback, **kwargs): # noqa: ARG002 + self.prompt_callback = prompt_callback + def authenticate(self, *args, **kwargs): # noqa: ARG002 + self.prompt_callback( + "VERIFICATION_URI", + "USER_DEVICE_CODE", + datetime.datetime.now(tz=datetime.UTC), + ) return authentication_record return mocker.patch.object( - DeviceCodeCredential, "__new__", return_value=MockDeviceCodeCredential() + DeviceCodeCredential, + "__new__", + lambda *args, **kwargs: MockDeviceCodeCredential(*args, **kwargs), ) diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index 12e60a83be..ba01eb315d 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -52,6 +52,21 @@ def test_get_credential( credential = GraphApiCredential(pytest.guid_tenant) assert isinstance(credential.get_credential(), DeviceCodeCredential) + def test_get_credential_callback( + self, + capsys, + mock_devicecodecredential_new, # noqa: ARG002 + tmp_config_dir, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + credential.get_credential() + captured = capsys.readouterr() + cleaned_stdout = " ".join(captured.out.split()) + assert ( + "Go to VERIFICATION_URI in a web browser and enter the code USER_DEVICE_CODE at the prompt." + in cleaned_stdout + ) + def test_get_token( self, mock_graphapicredential_get_token, # noqa: ARG002 From 6333ba2db88414810b8e7d38e767f6752140cc92 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 15:43:12 +0100 Subject: [PATCH 51/54] :wrench: Reorganise tests --- tests/external/api/conftest.py | 28 +++++++-- tests/external/api/test_credentials.py | 79 ++++++++++++-------------- 2 files changed, 60 insertions(+), 47 deletions(-) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index c56f497550..7f00487f6a 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -54,9 +54,18 @@ def graph_api_token(): ) +@pytest.fixture +def mock_authenticationrecord_deserialize(mocker, authentication_record): + return mocker.patch.object( + AuthenticationRecord, + "deserialize", + return_value=authentication_record, + ) + + @pytest.fixture def mock_azureclicredential_get_token(mocker, azure_cli_token): - mocker.patch.object( + return mocker.patch.object( AzureCliCredential, "get_token", return_value=AccessToken(azure_cli_token, 0), @@ -65,7 +74,7 @@ def mock_azureclicredential_get_token(mocker, azure_cli_token): @pytest.fixture def mock_azureclicredential_get_token_invalid(mocker): - mocker.patch.object( + return mocker.patch.object( AzureCliCredential, "get_token", return_value=AccessToken("not a jwt", 0), @@ -74,7 +83,7 @@ def mock_azureclicredential_get_token_invalid(mocker): @pytest.fixture def mock_devicecodecredential_authenticate(mocker, authentication_record): - mocker.patch.object( + return mocker.patch.object( DeviceCodeCredential, "authenticate", return_value=authentication_record, @@ -83,7 +92,7 @@ def mock_devicecodecredential_authenticate(mocker, authentication_record): @pytest.fixture def mock_devicecodecredential_get_token(mocker, graph_api_token): - mocker.patch.object( + return mocker.patch.object( DeviceCodeCredential, "get_token", return_value=AccessToken(graph_api_token, 0), @@ -104,13 +113,22 @@ def authenticate(self, *args, **kwargs): # noqa: ARG002 ) return authentication_record - return mocker.patch.object( + mocker.patch.object( DeviceCodeCredential, "__new__", lambda *args, **kwargs: MockDeviceCodeCredential(*args, **kwargs), ) +@pytest.fixture +def mock_graphapicredential_get_credential(mocker): + mocker.patch.object( + GraphApiCredential, + "get_credential", + return_value=DeviceCodeCredential(), + ) + + @pytest.fixture def mock_graphapicredential_get_token(mocker, graph_api_token): mocker.patch.object( diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index ba01eb315d..3fe94c3871 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -43,10 +43,44 @@ def test_decode_token(self, mock_azureclicredential_get_token): # noqa: ARG002 class TestGraphApiCredential: + def test_authentication_record_is_used( + self, + authentication_record, + mock_authenticationrecord_deserialize, + mock_devicecodecredential_authenticate, # noqa: ARG002 + tmp_config_dir, # noqa: ARG002 + ): + # Write an authentication record + cache_name = f"dsh-{pytest.guid_tenant}" + authentication_record_path = ( + config_dir() / f".msal-authentication-cache-{cache_name}" + ) + serialised_record = authentication_record.serialize() + with open(authentication_record_path, "w") as f_auth: + f_auth.write(serialised_record) + + # Get a credential + credential = GraphApiCredential(pytest.guid_tenant) + credential.get_credential() + + # Remove the authentication record + authentication_record_path.unlink(missing_ok=True) + + mock_authenticationrecord_deserialize.assert_called_once_with(serialised_record) + + def test_decode_token( + self, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + credential = GraphApiCredential(pytest.guid_tenant) + decoded = credential.decode_token(credential.token) + assert decoded["scp"] == "GroupMember.Read.All User.Read.All" + assert decoded["tid"] == pytest.guid_tenant + def test_get_credential( self, - mock_devicecodecredential_get_token, # noqa: ARG002 mock_devicecodecredential_authenticate, # noqa: ARG002 + mock_devicecodecredential_get_token, # noqa: ARG002 tmp_config_dir, # noqa: ARG002 ): credential = GraphApiCredential(pytest.guid_tenant) @@ -69,47 +103,8 @@ def test_get_credential_callback( def test_get_token( self, - mock_graphapicredential_get_token, # noqa: ARG002 + mock_devicecodecredential_get_token, # noqa: ARG002 + mock_graphapicredential_get_credential, # noqa: ARG002 ): credential = GraphApiCredential(pytest.guid_tenant) assert isinstance(credential.token, str) - - def test_decode_token( - self, - mock_graphapicredential_get_token, # noqa: ARG002 - ): - credential = GraphApiCredential(pytest.guid_tenant) - decoded = credential.decode_token(credential.token) - assert decoded["scp"] == "GroupMember.Read.All User.Read.All" - assert decoded["tid"] == pytest.guid_tenant - - def test_authentication_record_is_used( - self, - mocker, - authentication_record, - mock_devicecodecredential_new, - tmp_config_dir, # noqa: ARG002 - ): - credential = GraphApiCredential(pytest.guid_tenant) - - # Write an authentication record - cache_name = f"dsh-{credential.tenant_id}" - authentication_record_path = ( - config_dir() / f".msal-authentication-cache-{cache_name}" - ) - with open(authentication_record_path, "w") as f_auth: - f_auth.write(authentication_record.serialize()) - - credential.get_credential() - - # Note that we cannot check the calls exactly as the objects we use would have - # different IDs - mock_devicecodecredential_new.assert_called_once_with( - mocker.ANY, # this is 'self' - authentication_record=mocker.ANY, - cache_persistence_options=mocker.ANY, - prompt_callback=mocker.ANY, - ) - - # Remove the authentication record - authentication_record_path.unlink(missing_ok=True) From fb8aa67ac762b31504546cfad6bc4ae6d472f5d5 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 17:35:42 +0100 Subject: [PATCH 52/54] :white_check_mark: Add test for GraphApi.add_custom_domain --- data_safe_haven/external/api/graph_api.py | 5 ++--- tests/external/api/conftest.py | 10 ++++++++++ tests/external/api/test_graph_api.py | 23 +++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 4b0355085a..884123b536 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -95,9 +95,8 @@ def add_custom_domain(self, domain_name: str) -> str: try: # Create the Entra ID custom domain if it does not already exist domains = self.read_domains() - domain_exists = any(domain["id"] == domain_name for domain in domains) - if not domain_exists: - response = self.http_post( + if not any(domain["id"] == domain_name for domain in domains): + self.http_post( f"{self.base_endpoint}/domains", json={"id": domain_name}, ) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index 7f00487f6a..c6b523743d 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -10,6 +10,7 @@ DeviceCodeCredential, ) +from data_safe_haven.external import GraphApi from data_safe_haven.external.api.credentials import GraphApiCredential @@ -120,6 +121,15 @@ def authenticate(self, *args, **kwargs): # noqa: ARG002 ) +@pytest.fixture +def mock_graphapi_read_domains(mocker): + mocker.patch.object( + GraphApi, + "read_domains", + return_value=[{"id": "example.com"}], + ) + + @pytest.fixture def mock_graphapicredential_get_credential(mocker): mocker.patch.object( diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py index 0f8892ff13..afb0448760 100644 --- a/tests/external/api/test_graph_api.py +++ b/tests/external/api/test_graph_api.py @@ -26,6 +26,29 @@ def test_from_token_invalid(self): ): GraphApi.from_token("not a jwt") + def test_add_custom_domain( + self, + requests_mock, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + domain_name = "example.com" + requests_mock.get( + "https://graph.microsoft.com/v1.0/domains", + json={"value": [{"id": domain_name}, {"id": "example.org"}]}, + ) + requests_mock.get( + f"https://graph.microsoft.com/v1.0/domains/{domain_name}/verificationDnsRecords", + json={ + "value": [ + {"recordType": "Caa", "text": "caa-record-text"}, + {"recordType": "Txt", "text": "txt-record-text"}, + ] + }, + ) + api = GraphApi.from_scopes(scopes=[], tenant_id=pytest.guid_tenant) + result = api.add_custom_domain(domain_name) + assert result == "txt-record-text" + def test_token( self, graph_api_token, From 5bbbddf09bf35bd5ecc6218559ed94cf97f9de3f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 11 Jul 2024 18:49:47 +0100 Subject: [PATCH 53/54] :white_check_mark: Add test for http_get failure --- tests/external/api/test_graph_api.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py index afb0448760..b9c287e81d 100644 --- a/tests/external/api/test_graph_api.py +++ b/tests/external/api/test_graph_api.py @@ -1,6 +1,10 @@ import pytest +import requests -from data_safe_haven.exceptions import DataSafeHavenValueError +from data_safe_haven.exceptions import ( + DataSafeHavenMicrosoftGraphError, + DataSafeHavenValueError, +) from data_safe_haven.external import GraphApi @@ -49,6 +53,20 @@ def test_add_custom_domain( result = api.add_custom_domain(domain_name) assert result == "txt-record-text" + def test_http_get_failure( + self, + requests_mock, + mock_graphapicredential_get_token, # noqa: ARG002 + ): + url = "https://example.com" + requests_mock.get(url, exc=requests.exceptions.ConnectTimeout) + api = GraphApi.from_scopes(scopes=[], tenant_id=pytest.guid_tenant) + with pytest.raises( + DataSafeHavenMicrosoftGraphError, + match="Could not execute GET request to 'https://example.com'.", + ): + api.http_get(url) + def test_token( self, graph_api_token, From 901a74cef3625ae64accb0044a11306468d5810b Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 12 Jul 2024 11:17:55 +0100 Subject: [PATCH 54/54] :recycle: Switch to using 'config' in pytest_configure --- tests/conftest.py | 46 +++++++++++++------------- tests/external/api/conftest.py | 25 +++++++------- tests/external/api/test_azure_sdk.py | 14 ++++---- tests/external/api/test_credentials.py | 29 ++++++++++------ tests/external/api/test_graph_api.py | 16 +++++---- 5 files changed, 71 insertions(+), 59 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 337933664f..dae130f7c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ from shutil import which from subprocess import run -import pytest import yaml from azure.core.credentials import AccessToken, TokenCredential from azure.mgmt.resource.subscriptions.models import Subscription @@ -34,20 +33,21 @@ from data_safe_haven.logging import init_logging -def pytest_configure(): +def pytest_configure(config): """Define constants for use across multiple tests""" - pytest.guid_admin = "00edec65-b071-4d26-8779-a9fe791c6e14" - pytest.guid_entra = "48b2425b-5f2c-4cbd-9458-0441daa8994c" - pytest.guid_subscription = "35ebced1-4e7a-4c1f-b634-c0886937085d" - pytest.guid_tenant = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" + config.guid_admin = "00edec65-b071-4d26-8779-a9fe791c6e14" + config.guid_entra = "48b2425b-5f2c-4cbd-9458-0441daa8994c" + config.guid_subscription = "35ebced1-4e7a-4c1f-b634-c0886937085d" + config.guid_tenant = "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd" + config.guid_user = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" @fixture -def config_section_azure(): +def config_section_azure(request): return ConfigSectionAzure( location="uksouth", - subscription_id=pytest.guid_subscription, - tenant_id=pytest.guid_tenant, + subscription_id=request.config.guid_subscription, + tenant_id=request.config.guid_tenant, ) @@ -57,10 +57,10 @@ def config_section_shm(config_section_shm_dict): @fixture -def config_section_shm_dict(): +def config_section_shm_dict(request): return { - "admin_group_id": pytest.guid_admin, - "entra_tenant_id": pytest.guid_entra, + "admin_group_id": request.config.guid_admin, + "entra_tenant_id": request.config.guid_entra, "fqdn": "shm.acme.com", } @@ -168,11 +168,11 @@ def log_directory(session_mocker, tmp_path_factory): @fixture -def mock_azureapi_get_subscription(mocker): +def mock_azureapi_get_subscription(mocker, request): subscription = Subscription() subscription.display_name = "Data Safe Haven Acme" - subscription.subscription_id = pytest.guid_subscription - subscription.tenant_id = pytest.guid_tenant + subscription.subscription_id = request.config.guid_subscription + subscription.tenant_id = request.config.guid_tenant mocker.patch.object( AzureSdk, "get_subscription", @@ -363,7 +363,7 @@ def shm_config_file(shm_config_yaml: str, tmp_path: Path) -> Path: @fixture -def shm_config_yaml(): +def shm_config_yaml(request): content = ( """--- azure: @@ -375,11 +375,11 @@ def shm_config_yaml(): entra_tenant_id: guid_entra fqdn: shm.acme.com """.replace( - "guid_admin", pytest.guid_admin + "guid_admin", request.config.guid_admin ) - .replace("guid_entra", pytest.guid_entra) - .replace("guid_subscription", pytest.guid_subscription) - .replace("guid_tenant", pytest.guid_tenant) + .replace("guid_entra", request.config.guid_entra) + .replace("guid_subscription", request.config.guid_subscription) + .replace("guid_tenant", request.config.guid_tenant) ) return yaml.dump(yaml.safe_load(content)) @@ -424,7 +424,7 @@ def sre_config_alternate( @fixture -def sre_config_yaml(): +def sre_config_yaml(request): content = """--- azure: location: uksouth @@ -449,9 +449,9 @@ def sre_config_yaml(): timezone: Europe/London workspace_skus: [] """.replace( - "guid_subscription", pytest.guid_subscription + "guid_subscription", request.config.guid_subscription ).replace( - "guid_tenant", pytest.guid_tenant + "guid_tenant", request.config.guid_tenant ) return yaml.dump(yaml.safe_load(content)) diff --git a/tests/external/api/conftest.py b/tests/external/api/conftest.py index c6b523743d..7f5a59bb63 100644 --- a/tests/external/api/conftest.py +++ b/tests/external/api/conftest.py @@ -14,42 +14,41 @@ from data_safe_haven.external.api.credentials import GraphApiCredential -def pytest_configure(): +def pytest_configure(config): """Define constants for use across multiple tests""" - pytest.user_upn = "username@example.com" - pytest.user_id = "80b4ccfd-73ef-41b7-bb22-8ec268ec040b" + config.user_upn = "username@example.com" @pytest.fixture -def authentication_record(): +def authentication_record(request): return AuthenticationRecord( - tenant_id=pytest.guid_tenant, + tenant_id=request.config.guid_tenant, client_id="14d82eec-204b-4c2f-b7e8-296a70dab67e", authority="login.microsoftonline.com", - home_account_id=pytest.user_id, - username=pytest.user_upn, + home_account_id=request.config.guid_user, + username=request.config.user_upn, ) @pytest.fixture -def azure_cli_token(): +def azure_cli_token(request): return jwt.encode( { "name": "username", - "oid": pytest.user_id, - "upn": pytest.user_upn, - "tid": pytest.guid_tenant, + "oid": request.config.guid_user, + "upn": request.config.user_upn, + "tid": request.config.guid_tenant, }, "key", ) @pytest.fixture -def graph_api_token(): +def graph_api_token(request): return jwt.encode( { "scp": "GroupMember.Read.All User.Read.All", - "tid": pytest.guid_tenant, + "tid": request.config.guid_tenant, }, "key", ) diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index 28f2fce878..7c2d001386 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -69,7 +69,7 @@ def mock_blob_client( @fixture -def mock_subscription_client(monkeypatch): +def mock_subscription_client(monkeypatch, request): class MockSubscriptionsOperations: def __init__(self, *args, **kwargs): pass @@ -77,7 +77,7 @@ def __init__(self, *args, **kwargs): def list(self): subscription_1 = Subscription() subscription_1.display_name = "Subscription 1" - subscription_1.id = pytest.guid_subscription + subscription_1.id = request.config.guid_subscription subscription_2 = Subscription() subscription_2.display_name = "Subscription 2" return [subscription_1, subscription_2] @@ -104,17 +104,19 @@ def test_entra_directory(self): def test_subscription_id( self, + request, mock_azureapi_get_subscription, # noqa: ARG002 ): sdk = AzureSdk("subscription name") - assert sdk.subscription_id == pytest.guid_subscription + assert sdk.subscription_id == request.config.guid_subscription def test_tenant_id( self, + request, mock_azureapi_get_subscription, # noqa: ARG002 ): sdk = AzureSdk("subscription name") - assert sdk.tenant_id == pytest.guid_tenant + assert sdk.tenant_id == request.config.guid_tenant def test_blob_exists(self, mock_blob_client): # noqa: ARG002 sdk = AzureSdk("subscription name") @@ -144,12 +146,12 @@ def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 ): sdk.get_keyvault_key("does not exist", "key vault name") - def test_get_subscription(self, mock_subscription_client): # noqa: ARG002 + def test_get_subscription(self, request, mock_subscription_client): # noqa: ARG002 sdk = AzureSdk("subscription name") subscription = sdk.get_subscription("Subscription 1") assert isinstance(subscription, Subscription) assert subscription.display_name == "Subscription 1" - assert subscription.id == pytest.guid_subscription + assert subscription.id == request.config.guid_subscription def test_get_subscription_does_not_exist( self, mock_subscription_client # noqa: ARG002 diff --git a/tests/external/api/test_credentials.py b/tests/external/api/test_credentials.py index 3fe94c3871..0453bf6a87 100644 --- a/tests/external/api/test_credentials.py +++ b/tests/external/api/test_credentials.py @@ -33,25 +33,30 @@ def test_get_token(self, mock_azureclicredential_get_token): # noqa: ARG002 credential = AzureSdkCredential() assert isinstance(credential.token, str) - def test_decode_token(self, mock_azureclicredential_get_token): # noqa: ARG002 + def test_decode_token( + self, + request, + mock_azureclicredential_get_token, # noqa: ARG002 + ): credential = AzureSdkCredential() decoded = credential.decode_token(credential.token) assert decoded["name"] == "username" - assert decoded["oid"] == pytest.user_id + assert decoded["oid"] == request.config.guid_user assert decoded["upn"] == "username@example.com" - assert decoded["tid"] == pytest.guid_tenant + assert decoded["tid"] == request.config.guid_tenant class TestGraphApiCredential: def test_authentication_record_is_used( self, + request, authentication_record, mock_authenticationrecord_deserialize, mock_devicecodecredential_authenticate, # noqa: ARG002 tmp_config_dir, # noqa: ARG002 ): # Write an authentication record - cache_name = f"dsh-{pytest.guid_tenant}" + cache_name = f"dsh-{request.config.guid_tenant}" authentication_record_path = ( config_dir() / f".msal-authentication-cache-{cache_name}" ) @@ -60,7 +65,7 @@ def test_authentication_record_is_used( f_auth.write(serialised_record) # Get a credential - credential = GraphApiCredential(pytest.guid_tenant) + credential = GraphApiCredential(request.config.guid_tenant) credential.get_credential() # Remove the authentication record @@ -70,29 +75,32 @@ def test_authentication_record_is_used( def test_decode_token( self, + request, mock_graphapicredential_get_token, # noqa: ARG002 ): - credential = GraphApiCredential(pytest.guid_tenant) + credential = GraphApiCredential(request.config.guid_tenant) decoded = credential.decode_token(credential.token) assert decoded["scp"] == "GroupMember.Read.All User.Read.All" - assert decoded["tid"] == pytest.guid_tenant + assert decoded["tid"] == request.config.guid_tenant def test_get_credential( self, + request, mock_devicecodecredential_authenticate, # noqa: ARG002 mock_devicecodecredential_get_token, # noqa: ARG002 tmp_config_dir, # noqa: ARG002 ): - credential = GraphApiCredential(pytest.guid_tenant) + credential = GraphApiCredential(request.config.guid_tenant) assert isinstance(credential.get_credential(), DeviceCodeCredential) def test_get_credential_callback( self, capsys, + request, mock_devicecodecredential_new, # noqa: ARG002 tmp_config_dir, # noqa: ARG002 ): - credential = GraphApiCredential(pytest.guid_tenant) + credential = GraphApiCredential(request.config.guid_tenant) credential.get_credential() captured = capsys.readouterr() cleaned_stdout = " ".join(captured.out.split()) @@ -103,8 +111,9 @@ def test_get_credential_callback( def test_get_token( self, + request, mock_devicecodecredential_get_token, # noqa: ARG002 mock_graphapicredential_get_credential, # noqa: ARG002 ): - credential = GraphApiCredential(pytest.guid_tenant) + credential = GraphApiCredential(request.config.guid_tenant) assert isinstance(credential.token, str) diff --git a/tests/external/api/test_graph_api.py b/tests/external/api/test_graph_api.py index b9c287e81d..f9095db3de 100644 --- a/tests/external/api/test_graph_api.py +++ b/tests/external/api/test_graph_api.py @@ -9,17 +9,17 @@ class TestGraphApi: - def test_from_scopes(self): + def test_from_scopes(self, request): api = GraphApi.from_scopes( - scopes=["scope1", "scope2"], tenant_id=pytest.guid_tenant + scopes=["scope1", "scope2"], tenant_id=request.config.guid_tenant ) - assert api.credential.tenant_id == pytest.guid_tenant + assert api.credential.tenant_id == request.config.guid_tenant assert "scope1" in api.credential.scopes assert "scope2" in api.credential.scopes - def test_from_token(self, graph_api_token): + def test_from_token(self, request, graph_api_token): api = GraphApi.from_token(graph_api_token) - assert api.credential.tenant_id == pytest.guid_tenant + assert api.credential.tenant_id == request.config.guid_tenant assert "GroupMember.Read.All" in api.credential.scopes assert "User.Read.All" in api.credential.scopes @@ -32,6 +32,7 @@ def test_from_token_invalid(self): def test_add_custom_domain( self, + request, requests_mock, mock_graphapicredential_get_token, # noqa: ARG002 ): @@ -49,18 +50,19 @@ def test_add_custom_domain( ] }, ) - api = GraphApi.from_scopes(scopes=[], tenant_id=pytest.guid_tenant) + api = GraphApi.from_scopes(scopes=[], tenant_id=request.config.guid_tenant) result = api.add_custom_domain(domain_name) assert result == "txt-record-text" def test_http_get_failure( self, + request, requests_mock, mock_graphapicredential_get_token, # noqa: ARG002 ): url = "https://example.com" requests_mock.get(url, exc=requests.exceptions.ConnectTimeout) - api = GraphApi.from_scopes(scopes=[], tenant_id=pytest.guid_tenant) + api = GraphApi.from_scopes(scopes=[], tenant_id=request.config.guid_tenant) with pytest.raises( DataSafeHavenMicrosoftGraphError, match="Could not execute GET request to 'https://example.com'.",