Skip to content

Commit

Permalink
Merge pull request #2233 from alan-turing-institute/internet
Browse files Browse the repository at this point in the history
Add internet
  • Loading branch information
JimMadge authored Oct 16, 2024
2 parents ea23e1b + 5068e5a commit 727f43d
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 8 deletions.
18 changes: 16 additions & 2 deletions data_safe_haven/config/config_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_safe_haven.types import (
AzureLocation,
AzurePremiumFileShareSize,
AzureServiceTag,
AzureVmSku,
DatabaseSystem,
EmailAddress,
Expand Down Expand Up @@ -58,7 +59,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
databases: UniqueList[DatabaseSystem] = []
data_provider_ip_addresses: list[IpAddress] = []
remote_desktop: ConfigSubsectionRemoteDesktopOpts
research_user_ip_addresses: list[IpAddress] = []
research_user_ip_addresses: list[IpAddress] | AzureServiceTag = []
storage_quota_gb: ConfigSubsectionStorageQuotaGB
software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE
timezone: TimeZone = "Etc/UTC"
Expand All @@ -67,7 +68,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
@field_validator(
"admin_ip_addresses",
"data_provider_ip_addresses",
"research_user_ip_addresses",
# "research_user_ip_addresses",
mode="after",
)
@classmethod
Expand All @@ -78,3 +79,16 @@ def ensure_non_overlapping(cls, v: list[IpAddress]) -> list[IpAddress]:
msg = "IP addresses must not overlap."
raise ValueError(msg)
return v

@field_validator(
"research_user_ip_addresses",
mode="after",
)
@classmethod
def ensure_non_overlapping_or_tag(
cls, v: list[IpAddress] | AzureServiceTag
) -> list[IpAddress] | AzureServiceTag:
if isinstance(v, list):
return cls.ensure_non_overlapping(v)
else:
return v
5 changes: 4 additions & 1 deletion data_safe_haven/config/sre_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def template(cls: type[Self], tier: int | None = None) -> SREConfig:
allow_copy=remote_desktop_allow_copy,
allow_paste=remote_desktop_allow_paste,
),
research_user_ip_addresses=["List of IP addresses belonging to users"],
research_user_ip_addresses=[
"List of IP addresses belonging to users",
"You can also use the tag 'Internet' instead of a list",
],
software_packages=software_packages,
storage_quota_gb=ConfigSubsectionStorageQuotaGB.model_construct(
home="Total size in GiB across all home directories [minimum: 100].", # type: ignore
Expand Down
14 changes: 11 additions & 3 deletions data_safe_haven/infrastructure/programs/sre/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_id_from_vnet,
get_name_from_vnet,
)
from data_safe_haven.types import NetworkingPriorities, Ports
from data_safe_haven.types import AzureServiceTag, NetworkingPriorities, Ports


class SRENetworkingProps:
Expand All @@ -31,7 +31,7 @@ def __init__(
shm_subscription_id: Input[str],
shm_zone_name: Input[str],
sre_name: Input[str],
user_public_ip_ranges: Input[list[str]],
user_public_ip_ranges: Input[list[str]] | AzureServiceTag,
) -> None:
# Other variables
self.dns_private_zones = dns_private_zones
Expand Down Expand Up @@ -68,6 +68,13 @@ def __init__(
child_opts = ResourceOptions.merge(opts, ResourceOptions(parent=self))
child_tags = {"component": "networking"} | (tags if tags else {})

if isinstance(props.user_public_ip_ranges, list):
user_public_ip_ranges = props.user_public_ip_ranges
user_service_tag = None
else:
user_public_ip_ranges = None
user_service_tag = props.user_public_ip_ranges

# Define route table
route_table = network.RouteTable(
f"{self._name}_route_table",
Expand Down Expand Up @@ -125,7 +132,8 @@ def __init__(
name="AllowUsersInternetInbound",
priority=NetworkingPriorities.AUTHORISED_EXTERNAL_USER_IPS,
protocol=network.SecurityRuleProtocol.TCP,
source_address_prefixes=props.user_public_ip_ranges,
source_address_prefix=user_service_tag,
source_address_prefixes=user_public_ip_ranges,
source_port_range="*",
),
network.SecurityRuleArgs(
Expand Down
2 changes: 2 additions & 0 deletions data_safe_haven/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .enums import (
AzureDnsZoneNames,
AzureSdkCredentialScope,
AzureServiceTag,
DatabaseSystem,
FirewallPriorities,
ForbiddenDomains,
Expand All @@ -29,6 +30,7 @@
"AzureDnsZoneNames",
"AzureLocation",
"AzurePremiumFileShareSize",
"AzureServiceTag",
"AzureSdkCredentialScope",
"AzureSubscriptionName",
"AzureVmSku",
Expand Down
5 changes: 5 additions & 0 deletions data_safe_haven/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class AzureSdkCredentialScope(str, Enum):
KEY_VAULT = "https://vault.azure.net"


@verify(UNIQUE)
class AzureServiceTag(str, Enum):
INTERNET = "Internet"


@verify(UNIQUE)
class DatabaseSystem(str, Enum):
MICROSOFT_SQL_SERVER = "mssql"
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/validators/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def ip_address(ip_address: str) -> str:
try:
return str(ipaddress.ip_network(ip_address))
except Exception as exc:
msg = "Expected valid IPv4 address, for example '1.1.1.1'."
msg = "Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'."
raise ValueError(msg) from exc


Expand Down
24 changes: 23 additions & 1 deletion tests/config/test_config_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
ConfigSubsectionRemoteDesktopOpts,
ConfigSubsectionStorageQuotaGB,
)
from data_safe_haven.types import DatabaseSystem, SoftwarePackageCategory
from data_safe_haven.types import (
AzureServiceTag,
DatabaseSystem,
SoftwarePackageCategory,
)


class TestConfigSectionAzure:
Expand Down Expand Up @@ -184,6 +188,24 @@ def test_ip_overlap_research_user(self):
research_user_ip_addresses=["1.2.3.4", "1.2.3.4"],
)

def test_research_user_tag_internet(
self,
config_subsection_remote_desktop: ConfigSubsectionRemoteDesktopOpts,
config_subsection_storage_quota_gb: ConfigSubsectionStorageQuotaGB,
):
sre_config = ConfigSectionSRE(
admin_email_address="[email protected]",
remote_desktop=config_subsection_remote_desktop,
storage_quota_gb=config_subsection_storage_quota_gb,
research_user_ip_addresses="Internet",
)
assert isinstance(sre_config.research_user_ip_addresses, AzureServiceTag)
assert sre_config.research_user_ip_addresses == "Internet"

def test_research_user_tag_invalid(self):
with pytest.raises(ValueError, match="Input should be 'Internet'"):
ConfigSectionSRE(research_user_ip_addresses="Not a tag")

@pytest.mark.parametrize(
"addresses",
[
Expand Down
30 changes: 30 additions & 0 deletions tests/validators/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,36 @@ def test_fqdn_fail(self, fqdn):
validators.fqdn(fqdn)


class TestValidateIpAddress:
@pytest.mark.parametrize(
"ip_address,output",
[
("127.0.0.1", "127.0.0.1/32"),
("0.0.0.0/0", "0.0.0.0/0"),
("192.168.171.1/32", "192.168.171.1/32"),
],
)
def test_ip_address(self, ip_address, output):
assert validators.ip_address(ip_address) == output

@pytest.mark.parametrize(
"ip_address",
[
"example.com",
"University of Life",
"999.999.999.999",
"0.0.0.0/-1",
"255.255.255.0/2",
],
)
def test_ip_address_fail(self, ip_address):
with pytest.raises(
ValueError,
match="Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'.",
):
validators.ip_address(ip_address)


class TestValidateSafeString:
@pytest.mark.parametrize(
"safe_string",
Expand Down

0 comments on commit 727f43d

Please sign in to comment.