diff --git a/data_safe_haven/README.md b/data_safe_haven/README.md index 697d3a823f..e9e69a8761 100644 --- a/data_safe_haven/README.md +++ b/data_safe_haven/README.md @@ -47,6 +47,12 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them > dsh config upload config.yaml ``` +- As private endpoints for flexible PostgreSQL are still in preview, the following command is currently needed: + +```console +> az feature register --name "enablePrivateEndpoint" --namespace "Microsoft.DBforPostgreSQL" +``` + - Next deploy the infrastructure [approx 30 minutes]: ```console diff --git a/data_safe_haven/administration/users/guacamole_users.py b/data_safe_haven/administration/users/guacamole_users.py index 3df0f0a89f..035420334b 100644 --- a/data_safe_haven/administration/users/guacamole_users.py +++ b/data_safe_haven/administration/users/guacamole_users.py @@ -3,7 +3,7 @@ from typing import Any from data_safe_haven.config import Config -from data_safe_haven.external import AzurePostgreSQLDatabase +from data_safe_haven.external import AzureApi, AzurePostgreSQLDatabase from data_safe_haven.infrastructure import SREStackManager from .research_user import ResearchUser @@ -13,9 +13,15 @@ class GuacamoleUsers: def __init__(self, config: Config, sre_name: str, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) sre_stack = SREStackManager(config, sre_name) + # Read the SRE database secret from key vault + azure_api = AzureApi(config.context.subscription_name) + connection_db_server_password = azure_api.get_keyvault_secret( + sre_stack.output("data")["key_vault_name"], + sre_stack.output("data")["password_user_database_admin_secret"], + ) self.postgres_provisioner = AzurePostgreSQLDatabase( sre_stack.output("remote_desktop")["connection_db_name"], - sre_stack.secret("password-user-database-admin"), + connection_db_server_password, sre_stack.output("remote_desktop")["connection_db_server_name"], sre_stack.output("remote_desktop")["resource_group_name"], config.context.subscription_name, diff --git a/data_safe_haven/external/interface/azure_postgresql_database.py b/data_safe_haven/external/interface/azure_postgresql_database.py index 3ffe5f9871..13742cf0ed 100644 --- a/data_safe_haven/external/interface/azure_postgresql_database.py +++ b/data_safe_haven/external/interface/azure_postgresql_database.py @@ -2,16 +2,15 @@ import pathlib import time from collections.abc import Sequence -from typing import Any +from typing import Any, cast import psycopg import requests from azure.core.polling import LROPoller -from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient -from azure.mgmt.rdbms.postgresql.models import ( +from azure.mgmt.rdbms.postgresql_flexibleservers import PostgreSQLManagementClient +from azure.mgmt.rdbms.postgresql_flexibleservers.models import ( FirewallRule, Server, - ServerUpdateParameters, ) from data_safe_haven.exceptions import ( @@ -72,7 +71,7 @@ def connection_string(self) -> str: f"host={self.db_server.fully_qualified_domain_name}", f"password={self.db_server_admin_password}", f"port={self.port}", - f"user={self.db_server.administrator_login}@{self.server_name}", + f"user={self.db_server.administrator_login}", "sslmode=require", ] ) @@ -100,15 +99,17 @@ def db_connection(self, n_retries: int = 0) -> psycopg.Connection: """Get the database connection.""" while True: try: - connection = psycopg.connect(self.connection_string) - break - except psycopg.OperationalError as exc: - if n_retries > 0: + try: + connection = psycopg.connect(self.connection_string) + break + except psycopg.OperationalError as exc: + if n_retries <= 0: + raise exc n_retries -= 1 time.sleep(10) - else: - msg = f"Could not connect to database.\n{exc}" - raise DataSafeHavenAzureError(msg) from exc + except Exception as exc: + msg = f"Could not connect to database.\n{exc}" + raise DataSafeHavenAzureError(msg) from exc return connection def load_sql( @@ -169,24 +170,17 @@ def execute_scripts( def set_database_access(self, action: str) -> None: """Enable/disable database access to the PostgreSQL server.""" - rule_name = f"AllowConfigurationUpdate-{self.rule_suffix}" - if action == "enabled": self.logger.debug( f"Adding temporary firewall rule for [green]{self.current_ip}[/]...", ) - self.wait( - self.db_client.servers.begin_update( - self.resource_group_name, - self.server_name, - ServerUpdateParameters(public_network_access="Enabled"), - ) - ) + # NB. We would like to enable public_network_access at this point but this + # is not currently supported by the flexibleServer API self.wait( self.db_client.firewall_rules.begin_create_or_update( self.resource_group_name, self.server_name, - rule_name, + f"AllowConfigurationUpdate-{self.rule_suffix}", FirewallRule( start_ip_address=self.current_ip, end_ip_address=self.current_ip ), @@ -198,28 +192,44 @@ def set_database_access(self, action: str) -> None: ) elif action == "disabled": self.logger.debug( - f"Removing temporary firewall rule for [green]{self.current_ip}[/]...", + f"Removing all firewall rule(s) from [green]{self.server_name}[/]...", ) - self.wait( - self.db_client.firewall_rules.begin_delete( - self.resource_group_name, self.server_name, rule_name - ) + rules = cast( + list[FirewallRule], + self.db_client.firewall_rules.list_by_server( + self.resource_group_name, self.server_name + ), ) - self.wait( - self.db_client.servers.begin_update( - self.resource_group_name, - self.server_name, - ServerUpdateParameters(public_network_access="Disabled"), + + # Delete all named firewall rules + rule_names = [str(rule.name) for rule in rules if rule.name] + for rule_name in rule_names: + self.wait( + self.db_client.firewall_rules.begin_delete( + self.resource_group_name, self.server_name, rule_name + ) + ) + + # NB. We would like to disable public_network_access at this point but this + # is not currently supported by the flexibleServer API + if len(rule_names) == len(rules): + self.logger.info( + f"Removed all firewall rule(s) from [green]{self.server_name}[/].", + ) + else: + self.logger.warning( + f"Unable to remove all firewall rule(s) from [green]{self.server_name}[/].", ) - ) - self.logger.info( - f"Removed temporary firewall rule for [green]{self.current_ip}[/].", - ) else: msg = f"Database access action {action} was not recognised." raise DataSafeHavenInputError(msg) self.db_server_ = None # Force refresh of self.db_server - self.logger.info( + public_network_access = ( + self.db_server.network.public_network_access + if self.db_server.network + else "UNKNOWN" + ) + self.logger.debug( f"Public network access to [green]{self.server_name}[/]" - f" is [green]{self.db_server.public_network_access}[/]." + f" is [green]{public_network_access}[/]." ) diff --git a/data_safe_haven/infrastructure/components/composite/local_dns_record.py b/data_safe_haven/infrastructure/components/composite/local_dns_record.py index 3777fdc843..03c7f3f712 100644 --- a/data_safe_haven/infrastructure/components/composite/local_dns_record.py +++ b/data_safe_haven/infrastructure/components/composite/local_dns_record.py @@ -21,7 +21,7 @@ def __init__( class LocalDnsRecordComponent(ComponentResource): - """Deploy Gitea server with Pulumi""" + """Deploy public and private DNS records with Pulumi""" def __init__( self, diff --git a/data_safe_haven/infrastructure/components/composite/postgresql_database.py b/data_safe_haven/infrastructure/components/composite/postgresql_database.py index 5dab82e4d7..8723c5ac17 100644 --- a/data_safe_haven/infrastructure/components/composite/postgresql_database.py +++ b/data_safe_haven/infrastructure/components/composite/postgresql_database.py @@ -20,7 +20,7 @@ def __init__( location: Input[str], ) -> None: self.database_names = Output.from_input(database_names) - self.database_password = database_password + self.database_password = Output.secret(database_password) self.database_resource_group_name = database_resource_group_name self.database_server_name = database_server_name self.database_subnet_id = database_subnet_id @@ -45,31 +45,34 @@ def __init__( # Define a PostgreSQL server db_server = dbforpostgresql.Server( f"{self._name}_server", - location=props.location, - properties=dbforpostgresql.ServerPropertiesForDefaultCreateArgs( - administrator_login=props.database_username, - administrator_login_password=props.database_password, - create_mode="Default", - infrastructure_encryption=dbforpostgresql.InfrastructureEncryption.DISABLED, - minimal_tls_version=dbforpostgresql.MinimalTlsVersionEnum.TLS_ENFORCEMENT_DISABLED, - public_network_access=dbforpostgresql.PublicNetworkAccessEnum.DISABLED, - ssl_enforcement=dbforpostgresql.SslEnforcementEnum.ENABLED, - storage_profile=dbforpostgresql.StorageProfileArgs( - backup_retention_days=7, - geo_redundant_backup=dbforpostgresql.GeoRedundantBackup.DISABLED, - storage_autogrow=dbforpostgresql.StorageAutogrow.ENABLED, - storage_mb=5120, - ), - version=dbforpostgresql.ServerVersion.SERVER_VERSION_11, + administrator_login=props.database_username, + administrator_login_password=props.database_password, + auth_config=dbforpostgresql.AuthConfigArgs( + active_directory_auth=dbforpostgresql.ActiveDirectoryAuthEnum.DISABLED, + password_auth=dbforpostgresql.PasswordAuthEnum.ENABLED, + ), + backup=dbforpostgresql.BackupArgs( + backup_retention_days=7, + geo_redundant_backup=dbforpostgresql.GeoRedundantBackupEnum.DISABLED, + ), + create_mode=dbforpostgresql.CreateMode.DEFAULT, + data_encryption=dbforpostgresql.DataEncryptionArgs( + type=dbforpostgresql.ArmServerKeyType.SYSTEM_MANAGED, ), + high_availability=dbforpostgresql.HighAvailabilityArgs( + mode=dbforpostgresql.HighAvailabilityMode.DISABLED, + ), + location=props.location, resource_group_name=props.database_resource_group_name, server_name=props.database_server_name, sku=dbforpostgresql.SkuArgs( - capacity=2, - family="Gen5", - name="GP_Gen5_2", - tier=dbforpostgresql.SkuTier.GENERAL_PURPOSE, # required to use private link + name="Standard_B2s", + tier=dbforpostgresql.SkuTier.BURSTABLE, + ), + storage=dbforpostgresql.StorageArgs( + storage_size_gb=32, ), + version=dbforpostgresql.ServerVersion.SERVER_VERSION_14, opts=child_opts, tags=child_tags, ) diff --git a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py index 7bbf6a1094..3ce4ad217c 100644 --- a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py +++ b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py @@ -98,10 +98,11 @@ def create(self, props: dict[str, Any]) -> CreateResult: try: certificate_bytes = client.request_certificate() except ValidationError as exc: - raise DataSafeHavenSSLError( - "ACME validation error:\n" - + "\n".join([str(auth_error) for auth_error in exc.failed_authzrs]) - ) from exc + msg = "\n".join( + ["ACME validation error:"] + + [str(auth_error) for auth_error in exc.failed_authzrs] + ) + raise DataSafeHavenSSLError(msg) from exc # Although KeyVault will accept a PEM certificate (where we simply prepend # the private key) we need a PFX certificate for compatibility with # ApplicationGateway diff --git a/data_safe_haven/infrastructure/stacks/declarative_sre.py b/data_safe_haven/infrastructure/stacks/declarative_sre.py index 7e3141616e..52df097250 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_sre.py +++ b/data_safe_haven/infrastructure/stacks/declarative_sre.py @@ -220,8 +220,6 @@ def run(self) -> None: storage_account_key=data.storage_account_data_configuration_key, storage_account_name=data.storage_account_data_configuration_name, storage_account_resource_group_name=data.resource_group_name, - virtual_network_resource_group_name=networking.resource_group.name, - virtual_network=networking.virtual_network, ), tags=self.cfg.tags.model_dump(), ) @@ -298,8 +296,6 @@ def run(self) -> None: subnet_containers_support=networking.subnet_user_services_containers_support, subnet_databases=networking.subnet_user_services_databases, subnet_software_repositories=networking.subnet_user_services_software_repositories, - virtual_network=networking.virtual_network, - virtual_network_resource_group_name=networking.resource_group.name, ), tags=self.cfg.tags.model_dump(), ) diff --git a/data_safe_haven/infrastructure/stacks/sre/data.py b/data_safe_haven/infrastructure/stacks/sre/data.py index 80625e0f60..3ce5b932d7 100644 --- a/data_safe_haven/infrastructure/stacks/sre/data.py +++ b/data_safe_haven/infrastructure/stacks/sre/data.py @@ -775,7 +775,9 @@ def __init__( password_database_service_admin.result ) self.password_dns_server_admin = Output.secret( - props.password_dns_server_admin.result + Output.from_input(props.password_dns_server_admin).apply( + lambda password: password.result + ) ) self.password_gitea_database_admin = Output.secret( password_gitea_database_admin.result diff --git a/data_safe_haven/infrastructure/stacks/sre/gitea_server.py b/data_safe_haven/infrastructure/stacks/sre/gitea_server.py index 9297f3cd2b..7e665322fd 100644 --- a/data_safe_haven/infrastructure/stacks/sre/gitea_server.py +++ b/data_safe_haven/infrastructure/stacks/sre/gitea_server.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from pulumi import ComponentResource, Input, Output, ResourceOptions -from pulumi_azure_native import containerinstance, network, storage +from pulumi_azure_native import containerinstance, storage from data_safe_haven.infrastructure.common import ( get_ip_address_from_container_group, @@ -42,8 +42,6 @@ def __init__( storage_account_name: Input[str], storage_account_resource_group_name: Input[str], user_services_resource_group_name: Input[str], - virtual_network: Input[network.VirtualNetwork], - virtual_network_resource_group_name: Input[str], database_username: Input[str] | None = None, ) -> None: self.containers_subnet_id = containers_subnet_id @@ -68,8 +66,6 @@ def __init__( self.storage_account_name = storage_account_name self.storage_account_resource_group_name = storage_account_resource_group_name self.user_services_resource_group_name = user_services_resource_group_name - self.virtual_network = virtual_network - self.virtual_network_resource_group_name = virtual_network_resource_group_name class SREGiteaServerComponent(ComponentResource): diff --git a/data_safe_haven/infrastructure/stacks/sre/hedgedoc_server.py b/data_safe_haven/infrastructure/stacks/sre/hedgedoc_server.py index c8c1134ed6..76b34cbc90 100644 --- a/data_safe_haven/infrastructure/stacks/sre/hedgedoc_server.py +++ b/data_safe_haven/infrastructure/stacks/sre/hedgedoc_server.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from pulumi import ComponentResource, Input, Output, ResourceOptions -from pulumi_azure_native import containerinstance, network, storage +from pulumi_azure_native import containerinstance, storage from data_safe_haven.functions import b64encode from data_safe_haven.infrastructure.common import ( @@ -44,8 +44,6 @@ def __init__( storage_account_name: Input[str], storage_account_resource_group_name: Input[str], user_services_resource_group_name: Input[str], - virtual_network: Input[network.VirtualNetwork], - virtual_network_resource_group_name: Input[str], database_username: Input[str] | None = None, ) -> None: self.containers_subnet_id = containers_subnet_id @@ -81,8 +79,6 @@ def __init__( self.storage_account_name = storage_account_name self.storage_account_resource_group_name = storage_account_resource_group_name self.user_services_resource_group_name = user_services_resource_group_name - self.virtual_network = virtual_network - self.virtual_network_resource_group_name = virtual_network_resource_group_name class SREHedgeDocServerComponent(ComponentResource): diff --git a/data_safe_haven/infrastructure/stacks/sre/networking.py b/data_safe_haven/infrastructure/stacks/sre/networking.py index 9a3a1ec970..499e59fa26 100644 --- a/data_safe_haven/infrastructure/stacks/sre/networking.py +++ b/data_safe_haven/infrastructure/stacks/sre/networking.py @@ -1222,7 +1222,7 @@ def __init__( network_security_group=network.NetworkSecurityGroupArgs( id=nsg_guacamole_containers_support.id ), - private_endpoint_network_policies="Disabled", + private_endpoint_network_policies=network.VirtualNetworkPrivateEndpointNetworkPolicies.ENABLED, route_table=network.RouteTableArgs(id=route_table.id), ), # User services containers @@ -1443,10 +1443,10 @@ def __init__( ) # Link virtual network to SHM private DNS zones - # Note that although the DNS virtual network is already linked to these, Azure - # Container Instances do not have an IP address during deployment and so must - # use default Azure DNS when setting up file mounts. This means that we need to - # be able to resolve the "Storage Account" private DNS zones. + # Note that although the DNS virtual network is already linked to the SHM zones, + # Azure Container Instances do not have an IP address during deployment and so + # must use default Azure DNS when setting up file mounts. This means that we + # need to be able to resolve the "Storage Account" private DNS zones. for private_link_domain in ordered_private_dns_zones("Storage account"): network.VirtualNetworkLink( f"{self._name}_private_zone_{private_link_domain}_vnet_link", diff --git a/data_safe_haven/infrastructure/stacks/sre/remote_desktop.py b/data_safe_haven/infrastructure/stacks/sre/remote_desktop.py index c479802d08..8f975969a4 100644 --- a/data_safe_haven/infrastructure/stacks/sre/remote_desktop.py +++ b/data_safe_haven/infrastructure/stacks/sre/remote_desktop.py @@ -52,8 +52,6 @@ def __init__( storage_account_resource_group_name: Input[str], subnet_guacamole_containers: Input[network.GetSubnetResult], subnet_guacamole_containers_support: Input[network.GetSubnetResult], - virtual_network: Input[network.VirtualNetwork], - virtual_network_resource_group_name: Input[str], database_username: Input[str] | None = "postgresadmin", ) -> None: self.aad_application_name = aad_application_name @@ -107,8 +105,6 @@ def __init__( else [] ) ) - self.virtual_network = virtual_network - self.virtual_network_resource_group_name = virtual_network_resource_group_name class SRERemoteDesktopComponent(ComponentResource): @@ -252,16 +248,16 @@ def __init__( value="preferred_username", # this is 'username@domain' ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_DATABASE", value=db_guacamole_connections + name="POSTGRESQL_DATABASE", value=db_guacamole_connections ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_HOSTNAME", + name="POSTGRESQL_HOSTNAME", value=props.subnet_guacamole_containers_support_ip_addresses[ 0 ], ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_PASSWORD", + name="POSTGRESQL_PASSWORD", secure_value=props.database_password, ), containerinstance.EnvironmentVariableArgs( @@ -271,12 +267,8 @@ def __init__( name="POSTGRESQL_SOCKET_TIMEOUT", value="5" ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_USER", - value=Output.concat( - props.database_username, - "@", - db_server_guacamole.db_server.name, - ), + name="POSTGRESQL_USER", + value=props.database_username, ), ], resources=containerinstance.ResourceRequirementsArgs( @@ -302,7 +294,7 @@ def __init__( ), ), containerinstance.ContainerArgs( - image="ghcr.io/alan-turing-institute/guacamole-user-sync:v0.1.0", + image="ghcr.io/alan-turing-institute/guacamole-user-sync:v0.2.0", name="guacamole-user-sync"[:63], environment_variables=[ containerinstance.EnvironmentVariableArgs( @@ -319,7 +311,7 @@ def __init__( ), containerinstance.EnvironmentVariableArgs( name="LDAP_GROUP_FILTER", - value=Output.concat("(objectClass=group)"), + value="(objectClass=group)", ), containerinstance.EnvironmentVariableArgs( name="LDAP_HOST", @@ -340,26 +332,22 @@ def __init__( ), ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_DB_NAME", + name="POSTGRESQL_DB_NAME", value=db_guacamole_connections, ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_HOST", + name="POSTGRESQL_HOST", value=props.subnet_guacamole_containers_support_ip_addresses[ 0 ], ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_PASSWORD", + name="POSTGRESQL_PASSWORD", secure_value=props.database_password, ), containerinstance.EnvironmentVariableArgs( - name="POSTGRES_USERNAME", - value=Output.concat( - props.database_username, - "@", - db_server_guacamole.db_server.name, - ), + name="POSTGRESQL_USERNAME", + value=props.database_username, ), containerinstance.EnvironmentVariableArgs( name="REPEAT_INTERVAL", diff --git a/data_safe_haven/infrastructure/stacks/sre/software_repositories.py b/data_safe_haven/infrastructure/stacks/sre/software_repositories.py index efb51a456f..073124737e 100644 --- a/data_safe_haven/infrastructure/stacks/sre/software_repositories.py +++ b/data_safe_haven/infrastructure/stacks/sre/software_repositories.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from pulumi import ComponentResource, Input, Output, ResourceOptions -from pulumi_azure_native import containerinstance, network, storage +from pulumi_azure_native import containerinstance, storage from data_safe_haven.infrastructure.common import ( get_ip_address_from_container_group, @@ -35,8 +35,6 @@ def __init__( storage_account_resource_group_name: Input[str], subnet_id: Input[str], user_services_resource_group_name: Input[str], - virtual_network: Input[network.VirtualNetwork], - virtual_network_resource_group_name: Input[str], ) -> None: self.dns_resource_group_name = dns_resource_group_name self.dns_server_ip = dns_server_ip @@ -54,8 +52,6 @@ def __init__( self.storage_account_name = storage_account_name self.storage_account_resource_group_name = storage_account_resource_group_name self.subnet_id = subnet_id - self.virtual_network = virtual_network - self.virtual_network_resource_group_name = virtual_network_resource_group_name class SRESoftwareRepositoriesComponent(ComponentResource): diff --git a/data_safe_haven/infrastructure/stacks/sre/user_services.py b/data_safe_haven/infrastructure/stacks/sre/user_services.py index dff7552868..96beeb769d 100644 --- a/data_safe_haven/infrastructure/stacks/sre/user_services.py +++ b/data_safe_haven/infrastructure/stacks/sre/user_services.py @@ -46,8 +46,6 @@ def __init__( subnet_containers_support: Input[network.GetSubnetResult], subnet_databases: Input[network.GetSubnetResult], subnet_software_repositories: Input[network.GetSubnetResult], - virtual_network: Input[network.VirtualNetwork], - virtual_network_resource_group_name: Input[str], ) -> None: self.database_service_admin_password = database_service_admin_password self.databases = databases @@ -83,8 +81,6 @@ def __init__( self.subnet_software_repositories_id = Output.from_input( subnet_software_repositories ).apply(get_id_from_subnet) - self.virtual_network = virtual_network - self.virtual_network_resource_group_name = virtual_network_resource_group_name class SREUserServicesComponent(ComponentResource): @@ -135,8 +131,6 @@ def __init__( storage_account_name=props.storage_account_name, storage_account_resource_group_name=props.storage_account_resource_group_name, user_services_resource_group_name=resource_group.name, - virtual_network=props.virtual_network, - virtual_network_resource_group_name=props.virtual_network_resource_group_name, ), opts=child_opts, tags=child_tags, @@ -167,8 +161,6 @@ def __init__( storage_account_name=props.storage_account_name, storage_account_resource_group_name=props.storage_account_resource_group_name, user_services_resource_group_name=resource_group.name, - virtual_network=props.virtual_network, - virtual_network_resource_group_name=props.virtual_network_resource_group_name, ), opts=child_opts, tags=child_tags, @@ -191,8 +183,6 @@ def __init__( storage_account_resource_group_name=props.storage_account_resource_group_name, subnet_id=props.subnet_software_repositories_id, user_services_resource_group_name=resource_group.name, - virtual_network=props.virtual_network, - virtual_network_resource_group_name=props.virtual_network_resource_group_name, ), opts=child_opts, tags=child_tags,