From ba05dd7e48be9bb6e1cc439bae7cfad681f8ce7d Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 15:23:51 +0000 Subject: [PATCH] Correct typer validator factory --- data_safe_haven/functions/typer_validators.py | 1 + tests_/functions/test_typer_validators.py | 28 +++++++++++++++++++ tests_/functions/test_validators.py | 27 ++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 tests_/functions/test_typer_validators.py create mode 100644 tests_/functions/test_validators.py diff --git a/data_safe_haven/functions/typer_validators.py b/data_safe_haven/functions/typer_validators.py index 30d6227249..df72cd149c 100644 --- a/data_safe_haven/functions/typer_validators.py +++ b/data_safe_haven/functions/typer_validators.py @@ -16,6 +16,7 @@ def typer_validator_factory(validator: Callable[[Any], Any]) -> Callable[[Any], def typer_validator(x: Any) -> Any: try: validator(x) + return x except ValueError as exc: raise BadParameter(str(exc)) from exc diff --git a/tests_/functions/test_typer_validators.py b/tests_/functions/test_typer_validators.py new file mode 100644 index 0000000000..0c9c02ff90 --- /dev/null +++ b/tests_/functions/test_typer_validators.py @@ -0,0 +1,28 @@ +import pytest +from typer import BadParameter + +from data_safe_haven.functions.typer_validators import typer_validate_aad_guid + + +class TestTyperValidateAadGuid: + @pytest.mark.parametrize( + "guid", + [ + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "10de18e7-b238-6f1e-a4ad-772708929203", + ] + ) + def test_typer_validate_aad_guid(self, guid): + assert typer_validate_aad_guid(guid) == guid + + @pytest.mark.parametrize( + "guid", + [ + "10de18e7_b238_6f1e_a4ad_772708929203", + "not a guid", + ] + ) + def test_typer_validate_aad_guid_fail(self, guid): + with pytest.raises(BadParameter) as exc: + typer_validate_aad_guid(guid) + assert "Expected GUID" in exc diff --git a/tests_/functions/test_validators.py b/tests_/functions/test_validators.py new file mode 100644 index 0000000000..bcf41ca263 --- /dev/null +++ b/tests_/functions/test_validators.py @@ -0,0 +1,27 @@ +import pytest + +from data_safe_haven.functions.validators import validate_aad_guid + + +class TestValidateAadGuid: + @pytest.mark.parametrize( + "guid", + [ + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "10de18e7-b238-6f1e-a4ad-772708929203", + ] + ) + def test_validate_aad_guid(self, guid): + assert validate_aad_guid(guid) == guid + + @pytest.mark.parametrize( + "guid", + [ + "10de18e7_b238_6f1e_a4ad_772708929203", + "not a guid", + ] + ) + def test_validate_aad_guid_fail(self, guid): + with pytest.raises(ValueError) as exc: + validate_aad_guid(guid) + assert "Expected GUID" in exc