diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 552bc35713..0ce8044bce 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -1084,6 +1084,42 @@ def run_remote_script( msg = f"Failed to run command on '{vm_name}'.\n{exc}" raise DataSafeHavenAzureError(msg) from exc + def run_remote_script_waiting( + self, + resource_group_name: str, + script: str, + script_parameters: dict[str, str], + vm_name: str, + ) -> str: + """Run a script on a remote virtual machine waiting for other scripts to complete + + Returns: + str: The script output + + Raises: + DataSafeHavenAzureError if running the script failed + """ + while True: + try: + script_output = self.run_remote_script( + resource_group_name=resource_group_name, + script=script, + script_parameters=script_parameters, + vm_name=vm_name, + ) + break + except DataSafeHavenAzureError as exc: + if all( + reason not in str(exc) + for reason in ( + "The request failed due to conflict with a concurrent request", + "Run command extension execution is in progress", + ) + ): + raise + time.sleep(5) + return script_output + def set_blob_container_acl( self, container_name: str, diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index a329b50437..1b1a00b677 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -361,7 +361,9 @@ def create_token_administrator(self) -> str: result = None try: # Load local token cache - local_token_cache = LocalTokenCache(pathlib.Path.home() / ".msal_cache") + local_token_cache = LocalTokenCache( + pathlib.Path.home() / f".msal_cache_{self.tenant_id}" + ) # Use the Powershell application by default as this should be pre-installed app = PublicClientApplication( authority=f"https://login.microsoftonline.com/{self.tenant_id}", diff --git a/data_safe_haven/functions/strings.py b/data_safe_haven/functions/strings.py index a3dac5e106..27089eeaed 100644 --- a/data_safe_haven/functions/strings.py +++ b/data_safe_haven/functions/strings.py @@ -90,7 +90,7 @@ def seeded_uuid(seed: str) -> uuid.UUID: def sha256hash(input_string: str) -> str: """Return the SHA256 hash of a string as a string.""" - return hashlib.sha256(str.encode(input_string, encoding="utf-8")).hexdigest() + return hashlib.sha256(input_string.encode("utf-8")).hexdigest() def truncate_tokens(tokens: Sequence[str], max_length: int) -> list[str]: diff --git a/data_safe_haven/infrastructure/components/__init__.py b/data_safe_haven/infrastructure/components/__init__.py index ee872fabe0..6fcb8d3f9b 100644 --- a/data_safe_haven/infrastructure/components/__init__.py +++ b/data_safe_haven/infrastructure/components/__init__.py @@ -20,6 +20,8 @@ CompiledDscProps, FileShareFile, FileShareFileProps, + FileUpload, + FileUploadProps, RemoteScript, RemoteScriptProps, SSLCertificate, @@ -41,6 +43,8 @@ "CompiledDscProps", "FileShareFile", "FileShareFileProps", + "FileUpload", + "FileUploadProps", "LinuxVMComponentProps", "LocalDnsRecordComponent", "LocalDnsRecordProps", diff --git a/data_safe_haven/infrastructure/components/dynamic/__init__.py b/data_safe_haven/infrastructure/components/dynamic/__init__.py index 2fe0f8decb..4fdfb12dfc 100644 --- a/data_safe_haven/infrastructure/components/dynamic/__init__.py +++ b/data_safe_haven/infrastructure/components/dynamic/__init__.py @@ -2,7 +2,8 @@ from .blob_container_acl import BlobContainerAcl, BlobContainerAclProps from .compiled_dsc import CompiledDsc, CompiledDscProps from .file_share_file import FileShareFile, FileShareFileProps -from .remote_powershell import RemoteScript, RemoteScriptProps +from .file_upload import FileUpload, FileUploadProps +from .remote_script import RemoteScript, RemoteScriptProps from .ssl_certificate import SSLCertificate, SSLCertificateProps __all__ = [ @@ -14,6 +15,8 @@ "CompiledDscProps", "FileShareFile", "FileShareFileProps", + "FileUpload", + "FileUploadProps", "RemoteScript", "RemoteScriptProps", "SSLCertificate", diff --git a/data_safe_haven/infrastructure/components/dynamic/file_upload.py b/data_safe_haven/infrastructure/components/dynamic/file_upload.py new file mode 100644 index 0000000000..4f1f259c47 --- /dev/null +++ b/data_safe_haven/infrastructure/components/dynamic/file_upload.py @@ -0,0 +1,144 @@ +"""Pulumi dynamic component for running remote scripts on an Azure VM.""" +from typing import Any + +from pulumi import Input, Output, ResourceOptions +from pulumi.dynamic import CreateResult, DiffResult, Resource, UpdateResult + +from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.external import AzureApi +from data_safe_haven.functions import b64encode + +from .dsh_resource_provider import DshResourceProvider + + +class FileUploadProps: + """Props for the FileUpload class""" + + def __init__( + self, + file_contents: Input[str], + file_hash: Input[str], + file_permissions: Input[str], + file_target: Input[str], + subscription_name: Input[str], + vm_name: Input[str], + vm_resource_group_name: Input[str], + force_refresh: Input[bool] | None = None, + ) -> None: + self.file_contents = file_contents + self.file_hash = file_hash + self.file_target = file_target + self.file_permissions = file_permissions + self.force_refresh = Output.from_input(force_refresh).apply( + lambda force: force if force else False + ) + self.subscription_name = subscription_name + self.vm_name = vm_name + self.vm_resource_group_name = vm_resource_group_name + + +class FileUploadProvider(DshResourceProvider): + def create(self, props: dict[str, Any]) -> CreateResult: + """Run a remote script to create a file on a VM""" + outs = dict(**props) + azure_api = AzureApi(props["subscription_name"], disable_logging=True) + script_contents = f""" + target_dir=$(dirname "$target"); + mkdir -p $target_dir 2> /dev/null; + echo $contents_b64 | base64 --decode > $target; + chmod {props['file_permissions']} $target; + if [ -f "$target" ]; then + echo "Wrote file to $target"; + else + echo "Failed to write file to $target"; + fi + """ + script_parameters = { + "contents_b64": b64encode(props["file_contents"]), + "target": props["file_target"], + } + # Run remote script + script_output = azure_api.run_remote_script_waiting( + props["vm_resource_group_name"], + script_contents, + script_parameters, + props["vm_name"], + ) + outs["script_output"] = "\n".join( + [ + line.strip() + for line in script_output.replace("Enable succeeded:", "").split("\n") + if line + ] + ) + if "Failed to write" in outs["script_output"]: + raise DataSafeHavenAzureError(outs["script_output"]) + return CreateResult( + f"FileUpload-{props['file_hash']}", + outs=outs, + ) + + def delete(self, id_: str, props: dict[str, Any]) -> None: + """Delete the remote file from the VM""" + # Use `id` as a no-op to avoid ARG002 while maintaining function signature + id(id_) + azure_api = AzureApi(props["subscription_name"], disable_logging=True) + script_contents = """ + rm -f "$target"; + echo "Removed file at $target"; + """ + script_parameters = { + "target": props["file_target"], + } + # Run remote script + azure_api.run_remote_script_waiting( + props["vm_resource_group_name"], + script_contents, + script_parameters, + props["vm_name"], + ) + + def diff( + self, + id_: str, + old_props: dict[str, Any], + new_props: dict[str, Any], + ) -> DiffResult: + """Calculate diff between old and new state""" + # Use `id` as a no-op to avoid ARG002 while maintaining function signature + id(id_) + if new_props["force_refresh"]: + return DiffResult( + changes=True, + replaces=list(new_props.keys()), + stables=[], + delete_before_replace=False, + ) + return self.partial_diff(old_props, new_props, []) + + def update( + self, + id_: str, + old_props: dict[str, Any], + new_props: dict[str, Any], + ) -> UpdateResult: + """Updating is creating without the need to delete.""" + # Use `id` as a no-op to avoid ARG002 while maintaining function signature + id((id_, old_props)) + updated = self.create(new_props) + return UpdateResult(outs=updated.outs) + + +class FileUpload(Resource): + script_output: Output[str] + _resource_type_name = "dsh:common:FileUpload" # set resource type + + def __init__( + self, + name: str, + props: FileUploadProps, + opts: ResourceOptions | None = None, + ): + super().__init__( + FileUploadProvider(), name, {"script_output": None, **vars(props)}, opts + ) diff --git a/data_safe_haven/infrastructure/components/dynamic/remote_powershell.py b/data_safe_haven/infrastructure/components/dynamic/remote_script.py similarity index 100% rename from data_safe_haven/infrastructure/components/dynamic/remote_powershell.py rename to data_safe_haven/infrastructure/components/dynamic/remote_script.py diff --git a/data_safe_haven/infrastructure/stacks/declarative_sre.py b/data_safe_haven/infrastructure/stacks/declarative_sre.py index 439da4b8ca..907ebfe667 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_sre.py +++ b/data_safe_haven/infrastructure/stacks/declarative_sre.py @@ -253,6 +253,7 @@ def run(self) -> None: storage_account_data_private_user_name=data.storage_account_data_private_user_name, storage_account_data_private_sensitive_name=data.storage_account_data_private_sensitive_name, subnet_workspaces=networking.subnet_workspaces, + subscription_name=self.cfg.subscription_name, virtual_network_resource_group=networking.resource_group, virtual_network=networking.virtual_network, vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)), diff --git a/data_safe_haven/infrastructure/stacks/sre/workspaces.py b/data_safe_haven/infrastructure/stacks/sre/workspaces.py index 2821b6f669..fdf3d46c4d 100644 --- a/data_safe_haven/infrastructure/stacks/sre/workspaces.py +++ b/data_safe_haven/infrastructure/stacks/sre/workspaces.py @@ -1,3 +1,4 @@ +import pathlib from collections.abc import Mapping from typing import Any @@ -14,10 +15,13 @@ get_name_from_vnet, ) from data_safe_haven.infrastructure.components import ( + FileUpload, + FileUploadProps, LinuxVMComponentProps, VMComponent, ) from data_safe_haven.resources import resources_path +from data_safe_haven.utility import FileReader class SREWorkspacesProps: @@ -43,6 +47,7 @@ def __init__( storage_account_data_private_user_name: Input[str], storage_account_data_private_sensitive_name: Input[str], subnet_workspaces: Input[network.GetSubnetResult], + subscription_name: Input[str], virtual_network_resource_group: Input[resources.ResourceGroup], virtual_network: Input[network.VirtualNetwork], vm_details: list[tuple[int, str]], # this must *not* be passed as an Input[T] @@ -69,6 +74,7 @@ def __init__( self.storage_account_data_private_sensitive_name = ( storage_account_data_private_sensitive_name ) + self.subscription_name = subscription_name self.virtual_network_name = Output.from_input(virtual_network).apply( get_name_from_vnet ) @@ -161,7 +167,7 @@ def __init__( ] # Get details for each deployed VM - vm_outputs = [ + vm_outputs: list[dict[str, Any]] = [ { "ip_address": vm.ip_address_private, "name": vm.vm_name, @@ -170,6 +176,36 @@ def __init__( for vm in vms ] + # Upload smoke tests + mustache_values = { + "check_uninstallable_packages": "0", + } + file_uploads = [ + (FileReader(resources_path / "workspace" / "run_all_tests.bats"), "0444") + ] + for test_file in pathlib.Path(resources_path / "workspace").glob("test*"): + file_uploads.append((FileReader(test_file), "0444")) + for vm, vm_output in zip(vms, vm_outputs, strict=True): + outputs: dict[str, Output[str]] = {} + for file_upload, file_permissions in file_uploads: + file_smoke_test = FileUpload( + replace_separators(f"{self._name}_file_{file_upload.name}", "_"), + FileUploadProps( + file_contents=file_upload.file_contents( + mustache_values=mustache_values + ), + file_hash=file_upload.sha256(), + file_permissions=file_permissions, + file_target=f"/opt/tests/{file_upload.name}", + subscription_name=props.subscription_name, + vm_name=vm.vm_name, + vm_resource_group_name=resource_group.name, + ), + opts=child_opts, + ) + outputs[file_upload.name] = file_smoke_test.script_output + vm_output["file_uploads"] = outputs + # Register outputs self.resource_group = resource_group diff --git a/data_safe_haven/resources/software_repositories/allowlists/cran.allowlist b/data_safe_haven/resources/software_repositories/allowlists/cran.allowlist index d65ef196ea..9624ec7060 100644 --- a/data_safe_haven/resources/software_repositories/allowlists/cran.allowlist +++ b/data_safe_haven/resources/software_repositories/allowlists/cran.allowlist @@ -1,4 +1,5 @@ DBI +MASS RPostgres Rcpp bit diff --git a/data_safe_haven/resources/software_repositories/allowlists/pypi.allowlist b/data_safe_haven/resources/software_repositories/allowlists/pypi.allowlist index 3ab3c07dfe..704937893f 100644 --- a/data_safe_haven/resources/software_repositories/allowlists/pypi.allowlist +++ b/data_safe_haven/resources/software_repositories/allowlists/pypi.allowlist @@ -15,6 +15,7 @@ pyodbc pyparsing python-dateutil pytz +scikit-learn six typing-extensions tzdata diff --git a/data_safe_haven/resources/workspace/run_all_tests.bats b/data_safe_haven/resources/workspace/run_all_tests.bats new file mode 100644 index 0000000000..800a55cd3d --- /dev/null +++ b/data_safe_haven/resources/workspace/run_all_tests.bats @@ -0,0 +1,128 @@ +#! /usr/bin/env bats + + +# Helper functions +# ---------------- +initialise_python_environment() { + ENV_PATH="${HOME}/.local/bats-python-environment" + rm -rf "$ENV_PATH" + python -m venv "$ENV_PATH" + source "${ENV_PATH}/bin/activate" + pip install --upgrade pip --quiet +} + +initialise_r_environment() { + ENV_PATH="${HOME}/.local/bats-r-environment" + rm -rf "$ENV_PATH" + mkdir -p "$ENV_PATH" +} + +install_r_package() { + PACKAGE_NAME="$1" + ENV_PATH="${HOME}/.local/bats-r-environment" + Rscript -e "install.packages('$PACKAGE_NAME', lib='$ENV_PATH');" +} + +install_r_package_version() { + PACKAGE_NAME="$1" + PACKAGE_VERSION="$2" + ENV_PATH="${HOME}/.local/bats-r-environment" + Rscript -e "install.packages('remotes', lib='$ENV_PATH');" + Rscript -e "library('remotes', lib='$ENV_PATH'); remotes::install_version(package='$PACKAGE_NAME', version='$PACKAGE_VERSION', lib='$ENV_PATH');" +} + +check_db_credentials() { + db_credentials="${HOME}/.local/db.dsh" + if [ -f "$db_credentials" ]; then + return 0 + fi + return 1 +} + + +# Mounted drives +# -------------- +@test "Mounted drives (/data)" { + run bash test_mounted_drives.sh -d data + [ "$status" -eq 0 ] +} +@test "Mounted drives (/home)" { + run bash test_mounted_drives.sh -d home + [ "$status" -eq 0 ] +} +@test "Mounted drives (/output)" { + run bash test_mounted_drives.sh -d output + [ "$status" -eq 0 ] +} +@test "Mounted drives (/shared)" { + run bash test_mounted_drives.sh -d shared + [ "$status" -eq 0 ] +} + + +# Package repositories +# -------------------- +@test "Python package repository" { + initialise_python_environment + run bash test_repository_python.sh 2>&1 + [ "$status" -eq 0 ] +} +@test "R package repository" { + initialise_r_environment + run bash test_repository_R.sh + [ "$status" -eq 0 ] +} + + +# Language functionality +# ---------------------- +@test "Python functionality" { + initialise_python_environment + pip install numpy pandas scikit-learn --quiet + run python test_functionality_python.py 2>&1 + [ "$status" -eq 0 ] +} +@test "R functionality" { + initialise_r_environment + install_r_package_version "MASS" "7.3-52" + run Rscript test_functionality_R.R + [ "$status" -eq 0 ] +} + + +# Databases +# --------- +# Test MS SQL database +@test "MS SQL database (Python)" { + check_db_credentials || skip "No database credentials available" + initialise_python_environment + pip install pandas psycopg pymssql --quiet + run bash test_databases.sh -d mssql -l python + [ "$status" -eq 0 ] +} +@test "MS SQL database (R)" { + check_db_credentials || skip "No database credentials available" + initialise_r_environment + install_r_package "DBI" + install_r_package "odbc" + install_r_package "RPostgres" + run bash test_databases.sh -d mssql -l R + [ "$status" -eq 0 ] +} +# Test Postgres database +@test "Postgres database (Python)" { + check_db_credentials || skip "No database credentials available" + initialise_python_environment + pip install pandas psycopg pymssql --quiet + run bash test_databases.sh -d postgresql -l python + [ "$status" -eq 0 ] +} +@test "Postgres database (R)" { + check_db_credentials || skip "No database credentials available" + initialise_r_environment + install_r_package "DBI" + install_r_package "odbc" + install_r_package "RPostgres" + run bash test_databases.sh -d postgresql -l R + [ "$status" -eq 0 ] +} diff --git a/data_safe_haven/resources/workspace/test_databases.sh b/data_safe_haven/resources/workspace/test_databases.sh new file mode 100644 index 0000000000..69fd7a456c --- /dev/null +++ b/data_safe_haven/resources/workspace/test_databases.sh @@ -0,0 +1,51 @@ +#! /bin/bash +db_type="" +language="" +while getopts d:l: flag; do + case "${flag}" in + d) db_type=${OPTARG} ;; + l) language=${OPTARG} ;; + *) + echo "Invalid option ${OPTARG}" + exit 1 + ;; + esac +done + +db_credentials="${HOME}/.local/db.dsh" +if [ -f "$db_credentials" ]; then + username="databaseadmin" + password="$(cat "$db_credentials")" +else + echo "Credentials file ($db_credentials) not found." + exit 1 +fi + +sre_fqdn="$(grep trusted /etc/pip.conf | cut -d "." -f 2-99)" +sre_prefix="$(hostname | cut -d "-" -f 1-4)" +if [ "$db_type" == "mssql" ]; then + db_name="master" + port="1433" + server_name="mssql.${sre_fqdn}" + hostname="${sre_prefix}-db-server-mssql" +elif [ "$db_type" == "postgresql" ]; then + db_name="postgres" + port="5432" + server_name="postgresql.${sre_fqdn}" + hostname="${sre_prefix}-db-server-postgresql" +else + echo "Did not recognise database type '$db_type'" + exit 1 +fi + +if [ "$port" == "" ]; then + echo "Database type '$db_type' is not part of this SRE" + exit 1 +else + script_path=$(dirname "$(readlink -f "$0")") + if [ "$language" == "python" ]; then + python "${script_path}"/test_databases_python.py --db-type "$db_type" --db-name "$db_name" --port "$port" --server-name "$server_name" --hostname "$hostname" --username "$username" --password "$password" || exit 1 + elif [ "$language" == "R" ]; then + Rscript "${script_path}"/test_databases_R.R "$db_type" "$db_name" "$port" "$server_name" "$hostname" "$username" "$password" || exit 1 + fi +fi diff --git a/data_safe_haven/resources/workspace/test_databases_R.R b/data_safe_haven/resources/workspace/test_databases_R.R new file mode 100644 index 0000000000..a261f21532 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_databases_R.R @@ -0,0 +1,51 @@ +#!/usr/bin/env Rscript +library(DBI, lib.loc='~/.local/bats-r-environment') +library(odbc, lib.loc='~/.local/bats-r-environment') +library(RPostgres, lib.loc='~/.local/bats-r-environment') + +# Parse command line arguments +args = commandArgs(trailingOnly=TRUE) +if (length(args)!=7) { + stop("Exactly seven arguments are required: db_type, db_name, port, server_name, hostname, username and password") +} +db_type = args[1] +db_name = args[2] +port = args[3] +server_name = args[4] +hostname = args[5] +username = args[6] +password = args[7] + +# Connect to the database +print(paste("Attempting to connect to '", db_name, "' on '", server_name, "' via port '", port, sep="")) +if (db_type == "mssql") { + cnxn <- DBI::dbConnect( + odbc::odbc(), + Driver = "ODBC Driver 17 for SQL Server", + Server = paste(server_name, port, sep=","), + Database = db_name, + # Trusted_Connection = "yes", + UID = paste(username, "@", hostname, sep=""), + PWD = password + ) +} else if (db_type == "postgresql") { + cnxn <- DBI::dbConnect( + RPostgres::Postgres(), + host = server_name, + port = port, + dbname = db_name, + user = paste(username, "@", hostname, sep=""), + password = password + ) +} else { + stop(paste("Database type '", db_type, "' was not recognised", sep="")) +} + +# Run a query and save the output into a dataframe +df <- dbGetQuery(cnxn, "SELECT * FROM information_schema.tables;") +if (dim(df)[1] > 0) { + print(head(df, 5)) + print("All database tests passed") +} else { + stop(paste("Reading from database '", db_name, "' failed", sep="")) +} diff --git a/data_safe_haven/resources/workspace/test_databases_python.py b/data_safe_haven/resources/workspace/test_databases_python.py new file mode 100644 index 0000000000..ab0f01a3fe --- /dev/null +++ b/data_safe_haven/resources/workspace/test_databases_python.py @@ -0,0 +1,66 @@ +#! /usr/bin/env python +import argparse + +import pandas as pd +import psycopg +import pymssql + + +def test_database( + server_name: str, + hostname: str, + port: int, + db_type: str, + db_name: str, + username: str, + password: str, +) -> None: + msg = f"Attempting to connect to '{db_name}' on '{server_name}' via port {port}" + print(msg) # noqa: T201 + username_full = f"{username}@{hostname}" + cnxn = None + if db_type == "mssql": + cnxn = pymssql.connect( + server=server_name, user=username_full, password=password, database=db_name + ) + elif db_type == "postgresql": + connection_string = f"host={server_name} port={port} dbname={db_name} user={username_full} password={password}" + cnxn = psycopg.connect(connection_string) + else: + msg = f"Database type '{db_type}' was not recognised" + raise ValueError(msg) + df = pd.read_sql("SELECT * FROM information_schema.tables;", cnxn) + if df.size: + print(df.head(5)) # noqa: T201 + print("All database tests passed") # noqa: T201 + else: + msg = f"Reading from database '{db_name}' failed." + raise ValueError(msg) + + +# Parse command line arguments +parser = argparse.ArgumentParser() +parser.add_argument( + "--db-type", + type=str, + choices=["mssql", "postgresql"], + help="Which database type to use", +) +parser.add_argument("--db-name", type=str, help="Which database to connect to") +parser.add_argument("--port", type=str, help="Which port to connect to") +parser.add_argument("--server-name", type=str, help="Which server to connect to") +parser.add_argument("--username", type=str, help="Database username") +parser.add_argument("--hostname", type=str, help="Azure hostname of the server") +parser.add_argument("--password", type=str, help="Database user password") +args = parser.parse_args() + +# Run database test +test_database( + args.server_name, + args.hostname, + args.port, + args.db_type, + args.db_name, + args.username, + args.password, +) diff --git a/data_safe_haven/resources/workspace/test_functionality_R.R b/data_safe_haven/resources/workspace/test_functionality_R.R new file mode 100644 index 0000000000..94c351e7c3 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_functionality_R.R @@ -0,0 +1,39 @@ +# Test logistic regression using R +library('MASS', lib.loc='~/.local/bats-r-environment') +library('stats') + +gen_data <- function(n = 100, p = 3) { + set.seed(1) + weights <- stats::rgamma(n = n, shape = rep(1, n), rate = rep(1, n)) + y <- stats::rbinom(n = n, size = 1, prob = 0.5) + theta <- stats::rnorm(n = p, mean = 0, sd = 1) + means <- colMeans(as.matrix(y) %*% theta) + x <- MASS::mvrnorm(n = n, means, diag(1, p, p)) + return(list(x = x, y = y, weights = weights, theta = theta)) +} + +run_logistic_regression <- function(data) { + fit <- stats::glm.fit(x = data$x, + y = data$y, + weights = data$weights, + family = stats::quasibinomial(link = "logit")) + return(fit$coefficients) +} + +data <- gen_data() +theta <- run_logistic_regression(data) +print("Logistic regression ran OK") + + +# Test clustering of random data using R +num_clusters <- 5 +N <- 10 +set.seed(0, kind = "Mersenne-Twister") +cluster_means <- runif(num_clusters, 0, 10) +means_selector <- as.integer(runif(N, 1, num_clusters + 1)) +data_means <- cluster_means[means_selector] +data <- rnorm(n = N, mean = data_means, sd = 0.5) +hc <- hclust(dist(data)) +print("Clustering ran OK") + +print("All functionality tests passed") diff --git a/data_safe_haven/resources/workspace/test_functionality_python.py b/data_safe_haven/resources/workspace/test_functionality_python.py new file mode 100644 index 0000000000..855e5e5f15 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_functionality_python.py @@ -0,0 +1,37 @@ +"""Test logistic regression using python""" +import numpy as np +import pandas as pd +from sklearn.linear_model import LogisticRegression + + +def gen_data(n_samples: int, n_points: int) -> pd.DataFrame: + """Generate data for fitting""" + target = np.random.binomial(n=1, p=0.5, size=(n_samples, 1)) + theta = np.random.normal(loc=0.0, scale=1.0, size=(1, n_points)) + means = np.mean(np.multiply(target, theta), axis=0) + values = np.random.multivariate_normal( + means, np.diag([1] * n_points), size=n_samples + ).T + data = {f"x{n}": values[n] for n in range(n_points)} + data["y"] = target.reshape((n_samples,)) + data["weights"] = np.random.gamma(shape=1, scale=1.0, size=n_samples) + return pd.DataFrame(data=data) + + +def main() -> None: + """Logistic regression""" + data = gen_data(100, 3) + input_data = data.iloc[:, :-2] + output_data = data["y"] + weights = data["weights"] + + logit = LogisticRegression(solver="liblinear") + logit.fit(input_data, output_data, sample_weight=weights) + logit.score(input_data, output_data, sample_weight=weights) + + print("Logistic model ran OK") # noqa: T201 + print("All functionality tests passed") # noqa: T201 + + +if __name__ == "__main__": + main() diff --git a/data_safe_haven/resources/workspace/test_mounted_drives.sh b/data_safe_haven/resources/workspace/test_mounted_drives.sh new file mode 100644 index 0000000000..a1812934b9 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_mounted_drives.sh @@ -0,0 +1,66 @@ +#! /bin/bash +while getopts d: flag +do + case "${flag}" in + d) directory=${OPTARG};; + *) + echo "Usage: $0 -d [directory]" + exit 1 + esac +done + +nfailed=0 +if [[ "$directory" = "home" ]]; then directory_path=$(echo ~); else directory_path="/${directory}"; fi +testfile="$(tr -dc 'a-zA-Z0-9' < /dev/urandom | fold -w 32 | head -n 1)" + +# Check that directory exists +if [ "$(ls "${directory_path}" 2>&1 1>/dev/null)" ]; then + echo "Could not find mount '${directory_path}'" + nfailed=$((nfailed + 1)) +fi + +# Test operations +CAN_CREATE="$([[ "$(touch "${directory_path}/${testfile}" 2>&1 1>/dev/null)" = "" ]] && echo '1' || echo '0')" +CAN_WRITE="$([[ -w "${directory_path}/${testfile}" ]] && echo '1' || echo '0')" +CAN_DELETE="$([[ "$(touch "${directory_path}/${testfile}" 2>&1 1>/dev/null && rm "${directory_path}/${testfile}" 2>&1)" ]] && echo '0' || echo '1')" + +# Check that permissions are as expected for each directory +case "$directory" in + data) + if [ "$CAN_CREATE" = 1 ]; then echo "Able to create files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_WRITE" = 1 ]; then echo "Able to write files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_DELETE" = 1 ]; then echo "Able to delete files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + ;; + + home) + if [ "$CAN_CREATE" = 0 ]; then echo "Unable to create files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_WRITE" = 0 ]; then echo "Unable to write files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_DELETE" = 0 ]; then echo "Unable to delete files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + ;; + + output) + if [ "$CAN_CREATE" = 0 ]; then echo "Unable to create files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_WRITE" = 0 ]; then echo "Unable to write files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_DELETE" = 0 ]; then echo "Unable to delete files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + ;; + + shared) + if [ "$CAN_CREATE" = 0 ]; then echo "Unable to create files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_WRITE" = 0 ]; then echo "Unable to write files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + if [ "$CAN_DELETE" = 0 ]; then echo "Unable to delete files in ${directory_path}!"; nfailed=$((nfailed + 1)); fi + ;; + + *) + echo "Usage: $0 -d [directory]" + exit 1 +esac + +# Cleanup and print output +rm -f "${directory_path}/${testfile}" 2> /dev/null +if [ $nfailed = 0 ]; then + echo "All tests passed for '${directory_path}'" + exit 0 +else + echo "$nfailed tests failed for '${directory_path}'!" + exit $nfailed +fi diff --git a/data_safe_haven/resources/workspace/test_repository_R.mustache.sh b/data_safe_haven/resources/workspace/test_repository_R.mustache.sh new file mode 100644 index 0000000000..03568b1e62 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_repository_R.mustache.sh @@ -0,0 +1,49 @@ +#! /bin/bash +# We need to test packages that are: +# - *not* pre-installed +# - on the tier-3 list (so we can test all tiers) +# - alphabetically early and late (so we can test the progress of the mirror synchronisation) +packages=("askpass" "zeallot") +uninstallable_packages=("aws.s3") + +# Create a temporary library directory +TEST_INSTALL_PATH="${HOME}/.local/bats-r-environment" +# TEST_INSTALL_PATH="${HOME}/test-repository-R" +# rm -rf "$TEST_INSTALL_PATH" +# mkdir -p "$TEST_INSTALL_PATH" + +# Install sample packages to local user library +N_FAILURES=0 +for package in "${packages[@]}"; do + echo "Attempting to install ${package}..." + Rscript -e "options(warn=-1); install.packages('${package}', lib='${TEST_INSTALL_PATH}', quiet=TRUE)" + if (Rscript -e "library('${package}', lib.loc='${TEST_INSTALL_PATH}')"); then + echo "... $package installation succeeded" + else + echo "... $package installation failed" + N_FAILURES=$((N_FAILURES + 1)) + fi +done +# If requested, demonstrate that installation fails for packages *not* on the approved list +TEST_FAILURE="{{check_uninstallable_packages}}" +if [ $TEST_FAILURE -eq 1 ]; then + for package in "${uninstallable_packages[@]}"; do + echo "Attempting to install ${package}..." + Rscript -e "options(warn=-1); install.packages('${package}', lib='${TEST_INSTALL_PATH}', quiet=TRUE)" + if (Rscript -e "library('${package}', lib.loc='${TEST_INSTALL_PATH}')"); then + echo "... $package installation unexpectedly succeeded!" + N_FAILURES=$((N_FAILURES + 1)) + else + echo "... $package installation failed as expected" + fi + done +fi +rm -rf "$TEST_INSTALL_PATH" + +if [ $N_FAILURES -eq 0 ]; then + echo "All package installations behaved as expected" + exit 0 +else + echo "One or more package installations did not behave as expected!" + exit $N_FAILURES +fi diff --git a/data_safe_haven/resources/workspace/test_repository_python.mustache.sh b/data_safe_haven/resources/workspace/test_repository_python.mustache.sh new file mode 100644 index 0000000000..28e46a23e1 --- /dev/null +++ b/data_safe_haven/resources/workspace/test_repository_python.mustache.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +# We need to test packages that are: +# - *not* pre-installed +# - on the allowlist (so we can test this is working) +# - alphabetically early and late (so we can test the progress of the mirror synchronisation) +installable_packages=("contourpy" "tzdata") +uninstallable_packages=("awscli") + +# Install sample packages to local user library +N_FAILURES=0 +for package in "${installable_packages[@]}"; do + echo "Attempting to install ${package}..." + if (pip install "$package" --quiet); then + echo "... $package installation succeeded" + else + echo "... $package installation failed" + N_FAILURES=$((N_FAILURES + 1)) + fi +done +# If requested, demonstrate that installation fails for packages *not* on the approved list +TEST_FAILURE="{{check_uninstallable_packages}}" +if [ $TEST_FAILURE -eq 1 ]; then + for package in "${uninstallable_packages[@]}"; do + echo "Attempting to install ${package}..." + if (pip install "$package" --quiet); then + echo "... $package installation unexpectedly succeeded!" + N_FAILURES=$((N_FAILURES + 1)) + else + echo "... $package installation failed as expected" + fi + done +fi +rm -rf "$TEST_INSTALL_PATH" + +if [ $N_FAILURES -eq 0 ]; then + echo "All package installations behaved as expected" + exit 0 +else + echo "One or more package installations did not behave as expected!" + exit $N_FAILURES +fi diff --git a/data_safe_haven/resources/workspace/workspace.cloud_init.mustache.yaml b/data_safe_haven/resources/workspace/workspace.cloud_init.mustache.yaml index 17471221de..c4216adb76 100644 --- a/data_safe_haven/resources/workspace/workspace.cloud_init.mustache.yaml +++ b/data_safe_haven/resources/workspace/workspace.cloud_init.mustache.yaml @@ -104,6 +104,8 @@ packages: - libpq-dev # interact with PostgreSQL databases - msodbcsql17 # interact with Microsoft SQL databases - unixodbc-dev # interact with Microsoft SQL databases + # Bash testing + - bats package_update: true package_upgrade: true diff --git a/pyproject.toml b/pyproject.toml index e033e8459b..45d64e4052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,11 +152,15 @@ module = [ "cryptography.*", "dns.*", "msal.*", + "numpy.*", + "pandas.*", "psycopg.*", "pulumi.*", "pulumi_azure_native.*", + "pymssql.*", "rich.*", "simple_acme_dns.*", + "sklearn.*", "typer.*", "websocket.*", ]