From 63c52611a9d91639e3232a600d8fb7cf3676e584 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Wed, 20 Sep 2023 19:32:59 +0100 Subject: [PATCH] :rotating_light: Fix linting errors in smoke tests --- .../components/dynamic/file_upload.py | 4 +++- .../infrastructure/stacks/sre/workspaces.py | 10 +++++--- .../workspace/test_databases_python.py | 23 ++++++++++++++----- .../workspace/test_functionality_python.py | 10 ++++---- pyproject.toml | 4 ++++ 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/data_safe_haven/infrastructure/components/dynamic/file_upload.py b/data_safe_haven/infrastructure/components/dynamic/file_upload.py index 731a662899..4f1f259c47 100644 --- a/data_safe_haven/infrastructure/components/dynamic/file_upload.py +++ b/data_safe_haven/infrastructure/components/dynamic/file_upload.py @@ -29,7 +29,9 @@ def __init__( self.file_hash = file_hash self.file_target = file_target self.file_permissions = file_permissions - self.force_refresh = Output.from_input(force_refresh).apply(lambda force: force if force else False) + self.force_refresh = Output.from_input(force_refresh).apply( + lambda force: force if force else False + ) self.subscription_name = subscription_name self.vm_name = vm_name self.vm_resource_group_name = vm_resource_group_name diff --git a/data_safe_haven/infrastructure/stacks/sre/workspaces.py b/data_safe_haven/infrastructure/stacks/sre/workspaces.py index fa854168f1..fdf3d46c4d 100644 --- a/data_safe_haven/infrastructure/stacks/sre/workspaces.py +++ b/data_safe_haven/infrastructure/stacks/sre/workspaces.py @@ -177,10 +177,12 @@ def __init__( ] # Upload smoke tests - mustache_values={ + mustache_values = { "check_uninstallable_packages": "0", } - file_uploads = [(FileReader(resources_path / "workspace" / "run_all_tests.bats"), "0444")] + file_uploads = [ + (FileReader(resources_path / "workspace" / "run_all_tests.bats"), "0444") + ] for test_file in pathlib.Path(resources_path / "workspace").glob("test*"): file_uploads.append((FileReader(test_file), "0444")) for vm, vm_output in zip(vms, vm_outputs, strict=True): @@ -189,7 +191,9 @@ def __init__( file_smoke_test = FileUpload( replace_separators(f"{self._name}_file_{file_upload.name}", "_"), FileUploadProps( - file_contents=file_upload.file_contents(mustache_values=mustache_values), + file_contents=file_upload.file_contents( + mustache_values=mustache_values + ), file_hash=file_upload.sha256(), file_permissions=file_permissions, file_target=f"/opt/tests/{file_upload.name}", diff --git a/data_safe_haven/resources/workspace/test_databases_python.py b/data_safe_haven/resources/workspace/test_databases_python.py index 37a37acb91..ab0f01a3fe 100644 --- a/data_safe_haven/resources/workspace/test_databases_python.py +++ b/data_safe_haven/resources/workspace/test_databases_python.py @@ -6,8 +6,17 @@ import pymssql -def test_database(server_name, hostname, port, db_type, db_name, username, password): - print(f"Attempting to connect to '{db_name}' on '{server_name}' via port {port}") +def test_database( + server_name: str, + hostname: str, + port: int, + db_type: str, + db_name: str, + username: str, + password: str, +) -> None: + msg = f"Attempting to connect to '{db_name}' on '{server_name}' via port {port}" + print(msg) # noqa: T201 username_full = f"{username}@{hostname}" cnxn = None if db_type == "mssql": @@ -18,13 +27,15 @@ def test_database(server_name, hostname, port, db_type, db_name, username, passw connection_string = f"host={server_name} port={port} dbname={db_name} user={username_full} password={password}" cnxn = psycopg.connect(connection_string) else: - raise ValueError(f"Database type '{db_type}' was not recognised") + msg = f"Database type '{db_type}' was not recognised" + raise ValueError(msg) df = pd.read_sql("SELECT * FROM information_schema.tables;", cnxn) if df.size: - print(df.head(5)) - print("All database tests passed") + print(df.head(5)) # noqa: T201 + print("All database tests passed") # noqa: T201 else: - raise ValueError(f"Reading from database '{db_name}' failed.") + msg = f"Reading from database '{db_name}' failed." + raise ValueError(msg) # Parse command line arguments diff --git a/data_safe_haven/resources/workspace/test_functionality_python.py b/data_safe_haven/resources/workspace/test_functionality_python.py index 9ca9662d98..855e5e5f15 100644 --- a/data_safe_haven/resources/workspace/test_functionality_python.py +++ b/data_safe_haven/resources/workspace/test_functionality_python.py @@ -4,7 +4,7 @@ from sklearn.linear_model import LogisticRegression -def gen_data(n_samples, n_points): +def gen_data(n_samples: int, n_points: int) -> pd.DataFrame: """Generate data for fitting""" target = np.random.binomial(n=1, p=0.5, size=(n_samples, 1)) theta = np.random.normal(loc=0.0, scale=1.0, size=(1, n_points)) @@ -12,13 +12,13 @@ def gen_data(n_samples, n_points): values = np.random.multivariate_normal( means, np.diag([1] * n_points), size=n_samples ).T - data = dict(("x{}".format(n), values[n]) for n in range(n_points)) + data = {f"x{n}": values[n] for n in range(n_points)} data["y"] = target.reshape((n_samples,)) data["weights"] = np.random.gamma(shape=1, scale=1.0, size=n_samples) return pd.DataFrame(data=data) -def main(): +def main() -> None: """Logistic regression""" data = gen_data(100, 3) input_data = data.iloc[:, :-2] @@ -29,8 +29,8 @@ def main(): logit.fit(input_data, output_data, sample_weight=weights) logit.score(input_data, output_data, sample_weight=weights) - print("Logistic model ran OK") - print("All functionality tests passed") + print("Logistic model ran OK") # noqa: T201 + print("All functionality tests passed") # noqa: T201 if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index e033e8459b..45d64e4052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,11 +152,15 @@ module = [ "cryptography.*", "dns.*", "msal.*", + "numpy.*", + "pandas.*", "psycopg.*", "pulumi.*", "pulumi_azure_native.*", + "pymssql.*", "rich.*", "simple_acme_dns.*", + "sklearn.*", "typer.*", "websocket.*", ]