Skip to content

Commit

Permalink
Move Pulumi storage account and key to context
Browse files Browse the repository at this point in the history
  • Loading branch information
JimMadge committed Apr 17, 2024
1 parent eb1a73b commit d78d1d8
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 58 deletions.
23 changes: 1 addition & 22 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, ClassVar
from typing import Any

import yaml
from azure.keyvault.keys import KeyVaultKey
from pydantic import (
BaseModel,
Field,
Expand Down Expand Up @@ -59,8 +58,6 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]):


class ConfigSectionPulumi(BaseModel, validate_assignment=True):
storage_container_name: ClassVar[str] = "pulumi"
encryption_key_name: ClassVar[str] = "pulumi-encryption-key"
stacks: dict[str, str] = Field(..., default_factory=dict[str, str])


Expand Down Expand Up @@ -239,8 +236,6 @@ class Config(BaseModel, validate_assignment=True):
)
tags: ConfigSectionTags = Field(..., exclude=True)

_pulumi_encryption_key = None

def __init__(self, context: Context, **kwargs: dict[Any, Any]):
tags = ConfigSectionTags(context)
super().__init__(context=context, tags=tags, **kwargs)
Expand All @@ -258,22 +253,6 @@ def all_sre_indices_must_be_unique(
def work_directory(self) -> Path:
return self.context.work_directory

@property
def pulumi_encryption_key(self) -> KeyVaultKey:
if not self._pulumi_encryption_key:
azure_api = AzureApi(subscription_name=self.context.subscription_name)
self._pulumi_encryption_key = azure_api.get_keyvault_key(
key_name=self.pulumi.encryption_key_name,
key_vault_name=self.context.key_vault_name,
)
return self._pulumi_encryption_key

@property
def pulumi_encryption_key_version(self) -> str:
"""ID for the Pulumi encryption key"""
key_id: str = self.pulumi_encryption_key.id
return key_id.split("/")[-1]

@property
def sre_names(self) -> list[str]:
"""Names of all SREs"""
Expand Down
30 changes: 30 additions & 0 deletions data_safe_haven/config/context_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from typing import ClassVar

import yaml
from azure.keyvault.keys import KeyVaultKey
from pydantic import BaseModel, Field, ValidationError, model_validator
from yaml import YAMLError

from data_safe_haven.exceptions import (
DataSafeHavenConfigError,
DataSafeHavenParameterError,
)
from data_safe_haven.external import AzureApi
from data_safe_haven.functions import alphanumeric
from data_safe_haven.utility import LoggingSingleton, config_dir
from data_safe_haven.utility.annotated_types import (
Expand All @@ -35,6 +37,10 @@ class Context(BaseModel, validate_assignment=True):
name: str
subscription_name: AzureLongName
storage_container_name: ClassVar[str] = "config"
pulumi_storage_container_name: ClassVar[str] = "pulumi"
pulumi_encryption_key_name: ClassVar[str] = "pulumi-encryption-key"

_pulumi_encryption_key = None

@property
def shm_name(self) -> str:
Expand Down Expand Up @@ -68,6 +74,30 @@ def managed_identity_name(self) -> str:
def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), indent=2)

@property
def pulumi_backend_url(self) -> str:
return f"azblob://{self.pulumi_storage_container_name}"

@property
def pulumi_encryption_key(self) -> KeyVaultKey:
if not self._pulumi_encryption_key:
azure_api = AzureApi(subscription_name=self.subscription_name)
self._pulumi_encryption_key = azure_api.get_keyvault_key(
key_name=self.pulumi_encryption_key_name,
key_vault_name=self.key_vault_name,
)
return self._pulumi_encryption_key

@property
def pulumi_encryption_key_version(self) -> str:
"""ID for the Pulumi encryption key"""
key_id: str = self.pulumi_encryption_key.id
return key_id.split("/")[-1]

@property
def pulumi_secrets_provider_url(self) -> str:
return f"azurekeyvault://{self.key_vault_name}.vault.azure.net/keys/{self.pulumi_encryption_key_name}/{self.pulumi_encryption_key_version}"


class ContextSettings(BaseModel, validate_assignment=True):
"""Load global and local settings from dotfiles with structure like the following
Expand Down
6 changes: 3 additions & 3 deletions data_safe_haven/context/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from data_safe_haven.config import ConfigSectionPulumi, ConfigSectionTags, Context
from data_safe_haven.config import ConfigSectionTags, Context
from data_safe_haven.exceptions import DataSafeHavenAzureError
from data_safe_haven.external import AzureApi

Expand Down Expand Up @@ -59,7 +59,7 @@ def create(self) -> None:
storage_account_name=storage_account.name,
)
_ = self.azure_api.ensure_storage_blob_container(
container_name=ConfigSectionPulumi.storage_container_name,
container_name=self.context.pulumi_storage_container_name,
resource_group_name=resource_group.name,
storage_account_name=storage_account.name,
)
Expand All @@ -75,7 +75,7 @@ def create(self) -> None:
msg = f"Keyvault '{self.context.key_vault_name}' was not created."
raise DataSafeHavenAzureError(msg)
self.azure_api.ensure_keyvault_key(
key_name=ConfigSectionPulumi.encryption_key_name,
key_name=self.context.pulumi_encryption_key_name,
key_vault_name=keyvault.name,
)
except Exception as exc:
Expand Down
6 changes: 3 additions & 3 deletions data_safe_haven/infrastructure/stack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def env(self) -> dict[str, Any]:
"AZURE_STORAGE_ACCOUNT": self.cfg.context.storage_account_name,
"AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value),
"AZURE_KEYVAULT_AUTH_VIA_CLI": "true",
"PULUMI_BACKEND_URL": f"azblob://{self.cfg.pulumi.storage_container_name}",
"PULUMI_BACKEND_URL": self.cfg.context.pulumi_backend_url,
}
return self.env_

Expand Down Expand Up @@ -101,7 +101,7 @@ def stack(self) -> automation.Stack:
stack_name=self.stack_name,
program=self.program.run,
opts=automation.LocalWorkspaceOptions(
secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi_encryption_key_version}",
secrets_provider=self.cfg.context.pulumi_secrets_provider_url,
work_dir=str(self.work_dir),
env_vars=self.account.env,
),
Expand Down Expand Up @@ -216,7 +216,7 @@ def destroy(self) -> None:
blob_name=f".pulumi/stacks/{self.project_name}/{stack_backup_name}",
resource_group_name=self.cfg.context.resource_group_name,
storage_account_name=self.cfg.context.storage_account_name,
storage_container_name=self.cfg.pulumi.storage_container_name,
storage_container_name=self.cfg.context.pulumi_storage_container_name,
)
except DataSafeHavenAzureError as exc:
if "blob does not exist" in str(exc):
Expand Down
30 changes: 0 additions & 30 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
DataSafeHavenConfigError,
DataSafeHavenParameterError,
)
from data_safe_haven.external import AzureApi
from data_safe_haven.utility.enums import DatabaseSystem, SoftwarePackageCategory
from data_safe_haven.version import __version__

Expand Down Expand Up @@ -47,9 +46,7 @@ def pulumi_config():
class TestConfigSectionPulumi:
def test_constructor_defaults(self):
pulumi_config = ConfigSectionPulumi()
assert pulumi_config.encryption_key_name == "pulumi-encryption-key"
assert pulumi_config.stacks == {}
assert pulumi_config.storage_container_name == "pulumi"


@fixture
Expand Down Expand Up @@ -212,20 +209,6 @@ def config_sres(context, azure_config, pulumi_config, shm_config):
)


@fixture
def mock_key_vault_key(monkeypatch):
class MockKeyVaultKey:
def __init__(self, key_name, key_vault_name):
self.key_name = key_name
self.key_vault_name = key_vault_name
self.id = "mock_key/version"

def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001
return MockKeyVaultKey(key_name, key_vault_name)

monkeypatch.setattr(AzureApi, "get_keyvault_key", mock_get_keyvault_key)


class TestConfig:
def test_constructor(self, context, azure_config, pulumi_config, shm_config):
config = Config(
Expand Down Expand Up @@ -257,19 +240,6 @@ def test_work_directory(self, config_sres):
config = config_sres
assert config.work_directory == config.context.work_directory

def test_pulumi_encryption_key(
self, config_sres, mock_key_vault_key # noqa: ARG002
):
key = config_sres.pulumi_encryption_key
assert key.key_name == config_sres.pulumi.encryption_key_name
assert key.key_vault_name == config_sres.context.key_vault_name

def test_pulumi_encryption_key_version(
self, config_sres, mock_key_vault_key # noqa: ARG002
):
version = config_sres.pulumi_encryption_key_version
assert version == "version"

@pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)])
def test_is_complete_no_sres(self, config_no_sres, require_sres, expected):
assert config_no_sres.is_complete(require_sres=require_sres) is expected
Expand Down
28 changes: 28 additions & 0 deletions tests/config/test_context_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
DataSafeHavenConfigError,
DataSafeHavenParameterError,
)
from data_safe_haven.external import AzureApi


@fixture
def mock_key_vault_key(monkeypatch):
class MockKeyVaultKey:
def __init__(self, key_name, key_vault_name):
self.key_name = key_name
self.key_vault_name = key_vault_name
self.id = "mock_key/version"

def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001
return MockKeyVaultKey(key_name, key_vault_name)

monkeypatch.setattr(AzureApi, "get_keyvault_key", mock_get_keyvault_key)


class TestContext:
Expand Down Expand Up @@ -61,6 +76,19 @@ def test_long_storage_account_name(self, context_dict):
context = Context(**context_dict)
assert context.storage_account_name == "shmveryveryveryvecontext"

def test_pulumi_encryption_key(
self, context, mock_key_vault_key # noqa: ARG002
):
key = context.pulumi_encryption_key
assert key.key_name == context.pulumi_encryption_key_name
assert key.key_vault_name == context.key_vault_name

def test_pulumi_encryption_key_version(
self, context, mock_key_vault_key # noqa: ARG002
):
version = context.pulumi_encryption_key_version
assert version == "version"


@fixture
def context_yaml():
Expand Down

0 comments on commit d78d1d8

Please sign in to comment.