diff --git a/tests/test_param_validation.py b/tests/test_param_validation.py index 3ed5827..edf59ed 100644 --- a/tests/test_param_validation.py +++ b/tests/test_param_validation.py @@ -12,19 +12,40 @@ def test_minmax(case): param.validate(5) -@pytest.mark.parametrize("case", ("5.3", "hello")) -def test_integer(case): +@pytest.mark.parametrize( + "case, message", + [ + ("5.3", "5.3 is not an integer"), + ("hello", "hello is not an integer"), + (True, "True is not an integer"), + (False, "False is not an integer"), + (None, "No value supplied"), + ("", "No value supplied"), + ], +) +def test_integer(case, message): param = Parameter(name="test", type="integer") - with pytest.raises(ValidationErrors): + with pytest.raises(ValidationErrors) as exc_info: param.validate(case) + assert list(exc_info.value) == [message] param.validate(5) -@pytest.mark.parametrize("case", ("hello",)) -def test_float(case): +@pytest.mark.parametrize( + "case, message", + [ + ("hello", "hello is not a floating-point number"), + (True, "True is not a floating-point number"), + (False, "False is not a floating-point number"), + (None, "No value supplied"), + ("", "No value supplied"), + ], +) +def test_float(case, message): param = Parameter(name="test", type="float") - with pytest.raises(ValidationErrors): + with pytest.raises(ValidationErrors) as exc_info: param.validate(case) + assert list(exc_info.value) == [message] param.validate(1.5) diff --git a/valohai_yaml/objs/parameter.py b/valohai_yaml/objs/parameter.py index 53a9457..047a50a 100644 --- a/valohai_yaml/objs/parameter.py +++ b/valohai_yaml/objs/parameter.py @@ -113,12 +113,18 @@ def _validate_type(self, value: ValueAtomType, errors: List[str]) -> ValueAtomTy try: value = int(str(value), 10) except ValueError: - errors.append(f"{value} is not an integer") + if value == "": + errors.append("No value supplied") + else: + errors.append(f"{value} is not an integer") elif self.type == "float": try: value = float(str(value)) except ValueError: - errors.append(f"{value} is not a floating-point number") + if value == "": + errors.append("No value supplied") + else: + errors.append(f"{value} is not a floating-point number") return value def validate(self, value: ValueType) -> ValueType: @@ -135,6 +141,9 @@ def validate(self, value: ValueType) -> ValueType: if not self.multiple and isinstance(value, (list, tuple)): errors.append("Only a single value is allowed") + if value is None: + errors.append("No value supplied") + for atom in listify(value): if isinstance(atom, list): # type guard raise InvalidType(f"nested list atom {atom!r} not allowed")