Skip to content

Commit

Permalink
Merge pull request #1735 from jemrobinson/1733-upgrade-to-flexible-se…
Browse files Browse the repository at this point in the history
…rver

Upgrade to PostgreSQL flexible server
  • Loading branch information
JimMadge authored Feb 12, 2024
2 parents 0695ae1 + 010cbdb commit e50fc9d
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 125 deletions.
6 changes: 6 additions & 0 deletions data_safe_haven/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them
> dsh config upload config.yaml
```

- As private endpoints for flexible PostgreSQL are still in preview, the following command is currently needed:

```console
> az feature register --name "enablePrivateEndpoint" --namespace "Microsoft.DBforPostgreSQL"
```

- Next deploy the infrastructure [approx 30 minutes]:

```console
Expand Down
10 changes: 8 additions & 2 deletions data_safe_haven/administration/users/guacamole_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

from data_safe_haven.config import Config
from data_safe_haven.external import AzurePostgreSQLDatabase
from data_safe_haven.external import AzureApi, AzurePostgreSQLDatabase
from data_safe_haven.infrastructure import SREStackManager

from .research_user import ResearchUser
Expand All @@ -13,9 +13,15 @@ class GuacamoleUsers:
def __init__(self, config: Config, sre_name: str, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
sre_stack = SREStackManager(config, sre_name)
# Read the SRE database secret from key vault
azure_api = AzureApi(config.context.subscription_name)
connection_db_server_password = azure_api.get_keyvault_secret(
sre_stack.output("data")["key_vault_name"],
sre_stack.output("data")["password_user_database_admin_secret"],
)
self.postgres_provisioner = AzurePostgreSQLDatabase(
sre_stack.output("remote_desktop")["connection_db_name"],
sre_stack.secret("password-user-database-admin"),
connection_db_server_password,
sre_stack.output("remote_desktop")["connection_db_server_name"],
sre_stack.output("remote_desktop")["resource_group_name"],
config.context.subscription_name,
Expand Down
86 changes: 48 additions & 38 deletions data_safe_haven/external/interface/azure_postgresql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
import pathlib
import time
from collections.abc import Sequence
from typing import Any
from typing import Any, cast

import psycopg
import requests
from azure.core.polling import LROPoller
from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient
from azure.mgmt.rdbms.postgresql.models import (
from azure.mgmt.rdbms.postgresql_flexibleservers import PostgreSQLManagementClient
from azure.mgmt.rdbms.postgresql_flexibleservers.models import (
FirewallRule,
Server,
ServerUpdateParameters,
)

from data_safe_haven.exceptions import (
Expand Down Expand Up @@ -72,7 +71,7 @@ def connection_string(self) -> str:
f"host={self.db_server.fully_qualified_domain_name}",
f"password={self.db_server_admin_password}",
f"port={self.port}",
f"user={self.db_server.administrator_login}@{self.server_name}",
f"user={self.db_server.administrator_login}",
"sslmode=require",
]
)
Expand Down Expand Up @@ -100,15 +99,17 @@ def db_connection(self, n_retries: int = 0) -> psycopg.Connection:
"""Get the database connection."""
while True:
try:
connection = psycopg.connect(self.connection_string)
break
except psycopg.OperationalError as exc:
if n_retries > 0:
try:
connection = psycopg.connect(self.connection_string)
break
except psycopg.OperationalError as exc:
if n_retries <= 0:
raise exc
n_retries -= 1
time.sleep(10)
else:
msg = f"Could not connect to database.\n{exc}"
raise DataSafeHavenAzureError(msg) from exc
except Exception as exc:
msg = f"Could not connect to database.\n{exc}"
raise DataSafeHavenAzureError(msg) from exc
return connection

def load_sql(
Expand Down Expand Up @@ -169,24 +170,17 @@ def execute_scripts(

def set_database_access(self, action: str) -> None:
"""Enable/disable database access to the PostgreSQL server."""
rule_name = f"AllowConfigurationUpdate-{self.rule_suffix}"

if action == "enabled":
self.logger.debug(
f"Adding temporary firewall rule for [green]{self.current_ip}[/]...",
)
self.wait(
self.db_client.servers.begin_update(
self.resource_group_name,
self.server_name,
ServerUpdateParameters(public_network_access="Enabled"),
)
)
# NB. We would like to enable public_network_access at this point but this
# is not currently supported by the flexibleServer API
self.wait(
self.db_client.firewall_rules.begin_create_or_update(
self.resource_group_name,
self.server_name,
rule_name,
f"AllowConfigurationUpdate-{self.rule_suffix}",
FirewallRule(
start_ip_address=self.current_ip, end_ip_address=self.current_ip
),
Expand All @@ -198,28 +192,44 @@ def set_database_access(self, action: str) -> None:
)
elif action == "disabled":
self.logger.debug(
f"Removing temporary firewall rule for [green]{self.current_ip}[/]...",
f"Removing all firewall rule(s) from [green]{self.server_name}[/]...",
)
self.wait(
self.db_client.firewall_rules.begin_delete(
self.resource_group_name, self.server_name, rule_name
)
rules = cast(
list[FirewallRule],
self.db_client.firewall_rules.list_by_server(
self.resource_group_name, self.server_name
),
)
self.wait(
self.db_client.servers.begin_update(
self.resource_group_name,
self.server_name,
ServerUpdateParameters(public_network_access="Disabled"),

# Delete all named firewall rules
rule_names = [str(rule.name) for rule in rules if rule.name]
for rule_name in rule_names:
self.wait(
self.db_client.firewall_rules.begin_delete(
self.resource_group_name, self.server_name, rule_name
)
)

# NB. We would like to disable public_network_access at this point but this
# is not currently supported by the flexibleServer API
if len(rule_names) == len(rules):
self.logger.info(
f"Removed all firewall rule(s) from [green]{self.server_name}[/].",
)
else:
self.logger.warning(
f"Unable to remove all firewall rule(s) from [green]{self.server_name}[/].",
)
)
self.logger.info(
f"Removed temporary firewall rule for [green]{self.current_ip}[/].",
)
else:
msg = f"Database access action {action} was not recognised."
raise DataSafeHavenInputError(msg)
self.db_server_ = None # Force refresh of self.db_server
self.logger.info(
public_network_access = (
self.db_server.network.public_network_access
if self.db_server.network
else "UNKNOWN"
)
self.logger.debug(
f"Public network access to [green]{self.server_name}[/]"
f" is [green]{self.db_server.public_network_access}[/]."
f" is [green]{public_network_access}[/]."
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(


class LocalDnsRecordComponent(ComponentResource):
"""Deploy Gitea server with Pulumi"""
"""Deploy public and private DNS records with Pulumi"""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
location: Input[str],
) -> None:
self.database_names = Output.from_input(database_names)
self.database_password = database_password
self.database_password = Output.secret(database_password)
self.database_resource_group_name = database_resource_group_name
self.database_server_name = database_server_name
self.database_subnet_id = database_subnet_id
Expand All @@ -45,31 +45,34 @@ def __init__(
# Define a PostgreSQL server
db_server = dbforpostgresql.Server(
f"{self._name}_server",
location=props.location,
properties=dbforpostgresql.ServerPropertiesForDefaultCreateArgs(
administrator_login=props.database_username,
administrator_login_password=props.database_password,
create_mode="Default",
infrastructure_encryption=dbforpostgresql.InfrastructureEncryption.DISABLED,
minimal_tls_version=dbforpostgresql.MinimalTlsVersionEnum.TLS_ENFORCEMENT_DISABLED,
public_network_access=dbforpostgresql.PublicNetworkAccessEnum.DISABLED,
ssl_enforcement=dbforpostgresql.SslEnforcementEnum.ENABLED,
storage_profile=dbforpostgresql.StorageProfileArgs(
backup_retention_days=7,
geo_redundant_backup=dbforpostgresql.GeoRedundantBackup.DISABLED,
storage_autogrow=dbforpostgresql.StorageAutogrow.ENABLED,
storage_mb=5120,
),
version=dbforpostgresql.ServerVersion.SERVER_VERSION_11,
administrator_login=props.database_username,
administrator_login_password=props.database_password,
auth_config=dbforpostgresql.AuthConfigArgs(
active_directory_auth=dbforpostgresql.ActiveDirectoryAuthEnum.DISABLED,
password_auth=dbforpostgresql.PasswordAuthEnum.ENABLED,
),
backup=dbforpostgresql.BackupArgs(
backup_retention_days=7,
geo_redundant_backup=dbforpostgresql.GeoRedundantBackupEnum.DISABLED,
),
create_mode=dbforpostgresql.CreateMode.DEFAULT,
data_encryption=dbforpostgresql.DataEncryptionArgs(
type=dbforpostgresql.ArmServerKeyType.SYSTEM_MANAGED,
),
high_availability=dbforpostgresql.HighAvailabilityArgs(
mode=dbforpostgresql.HighAvailabilityMode.DISABLED,
),
location=props.location,
resource_group_name=props.database_resource_group_name,
server_name=props.database_server_name,
sku=dbforpostgresql.SkuArgs(
capacity=2,
family="Gen5",
name="GP_Gen5_2",
tier=dbforpostgresql.SkuTier.GENERAL_PURPOSE, # required to use private link
name="Standard_B2s",
tier=dbforpostgresql.SkuTier.BURSTABLE,
),
storage=dbforpostgresql.StorageArgs(
storage_size_gb=32,
),
version=dbforpostgresql.ServerVersion.SERVER_VERSION_14,
opts=child_opts,
tags=child_tags,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ def create(self, props: dict[str, Any]) -> CreateResult:
try:
certificate_bytes = client.request_certificate()
except ValidationError as exc:
raise DataSafeHavenSSLError(
"ACME validation error:\n"
+ "\n".join([str(auth_error) for auth_error in exc.failed_authzrs])
) from exc
msg = "\n".join(
["ACME validation error:"]
+ [str(auth_error) for auth_error in exc.failed_authzrs]
)
raise DataSafeHavenSSLError(msg) from exc
# Although KeyVault will accept a PEM certificate (where we simply prepend
# the private key) we need a PFX certificate for compatibility with
# ApplicationGateway
Expand Down
4 changes: 0 additions & 4 deletions data_safe_haven/infrastructure/stacks/declarative_sre.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ def run(self) -> None:
storage_account_key=data.storage_account_data_configuration_key,
storage_account_name=data.storage_account_data_configuration_name,
storage_account_resource_group_name=data.resource_group_name,
virtual_network_resource_group_name=networking.resource_group.name,
virtual_network=networking.virtual_network,
),
tags=self.cfg.tags.model_dump(),
)
Expand Down Expand Up @@ -298,8 +296,6 @@ def run(self) -> None:
subnet_containers_support=networking.subnet_user_services_containers_support,
subnet_databases=networking.subnet_user_services_databases,
subnet_software_repositories=networking.subnet_user_services_software_repositories,
virtual_network=networking.virtual_network,
virtual_network_resource_group_name=networking.resource_group.name,
),
tags=self.cfg.tags.model_dump(),
)
Expand Down
4 changes: 3 additions & 1 deletion data_safe_haven/infrastructure/stacks/sre/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,9 @@ def __init__(
password_database_service_admin.result
)
self.password_dns_server_admin = Output.secret(
props.password_dns_server_admin.result
Output.from_input(props.password_dns_server_admin).apply(
lambda password: password.result
)
)
self.password_gitea_database_admin = Output.secret(
password_gitea_database_admin.result
Expand Down
6 changes: 1 addition & 5 deletions data_safe_haven/infrastructure/stacks/sre/gitea_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Mapping

from pulumi import ComponentResource, Input, Output, ResourceOptions
from pulumi_azure_native import containerinstance, network, storage
from pulumi_azure_native import containerinstance, storage

from data_safe_haven.infrastructure.common import (
get_ip_address_from_container_group,
Expand Down Expand Up @@ -42,8 +42,6 @@ def __init__(
storage_account_name: Input[str],
storage_account_resource_group_name: Input[str],
user_services_resource_group_name: Input[str],
virtual_network: Input[network.VirtualNetwork],
virtual_network_resource_group_name: Input[str],
database_username: Input[str] | None = None,
) -> None:
self.containers_subnet_id = containers_subnet_id
Expand All @@ -68,8 +66,6 @@ def __init__(
self.storage_account_name = storage_account_name
self.storage_account_resource_group_name = storage_account_resource_group_name
self.user_services_resource_group_name = user_services_resource_group_name
self.virtual_network = virtual_network
self.virtual_network_resource_group_name = virtual_network_resource_group_name


class SREGiteaServerComponent(ComponentResource):
Expand Down
6 changes: 1 addition & 5 deletions data_safe_haven/infrastructure/stacks/sre/hedgedoc_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Mapping

from pulumi import ComponentResource, Input, Output, ResourceOptions
from pulumi_azure_native import containerinstance, network, storage
from pulumi_azure_native import containerinstance, storage

from data_safe_haven.functions import b64encode
from data_safe_haven.infrastructure.common import (
Expand Down Expand Up @@ -44,8 +44,6 @@ def __init__(
storage_account_name: Input[str],
storage_account_resource_group_name: Input[str],
user_services_resource_group_name: Input[str],
virtual_network: Input[network.VirtualNetwork],
virtual_network_resource_group_name: Input[str],
database_username: Input[str] | None = None,
) -> None:
self.containers_subnet_id = containers_subnet_id
Expand Down Expand Up @@ -81,8 +79,6 @@ def __init__(
self.storage_account_name = storage_account_name
self.storage_account_resource_group_name = storage_account_resource_group_name
self.user_services_resource_group_name = user_services_resource_group_name
self.virtual_network = virtual_network
self.virtual_network_resource_group_name = virtual_network_resource_group_name


class SREHedgeDocServerComponent(ComponentResource):
Expand Down
10 changes: 5 additions & 5 deletions data_safe_haven/infrastructure/stacks/sre/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ def __init__(
network_security_group=network.NetworkSecurityGroupArgs(
id=nsg_guacamole_containers_support.id
),
private_endpoint_network_policies="Disabled",
private_endpoint_network_policies=network.VirtualNetworkPrivateEndpointNetworkPolicies.ENABLED,
route_table=network.RouteTableArgs(id=route_table.id),
),
# User services containers
Expand Down Expand Up @@ -1443,10 +1443,10 @@ def __init__(
)

# Link virtual network to SHM private DNS zones
# Note that although the DNS virtual network is already linked to these, Azure
# Container Instances do not have an IP address during deployment and so must
# use default Azure DNS when setting up file mounts. This means that we need to
# be able to resolve the "Storage Account" private DNS zones.
# Note that although the DNS virtual network is already linked to the SHM zones,
# Azure Container Instances do not have an IP address during deployment and so
# must use default Azure DNS when setting up file mounts. This means that we
# need to be able to resolve the "Storage Account" private DNS zones.
for private_link_domain in ordered_private_dns_zones("Storage account"):
network.VirtualNetworkLink(
f"{self._name}_private_zone_{private_link_domain}_vnet_link",
Expand Down
Loading

0 comments on commit e50fc9d

Please sign in to comment.