diff --git a/data_safe_haven/external/__init__.py b/data_safe_haven/external/__init__.py index 05a3aed597..18473b5f6d 100644 --- a/data_safe_haven/external/__init__.py +++ b/data_safe_haven/external/__init__.py @@ -1,5 +1,5 @@ from .api.azure_api import AzureApi -from .api.azure_cli import AzureCli +from .api.azure_cli import AzureCliSingleton from .api.graph_api import GraphApi from .interface.azure_container_instance import AzureContainerInstance from .interface.azure_fileshare import AzureFileShare @@ -8,7 +8,7 @@ __all__ = [ "AzureApi", - "AzureCli", + "AzureCliSingleton", "AzureContainerInstance", "AzureFileShare", "AzureIPv4Range", diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index 54c86e6689..2848b22fba 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -1,44 +1,79 @@ """Interface to the Azure CLI""" +import json import subprocess -from typing import Any +from dataclasses import dataclass +from shutil import which + +import typer from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.utility import LoggingSingleton, Singleton + + +@dataclass +class AzureCliAccount: + """Dataclass for Azure CLI Account details""" + + name: str + id_: str + tenant_id: str -class AzureCli: +class AzureCliSingleton(metaclass=Singleton): """Interface to the Azure CLI""" - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) + def __init__(self) -> None: self.logger = LoggingSingleton() - def login(self) -> None: - """Force log in via the Azure CLI""" - try: - self.logger.debug("Attempting to login using Azure CLI.") - # We do not use `check` in subprocess as this raises a CalledProcessError - # which would break the loop. Instead we check the return code of - # `az account show` which will be 0 on success. - while True: - # Check whether we are already logged in - process = subprocess.run( - ["az", "account", "show"], capture_output=True, check=False - ) - if process.returncode == 0: - break - # Note that subprocess.run will block until the process terminates so - # we need to print the guidance first. - self.logger.info( - "Please login in your web browser at [bold]https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize[/]." - ) - self.logger.info( - "If no web browser is available, please run [bold]az login --use-device-code[/] in a command line window." - ) - # Attempt to log in at the command line - process = subprocess.run( - ["az", "login"], capture_output=True, check=False + path = which("az") + if path is None: + msg = "Unable to find Azure CLI executable in your path.\nPlease ensure that Azure CLI is installed" + raise DataSafeHavenAzureError(msg) + self.path = path + + self._account: AzureCliAccount | None = None + self._confirmed = False + + @property + def account(self) -> AzureCliAccount: + if not self._account: + try: + result = subprocess.check_output( + [self.path, "account", "show", "--output", "json"], + stderr=subprocess.PIPE, + encoding="utf8", ) - except FileNotFoundError as exc: - msg = f"Please ensure that the Azure CLI is installed.\n{exc}" - raise DataSafeHavenAzureError(msg) from exc + except subprocess.CalledProcessError as exc: + msg = f"Error getting account information from Azure CLI.\n{exc.stderr}" + raise DataSafeHavenAzureError(msg) from exc + + try: + result_dict = json.loads(result) + except json.JSONDecodeError as exc: + msg = f"Unable to parse Azure CLI output as JSON.\n{result}" + raise DataSafeHavenAzureError(msg) from exc + + self._account = AzureCliAccount( + name=result_dict.get("user").get("name"), + id_=result_dict.get("id"), + tenant_id=result_dict.get("tenantId"), + ) + + return self._account + + def confirm(self) -> None: + """Prompt user to confirm the Azure CLI account is correct""" + if self._confirmed: + return None + + account = self.account + self.logger.info( + f"name: {account.name} (id: {account.id_}\ntenant: {account.tenant_id})" + ) + if not typer.confirm("Is this the Azure account you expect?\n"): + self.logger.error( + "Please use `az login` to connect to the correct Azure CLI account" + ) + raise typer.Exit(1) + + self._confirmed = True diff --git a/data_safe_haven/external/interface/azure_authenticator.py b/data_safe_haven/external/interface/azure_authenticator.py index 4652843ec9..c6f32d91e0 100644 --- a/data_safe_haven/external/interface/azure_authenticator.py +++ b/data_safe_haven/external/interface/azure_authenticator.py @@ -2,7 +2,7 @@ from typing import cast from azure.core.exceptions import ClientAuthenticationError -from azure.identity import DefaultAzureCredential +from azure.identity import AzureCliCredential from azure.mgmt.resource.subscriptions import SubscriptionClient from azure.mgmt.resource.subscriptions.models import Subscription @@ -17,17 +17,14 @@ class AzureAuthenticator: def __init__(self, subscription_name: str) -> None: self.subscription_name: str = subscription_name - self.credential_: DefaultAzureCredential | None = None + self.credential_: AzureCliCredential | None = None self.subscription_id_: str | None = None self.tenant_id_: str | None = None @property - def credential(self) -> DefaultAzureCredential: + def credential(self) -> AzureCliCredential: if not self.credential_: - self.credential_ = DefaultAzureCredential( - exclude_interactive_browser_credential=False, - exclude_shared_token_cache_credential=True, # this requires multiple approvals per sign-in - exclude_visual_studio_code_credential=True, # this often fails + self.credential_ = AzureCliCredential( additionally_allowed_tenants=["*"], ) return self.credential_ diff --git a/data_safe_haven/infrastructure/__init__.py b/data_safe_haven/infrastructure/__init__.py index 864f0180f7..2643ac9466 100644 --- a/data_safe_haven/infrastructure/__init__.py +++ b/data_safe_haven/infrastructure/__init__.py @@ -1,6 +1,7 @@ -from .stack_manager import SHMStackManager, SREStackManager +from .stack_manager import PulumiAccount, SHMStackManager, SREStackManager __all__ = [ "SHMStackManager", "SREStackManager", + "PulumiAccount", ] diff --git a/data_safe_haven/infrastructure/stack_manager.py b/data_safe_haven/infrastructure/stack_manager.py index d86b994b4f..48d6c02d1b 100644 --- a/data_safe_haven/infrastructure/stack_manager.py +++ b/data_safe_haven/infrastructure/stack_manager.py @@ -1,24 +1,56 @@ """Deploy with Pulumi""" import logging -import os import pathlib import shutil -import subprocess import time from contextlib import suppress from importlib import metadata +from shutil import which from typing import Any from pulumi import automation from data_safe_haven.config import Config from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenPulumiError -from data_safe_haven.external import AzureApi, AzureCli +from data_safe_haven.external import AzureApi, AzureCliSingleton from data_safe_haven.functions import replace_separators from data_safe_haven.infrastructure.stacks import DeclarativeSHM, DeclarativeSRE from data_safe_haven.utility import LoggingSingleton +class PulumiAccount: + """Manage and interact with Pulumi backend account""" + + def __init__(self, config: Config): + self.cfg = config + self.env_: dict[str, Any] | None = None + path = which("pulumi") + if path is None: + msg = "Unable to find Pulumi CLI executable in your path.\nPlease ensure that Pulumi is installed" + raise DataSafeHavenPulumiError(msg) + + # Ensure Azure CLI account is correct + # This will be needed to populate env + AzureCliSingleton().confirm() + + @property + def env(self) -> dict[str, Any]: + """Get necessary Pulumi environment variables""" + if not self.env_: + azure_api = AzureApi(self.cfg.subscription_name) + backend_storage_account_keys = azure_api.get_storage_account_keys( + self.cfg.backend.resource_group_name, + self.cfg.backend.storage_account_name, + ) + self.env_ = { + "AZURE_STORAGE_ACCOUNT": self.cfg.backend.storage_account_name, + "AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value), + "AZURE_KEYVAULT_AUTH_VIA_CLI": "true", + "PULUMI_BACKEND_URL": f"azblob://{self.cfg.pulumi.storage_container_name}", + } + return self.env_ + + class StackManager: """Interact with infrastructure using Pulumi""" @@ -27,8 +59,8 @@ def __init__( config: Config, program: DeclarativeSHM | DeclarativeSRE, ) -> None: + self.account = PulumiAccount(config) self.cfg: Config = config - self.env_: dict[str, Any] | None = None self.logger = LoggingSingleton() self.stack_: automation.Stack | None = None self.stack_outputs_: automation.OutputMap | None = None @@ -38,7 +70,6 @@ def __init__( self.stack_name = self.program.stack_name self.work_dir = config.work_directory / "pulumi" / self.program.short_name self.work_dir.mkdir(parents=True, exist_ok=True) - self.login() # Log in to the Pulumi backend self.initialise_workdir() self.install_plugins() @@ -47,22 +78,6 @@ def local_stack_path(self) -> pathlib.Path: """Return the local stack path""" return self.work_dir / f"Pulumi.{self.stack_name}.yaml" - @property - def env(self) -> dict[str, Any]: - """Get necessary Pulumi environment variables""" - if not self.env_: - azure_api = AzureApi(self.cfg.subscription_name) - backend_storage_account_keys = azure_api.get_storage_account_keys( - self.cfg.backend.resource_group_name, - self.cfg.backend.storage_account_name, - ) - self.env_ = { - "AZURE_STORAGE_ACCOUNT": self.cfg.backend.storage_account_name, - "AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value), - "AZURE_KEYVAULT_AUTH_VIA_CLI": "true", - } - return self.env_ - @property def pulumi_extra_args(self) -> dict[str, Any]: extra_args: dict[str, Any] = {} @@ -85,7 +100,7 @@ def stack(self) -> automation.Stack: opts=automation.LocalWorkspaceOptions( secrets_provider=f"azurekeyvault://{self.cfg.backend.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi.encryption_key_version}", work_dir=str(self.work_dir), - env_vars=self.env, + env_vars=self.account.env, ), ) self.logger.info(f"Loaded stack [green]{self.stack_name}[/].") @@ -260,44 +275,6 @@ def install_plugins(self) -> None: msg = f"Installing Pulumi plugins failed.\n{exc}." raise DataSafeHavenPulumiError(msg) from exc - def login(self) -> None: - """Login to Pulumi.""" - try: - # Ensure we are authenticated with the Azure CLI - # Without this, we cannot read the encryption key from the keyvault - AzureCli().login() - # Check whether we're already logged in - # Note that we cannot retrieve self.stack without being logged in - self.logger.debug("Logging into Pulumi") - with suppress(DataSafeHavenPulumiError): - result = self.stack.workspace.who_am_i() - if result.user: - self.logger.info(f"Logged into Pulumi as [green]{result.user}[/]") - return - # Otherwise log in to Pulumi - try: - cmd_env = {**os.environ, **self.env} - self.logger.debug(f"Running command using environment {cmd_env}") - process = subprocess.run( - [ - "pulumi", - "login", - f"azblob://{self.cfg.pulumi.storage_container_name}", - ], - capture_output=True, - check=True, - cwd=self.work_dir, - encoding="UTF-8", - env=cmd_env, - ) - self.logger.info(process.stdout) - except (subprocess.CalledProcessError, FileNotFoundError) as exc: - msg = f"Logging into Pulumi failed.\n{exc}." - raise DataSafeHavenPulumiError(msg) from exc - except Exception as exc: - msg = f"Logging into Pulumi failed.\n{exc}." - raise DataSafeHavenPulumiError(msg) from exc - def output(self, name: str) -> Any: """Get a named output value from a stack""" if not self.stack_outputs_: diff --git a/data_safe_haven/utility/__init__.py b/data_safe_haven/utility/__init__.py index 64a9e412ab..73ab4620ec 100644 --- a/data_safe_haven/utility/__init__.py +++ b/data_safe_haven/utility/__init__.py @@ -1,6 +1,7 @@ from .enums import DatabaseSystem, SoftwarePackageCategory from .file_reader import FileReader from .logger import LoggingSingleton, NonLoggingSingleton +from .singleton import Singleton from .types import PathType __all__ = [ @@ -9,5 +10,6 @@ "LoggingSingleton", "NonLoggingSingleton", "PathType", + "Singleton", "SoftwarePackageCategory", ] diff --git a/data_safe_haven/utility/logger.py b/data_safe_haven/utility/logger.py index c42420de41..7eb8967a87 100644 --- a/data_safe_haven/utility/logger.py +++ b/data_safe_haven/utility/logger.py @@ -1,7 +1,7 @@ """Standalone logging class implemented as a singleton""" import io import logging -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar from rich.console import Console from rich.highlighter import RegexHighlighter @@ -98,8 +98,6 @@ class LoggingSingleton(logging.Logger, metaclass=Singleton): date_fmt = r"%Y-%m-%d %H:%M:%S" rich_format = r"[log.time]%(asctime)s[/] [%(levelname)8s] %(message)s" - # Due to https://bugs.python.org/issue45857 we must use `Optional` - _instance: Optional["LoggingSingleton"] = None def __init__(self) -> None: super().__init__(name="data_safe_haven", level=logging.INFO) diff --git a/pyproject.toml b/pyproject.toml index f8538fb66d..95afe80918 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,8 +120,7 @@ select = [ ignore = [ "E501", # ignore line length "S106", # ignore check for possible passwords - "S603", # allow subprocess without shell=True - "S607", # allow subprocess without absolute path + "S603", # allow subprocess with shell=False, this is lower severity than those with shell=True "C901", # ignore complex-structure "PLR0912", # ignore too-many-branches "PLR0913", # ignore too-many-arguments diff --git a/typings/typer/__init__.pyi b/typings/typer/__init__.pyi index 2bb704a21a..16e679947b 100644 --- a/typings/typer/__init__.pyi +++ b/typings/typer/__init__.pyi @@ -1,4 +1,5 @@ from click.exceptions import BadParameter, Exit +from click.termui import confirm from .main import Typer from .params import Argument @@ -10,4 +11,5 @@ __all__ = [ "Exit", "Option", "Typer", + "confirm", ]