From 906177329bcc54f6946af361fcd3d0e334e6ce5f Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Tue, 22 Oct 2024 17:08:36 +0200 Subject: [PATCH] fix: Enforce basic Python types restriction on serialized component data (#8473) --- haystack/core/pipeline/base.py | 2 +- haystack/core/serialization.py | 97 ++++++++++++++----- haystack/marshal/yaml.py | 23 ++++- ...e-python-types-serde-0c5cd0a716bab7d0.yaml | 8 ++ test/core/test_serialization.py | 25 ++++- test/marshal/test_yaml.py | 18 ++++ 6 files changed, 143 insertions(+), 30 deletions(-) create mode 100644 releasenotes/notes/enforce-python-types-serde-0c5cd0a716bab7d0.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index c31f02ba64..3d1b1c0bba 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -111,7 +111,7 @@ def to_dict(self) -> Dict[str, Any]: """ components = {} for name, instance in self.graph.nodes(data="instance"): # type:ignore - components[name] = component_to_dict(instance) + components[name] = component_to_dict(instance, name) connections = [] for sender, receiver, edge_data in self.graph.edges.data(): diff --git a/haystack/core/serialization.py b/haystack/core/serialization.py index 15cff1e2c9..3477f0c806 100644 --- a/haystack/core/serialization.py +++ b/haystack/core/serialization.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass from importlib import import_module -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Iterable, Optional, Type from haystack.core.component.component import _hook_component_init, logger from haystack.core.errors import DeserializationError, SerializationError @@ -30,7 +30,7 @@ class DeserializationCallbacks: component_pre_init: Optional[Callable] = None -def component_to_dict(obj: Any) -> Dict[str, Any]: +def component_to_dict(obj: Any, name: str) -> Dict[str, Any]: """ Converts a component instance into a dictionary. @@ -38,37 +38,82 @@ def component_to_dict(obj: Any) -> Dict[str, Any]: :param obj: The component to be serialized. + :param name: + The name of the component. :returns: A dictionary representation of the component. :raises SerializationError: - If the component doesn't have a `to_dict` method and the values of the init parameters can't be determined. + If the component doesn't have a `to_dict` method. + If the values of the init parameters can't be determined. + If a non-basic Python type is used in the serialized data. """ if hasattr(obj, "to_dict"): - return obj.to_dict() - - init_parameters = {} - for name, param in inspect.signature(obj.__init__).parameters.items(): - # Ignore `args` and `kwargs`, used by the default constructor - if name in ("args", "kwargs"): - continue - try: - # This only works if the Component constructor assigns the init - # parameter to an instance variable or property with the same name - param_value = getattr(obj, name) - except AttributeError as e: - # If the parameter doesn't have a default value, raise an error - if param.default is param.empty: + data = obj.to_dict() + else: + init_parameters = {} + for param_name, param in inspect.signature(obj.__init__).parameters.items(): + # Ignore `args` and `kwargs`, used by the default constructor + if param_name in ("args", "kwargs"): + continue + try: + # This only works if the Component constructor assigns the init + # parameter to an instance variable or property with the same name + param_value = getattr(obj, param_name) + except AttributeError as e: + # If the parameter doesn't have a default value, raise an error + if param.default is param.empty: + raise SerializationError( + f"Cannot determine the value of the init parameter '{param_name}' " + f"for the class {obj.__class__.__name__}." + f"You can fix this error by assigning 'self.{param_name} = {param_name}' or adding a " + f"custom serialization method 'to_dict' to the class." + ) from e + # In case the init parameter was not assigned, we use the default value + param_value = param.default + init_parameters[param_name] = param_value + + data = default_to_dict(obj, **init_parameters) + + _validate_component_to_dict_output(obj, name, data) + return data + + +def _validate_component_to_dict_output(component: Any, name: str, data: Dict[str, Any]) -> None: + # Ensure that only basic Python types are used in the serde data. + def is_allowed_type(obj: Any) -> bool: + return isinstance(obj, (str, int, float, bool, list, dict, set, tuple, type(None))) + + def check_iterable(l: Iterable[Any]): + for v in l: + if not is_allowed_type(v): + raise SerializationError( + f"Component '{name}' of type '{type(component).__name__}' has an unsupported value " + f"of type '{type(v).__name__}' in the serialized data." + ) + if isinstance(v, (list, set, tuple)): + check_iterable(v) + elif isinstance(v, dict): + check_dict(v) + + def check_dict(d: Dict[str, Any]): + if any(not isinstance(k, str) for k in data.keys()): + raise SerializationError( + f"Component '{name}' of type '{type(component).__name__}' has a non-string key in the serialized data." + ) + + for k, v in d.items(): + if not is_allowed_type(v): raise SerializationError( - f"Cannot determine the value of the init parameter '{name}' for the class {obj.__class__.__name__}." - f"You can fix this error by assigning 'self.{name} = {name}' or adding a " - f"custom serialization method 'to_dict' to the class." - ) from e - # In case the init parameter was not assigned, we use the default value - param_value = param.default - init_parameters[name] = param_value - - return default_to_dict(obj, **init_parameters) + f"Component '{name}' of type '{type(component).__name__}' has an unsupported value " + f"of type '{type(v).__name__}' in the serialized data under key '{k}'." + ) + if isinstance(v, (list, set, tuple)): + check_iterable(v) + elif isinstance(v, dict): + check_dict(v) + + check_dict(data) def generate_qualified_class_name(cls: Type[object]) -> str: diff --git a/haystack/marshal/yaml.py b/haystack/marshal/yaml.py index c25498a3c5..615cd1916a 100644 --- a/haystack/marshal/yaml.py +++ b/haystack/marshal/yaml.py @@ -14,14 +14,33 @@ def construct_python_tuple(self, node: yaml.SequenceNode): return tuple(self.construct_sequence(node)) +class YamlDumper(yaml.SafeDumper): # pylint: disable=too-many-ancestors + def represent_tuple(self, data: tuple): + """Represent a Python tuple.""" + return self.represent_sequence("tag:yaml.org,2002:python/tuple", data) + + +YamlDumper.add_representer(tuple, YamlDumper.represent_tuple) YamlLoader.add_constructor("tag:yaml.org,2002:python/tuple", YamlLoader.construct_python_tuple) class YamlMarshaller: def marshal(self, dict_: Dict[str, Any]) -> str: """Return a YAML representation of the given dictionary.""" - return yaml.dump(dict_) + try: + return yaml.dump(dict_, Dumper=YamlDumper) + except yaml.representer.RepresenterError as e: + raise TypeError( + "Error dumping pipeline to YAML - Ensure that all pipeline " + "components only serialize basic Python types" + ) from e def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]: """Return a dictionary from the given YAML data.""" - return yaml.load(data_, Loader=YamlLoader) + try: + return yaml.load(data_, Loader=YamlLoader) + except yaml.constructor.ConstructorError as e: + raise TypeError( + "Error loading pipeline from YAML - Ensure that all pipeline " + "components only serialize basic Python types" + ) from e diff --git a/releasenotes/notes/enforce-python-types-serde-0c5cd0a716bab7d0.yaml b/releasenotes/notes/enforce-python-types-serde-0c5cd0a716bab7d0.yaml new file mode 100644 index 0000000000..79e3db4074 --- /dev/null +++ b/releasenotes/notes/enforce-python-types-serde-0c5cd0a716bab7d0.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Improved serialization/deserialization errors to provide extra context about the delinquent components when possible. + +fixes: + - | + Serialized data of components are now explicitly enforced to be one of the following basic Python datatypes: str, int, float, bool, list, dict, set, tuple or None. diff --git a/test/core/test_serialization.py b/test/core/test_serialization.py index d6a3980bf3..fd83122c33 100644 --- a/test/core/test_serialization.py +++ b/test/core/test_serialization.py @@ -10,13 +10,14 @@ from haystack.core.pipeline import Pipeline from haystack.core.component import component -from haystack.core.errors import DeserializationError +from haystack.core.errors import DeserializationError, SerializationError from haystack.testing import factory from haystack.core.serialization import ( default_to_dict, default_from_dict, generate_qualified_class_name, import_class_by_name, + component_to_dict, ) @@ -106,3 +107,25 @@ def test_import_class_by_name_no_valid_class(): data = "some.invalid.class" with pytest.raises(ImportError): import_class_by_name(data) + + +class CustomData: + def __init__(self, a: int, b: str) -> None: + self.a = a + self.b = b + + +@component() +class UnserializableClass: + def __init__(self, a: int, b: str, c: CustomData) -> None: + self.a = a + self.b = b + self.c = c + + def run(self): + pass + + +def test_component_to_dict_invalid_type(): + with pytest.raises(SerializationError, match="unsupported value of type 'CustomData'"): + component_to_dict(UnserializableClass(1, "s", CustomData(99, "aa")), "invalid_component") diff --git a/test/marshal/test_yaml.py b/test/marshal/test_yaml.py index 552370e9ac..1c8e550215 100644 --- a/test/marshal/test_yaml.py +++ b/test/marshal/test_yaml.py @@ -6,11 +6,23 @@ from haystack.marshal.yaml import YamlMarshaller +class InvalidClass: + def __init__(self) -> None: + self.a = 1 + self.b = None + self.c = "string" + + @pytest.fixture def yaml_data(): return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, None, True), "dict": {"set": {False}}} +@pytest.fixture +def invalid_yaml_data(): + return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, InvalidClass(), True), "dict": {"set": {False}}} + + @pytest.fixture def serialized_yaml_str(): return """key: value @@ -36,6 +48,12 @@ def test_yaml_marshal(yaml_data, serialized_yaml_str): assert marshalled.strip().replace("\n", "") == serialized_yaml_str.strip().replace("\n", "") +def test_yaml_marshal_invalid_type(invalid_yaml_data): + with pytest.raises(TypeError, match="basic Python types"): + marshaller = YamlMarshaller() + marshalled = marshaller.marshal(invalid_yaml_data) + + def test_yaml_unmarshal(yaml_data, serialized_yaml_str): marshaller = YamlMarshaller() unmarshalled = marshaller.unmarshal(serialized_yaml_str)