Skip to content

Commit

Permalink
Merge pull request #1617 from alan-turing-institute/login_flow
Browse files Browse the repository at this point in the history
Pulumi: Improve login flow
  • Loading branch information
JimMadge authored Oct 26, 2023
2 parents 63c7910 + 1c87532 commit 653e86d
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 108 deletions.
4 changes: 2 additions & 2 deletions data_safe_haven/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .api.azure_api import AzureApi
from .api.azure_cli import AzureCli
from .api.azure_cli import AzureCliSingleton
from .api.graph_api import GraphApi
from .interface.azure_container_instance import AzureContainerInstance
from .interface.azure_fileshare import AzureFileShare
Expand All @@ -8,7 +8,7 @@

__all__ = [
"AzureApi",
"AzureCli",
"AzureCliSingleton",
"AzureContainerInstance",
"AzureFileShare",
"AzureIPv4Range",
Expand Down
101 changes: 68 additions & 33 deletions data_safe_haven/external/api/azure_cli.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,79 @@
"""Interface to the Azure CLI"""
import json
import subprocess
from typing import Any
from dataclasses import dataclass
from shutil import which

import typer

from data_safe_haven.exceptions import DataSafeHavenAzureError
from data_safe_haven.utility import LoggingSingleton
from data_safe_haven.utility import LoggingSingleton, Singleton


@dataclass
class AzureCliAccount:
"""Dataclass for Azure CLI Account details"""

name: str
id_: str
tenant_id: str


class AzureCli:
class AzureCliSingleton(metaclass=Singleton):
"""Interface to the Azure CLI"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
def __init__(self) -> None:
self.logger = LoggingSingleton()

def login(self) -> None:
"""Force log in via the Azure CLI"""
try:
self.logger.debug("Attempting to login using Azure CLI.")
# We do not use `check` in subprocess as this raises a CalledProcessError
# which would break the loop. Instead we check the return code of
# `az account show` which will be 0 on success.
while True:
# Check whether we are already logged in
process = subprocess.run(
["az", "account", "show"], capture_output=True, check=False
)
if process.returncode == 0:
break
# Note that subprocess.run will block until the process terminates so
# we need to print the guidance first.
self.logger.info(
"Please login in your web browser at [bold]https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize[/]."
)
self.logger.info(
"If no web browser is available, please run [bold]az login --use-device-code[/] in a command line window."
)
# Attempt to log in at the command line
process = subprocess.run(
["az", "login"], capture_output=True, check=False
path = which("az")
if path is None:
msg = "Unable to find Azure CLI executable in your path.\nPlease ensure that Azure CLI is installed"
raise DataSafeHavenAzureError(msg)
self.path = path

self._account: AzureCliAccount | None = None
self._confirmed = False

@property
def account(self) -> AzureCliAccount:
if not self._account:
try:
result = subprocess.check_output(
[self.path, "account", "show", "--output", "json"],
stderr=subprocess.PIPE,
encoding="utf8",
)
except FileNotFoundError as exc:
msg = f"Please ensure that the Azure CLI is installed.\n{exc}"
raise DataSafeHavenAzureError(msg) from exc
except subprocess.CalledProcessError as exc:
msg = f"Error getting account information from Azure CLI.\n{exc.stderr}"
raise DataSafeHavenAzureError(msg) from exc

try:
result_dict = json.loads(result)
except json.JSONDecodeError as exc:
msg = f"Unable to parse Azure CLI output as JSON.\n{result}"
raise DataSafeHavenAzureError(msg) from exc

self._account = AzureCliAccount(
name=result_dict.get("user").get("name"),
id_=result_dict.get("id"),
tenant_id=result_dict.get("tenantId"),
)

return self._account

def confirm(self) -> None:
"""Prompt user to confirm the Azure CLI account is correct"""
if self._confirmed:
return None

account = self.account
self.logger.info(
f"name: {account.name} (id: {account.id_}\ntenant: {account.tenant_id})"
)
if not typer.confirm("Is this the Azure account you expect?\n"):
self.logger.error(
"Please use `az login` to connect to the correct Azure CLI account"
)
raise typer.Exit(1)

self._confirmed = True
11 changes: 4 additions & 7 deletions data_safe_haven/external/interface/azure_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast

from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
from azure.identity import AzureCliCredential
from azure.mgmt.resource.subscriptions import SubscriptionClient
from azure.mgmt.resource.subscriptions.models import Subscription

Expand All @@ -17,17 +17,14 @@ class AzureAuthenticator:

def __init__(self, subscription_name: str) -> None:
self.subscription_name: str = subscription_name
self.credential_: DefaultAzureCredential | None = None
self.credential_: AzureCliCredential | None = None
self.subscription_id_: str | None = None
self.tenant_id_: str | None = None

@property
def credential(self) -> DefaultAzureCredential:
def credential(self) -> AzureCliCredential:
if not self.credential_:
self.credential_ = DefaultAzureCredential(
exclude_interactive_browser_credential=False,
exclude_shared_token_cache_credential=True, # this requires multiple approvals per sign-in
exclude_visual_studio_code_credential=True, # this often fails
self.credential_ = AzureCliCredential(
additionally_allowed_tenants=["*"],
)
return self.credential_
Expand Down
3 changes: 2 additions & 1 deletion data_safe_haven/infrastructure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .stack_manager import SHMStackManager, SREStackManager
from .stack_manager import PulumiAccount, SHMStackManager, SREStackManager

__all__ = [
"SHMStackManager",
"SREStackManager",
"PulumiAccount",
]
97 changes: 37 additions & 60 deletions data_safe_haven/infrastructure/stack_manager.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,56 @@
"""Deploy with Pulumi"""
import logging
import os
import pathlib
import shutil
import subprocess
import time
from contextlib import suppress
from importlib import metadata
from shutil import which
from typing import Any

from pulumi import automation

from data_safe_haven.config import Config
from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenPulumiError
from data_safe_haven.external import AzureApi, AzureCli
from data_safe_haven.external import AzureApi, AzureCliSingleton
from data_safe_haven.functions import replace_separators
from data_safe_haven.infrastructure.stacks import DeclarativeSHM, DeclarativeSRE
from data_safe_haven.utility import LoggingSingleton


class PulumiAccount:
"""Manage and interact with Pulumi backend account"""

def __init__(self, config: Config):
self.cfg = config
self.env_: dict[str, Any] | None = None
path = which("pulumi")
if path is None:
msg = "Unable to find Pulumi CLI executable in your path.\nPlease ensure that Pulumi is installed"
raise DataSafeHavenPulumiError(msg)

# Ensure Azure CLI account is correct
# This will be needed to populate env
AzureCliSingleton().confirm()

@property
def env(self) -> dict[str, Any]:
"""Get necessary Pulumi environment variables"""
if not self.env_:
azure_api = AzureApi(self.cfg.subscription_name)
backend_storage_account_keys = azure_api.get_storage_account_keys(
self.cfg.backend.resource_group_name,
self.cfg.backend.storage_account_name,
)
self.env_ = {
"AZURE_STORAGE_ACCOUNT": self.cfg.backend.storage_account_name,
"AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value),
"AZURE_KEYVAULT_AUTH_VIA_CLI": "true",
"PULUMI_BACKEND_URL": f"azblob://{self.cfg.pulumi.storage_container_name}",
}
return self.env_


class StackManager:
"""Interact with infrastructure using Pulumi"""

Expand All @@ -27,8 +59,8 @@ def __init__(
config: Config,
program: DeclarativeSHM | DeclarativeSRE,
) -> None:
self.account = PulumiAccount(config)
self.cfg: Config = config
self.env_: dict[str, Any] | None = None
self.logger = LoggingSingleton()
self.stack_: automation.Stack | None = None
self.stack_outputs_: automation.OutputMap | None = None
Expand All @@ -38,7 +70,6 @@ def __init__(
self.stack_name = self.program.stack_name
self.work_dir = config.work_directory / "pulumi" / self.program.short_name
self.work_dir.mkdir(parents=True, exist_ok=True)
self.login() # Log in to the Pulumi backend
self.initialise_workdir()
self.install_plugins()

Expand All @@ -47,22 +78,6 @@ def local_stack_path(self) -> pathlib.Path:
"""Return the local stack path"""
return self.work_dir / f"Pulumi.{self.stack_name}.yaml"

@property
def env(self) -> dict[str, Any]:
"""Get necessary Pulumi environment variables"""
if not self.env_:
azure_api = AzureApi(self.cfg.subscription_name)
backend_storage_account_keys = azure_api.get_storage_account_keys(
self.cfg.backend.resource_group_name,
self.cfg.backend.storage_account_name,
)
self.env_ = {
"AZURE_STORAGE_ACCOUNT": self.cfg.backend.storage_account_name,
"AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value),
"AZURE_KEYVAULT_AUTH_VIA_CLI": "true",
}
return self.env_

@property
def pulumi_extra_args(self) -> dict[str, Any]:
extra_args: dict[str, Any] = {}
Expand All @@ -85,7 +100,7 @@ def stack(self) -> automation.Stack:
opts=automation.LocalWorkspaceOptions(
secrets_provider=f"azurekeyvault://{self.cfg.backend.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi.encryption_key_version}",
work_dir=str(self.work_dir),
env_vars=self.env,
env_vars=self.account.env,
),
)
self.logger.info(f"Loaded stack [green]{self.stack_name}[/].")
Expand Down Expand Up @@ -260,44 +275,6 @@ def install_plugins(self) -> None:
msg = f"Installing Pulumi plugins failed.\n{exc}."
raise DataSafeHavenPulumiError(msg) from exc

def login(self) -> None:
"""Login to Pulumi."""
try:
# Ensure we are authenticated with the Azure CLI
# Without this, we cannot read the encryption key from the keyvault
AzureCli().login()
# Check whether we're already logged in
# Note that we cannot retrieve self.stack without being logged in
self.logger.debug("Logging into Pulumi")
with suppress(DataSafeHavenPulumiError):
result = self.stack.workspace.who_am_i()
if result.user:
self.logger.info(f"Logged into Pulumi as [green]{result.user}[/]")
return
# Otherwise log in to Pulumi
try:
cmd_env = {**os.environ, **self.env}
self.logger.debug(f"Running command using environment {cmd_env}")
process = subprocess.run(
[
"pulumi",
"login",
f"azblob://{self.cfg.pulumi.storage_container_name}",
],
capture_output=True,
check=True,
cwd=self.work_dir,
encoding="UTF-8",
env=cmd_env,
)
self.logger.info(process.stdout)
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
msg = f"Logging into Pulumi failed.\n{exc}."
raise DataSafeHavenPulumiError(msg) from exc
except Exception as exc:
msg = f"Logging into Pulumi failed.\n{exc}."
raise DataSafeHavenPulumiError(msg) from exc

def output(self, name: str) -> Any:
"""Get a named output value from a stack"""
if not self.stack_outputs_:
Expand Down
2 changes: 2 additions & 0 deletions data_safe_haven/utility/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .enums import DatabaseSystem, SoftwarePackageCategory
from .file_reader import FileReader
from .logger import LoggingSingleton, NonLoggingSingleton
from .singleton import Singleton
from .types import PathType

__all__ = [
Expand All @@ -9,5 +10,6 @@
"LoggingSingleton",
"NonLoggingSingleton",
"PathType",
"Singleton",
"SoftwarePackageCategory",
]
4 changes: 1 addition & 3 deletions data_safe_haven/utility/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Standalone logging class implemented as a singleton"""
import io
import logging
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar

from rich.console import Console
from rich.highlighter import RegexHighlighter
Expand Down Expand Up @@ -98,8 +98,6 @@ class LoggingSingleton(logging.Logger, metaclass=Singleton):

date_fmt = r"%Y-%m-%d %H:%M:%S"
rich_format = r"[log.time]%(asctime)s[/] [%(levelname)8s] %(message)s"
# Due to https://bugs.python.org/issue45857 we must use `Optional`
_instance: Optional["LoggingSingleton"] = None

def __init__(self) -> None:
super().__init__(name="data_safe_haven", level=logging.INFO)
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ select = [
ignore = [
"E501", # ignore line length
"S106", # ignore check for possible passwords
"S603", # allow subprocess without shell=True
"S607", # allow subprocess without absolute path
"S603", # allow subprocess with shell=False, this is lower severity than those with shell=True
"C901", # ignore complex-structure
"PLR0912", # ignore too-many-branches
"PLR0913", # ignore too-many-arguments
Expand Down
2 changes: 2 additions & 0 deletions typings/typer/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from click.exceptions import BadParameter, Exit
from click.termui import confirm

from .main import Typer
from .params import Argument
Expand All @@ -10,4 +11,5 @@ __all__ = [
"Exit",
"Option",
"Typer",
"confirm",
]

0 comments on commit 653e86d

Please sign in to comment.