Skip to content

Commit

Permalink
fix: Enforce basic Python types restriction on serialized component d…
Browse files Browse the repository at this point in the history
…ata (#8473)
  • Loading branch information
shadeMe authored Oct 22, 2024
1 parent a556e11 commit 9061773
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 30 deletions.
2 changes: 1 addition & 1 deletion haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
components = {}
for name, instance in self.graph.nodes(data="instance"): # type:ignore
components[name] = component_to_dict(instance)
components[name] = component_to_dict(instance, name)

connections = []
for sender, receiver, edge_data in self.graph.edges.data():
Expand Down
97 changes: 71 additions & 26 deletions haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from importlib import import_module
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Iterable, Optional, Type

from haystack.core.component.component import _hook_component_init, logger
from haystack.core.errors import DeserializationError, SerializationError
Expand All @@ -30,45 +30,90 @@ class DeserializationCallbacks:
component_pre_init: Optional[Callable] = None


def component_to_dict(obj: Any) -> Dict[str, Any]:
def component_to_dict(obj: Any, name: str) -> Dict[str, Any]:
"""
Converts a component instance into a dictionary.
If a `to_dict` method is present in the component instance, that will be used instead of the default method.
:param obj:
The component to be serialized.
:param name:
The name of the component.
:returns:
A dictionary representation of the component.
:raises SerializationError:
If the component doesn't have a `to_dict` method and the values of the init parameters can't be determined.
If the component doesn't have a `to_dict` method.
If the values of the init parameters can't be determined.
If a non-basic Python type is used in the serialized data.
"""
if hasattr(obj, "to_dict"):
return obj.to_dict()

init_parameters = {}
for name, param in inspect.signature(obj.__init__).parameters.items():
# Ignore `args` and `kwargs`, used by the default constructor
if name in ("args", "kwargs"):
continue
try:
# This only works if the Component constructor assigns the init
# parameter to an instance variable or property with the same name
param_value = getattr(obj, name)
except AttributeError as e:
# If the parameter doesn't have a default value, raise an error
if param.default is param.empty:
data = obj.to_dict()
else:
init_parameters = {}
for param_name, param in inspect.signature(obj.__init__).parameters.items():
# Ignore `args` and `kwargs`, used by the default constructor
if param_name in ("args", "kwargs"):
continue
try:
# This only works if the Component constructor assigns the init
# parameter to an instance variable or property with the same name
param_value = getattr(obj, param_name)
except AttributeError as e:
# If the parameter doesn't have a default value, raise an error
if param.default is param.empty:
raise SerializationError(
f"Cannot determine the value of the init parameter '{param_name}' "
f"for the class {obj.__class__.__name__}."
f"You can fix this error by assigning 'self.{param_name} = {param_name}' or adding a "
f"custom serialization method 'to_dict' to the class."
) from e
# In case the init parameter was not assigned, we use the default value
param_value = param.default
init_parameters[param_name] = param_value

data = default_to_dict(obj, **init_parameters)

_validate_component_to_dict_output(obj, name, data)
return data


def _validate_component_to_dict_output(component: Any, name: str, data: Dict[str, Any]) -> None:
# Ensure that only basic Python types are used in the serde data.
def is_allowed_type(obj: Any) -> bool:
return isinstance(obj, (str, int, float, bool, list, dict, set, tuple, type(None)))

def check_iterable(l: Iterable[Any]):
for v in l:
if not is_allowed_type(v):
raise SerializationError(
f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
f"of type '{type(v).__name__}' in the serialized data."
)
if isinstance(v, (list, set, tuple)):
check_iterable(v)
elif isinstance(v, dict):
check_dict(v)

def check_dict(d: Dict[str, Any]):
if any(not isinstance(k, str) for k in data.keys()):
raise SerializationError(
f"Component '{name}' of type '{type(component).__name__}' has a non-string key in the serialized data."
)

for k, v in d.items():
if not is_allowed_type(v):
raise SerializationError(
f"Cannot determine the value of the init parameter '{name}' for the class {obj.__class__.__name__}."
f"You can fix this error by assigning 'self.{name} = {name}' or adding a "
f"custom serialization method 'to_dict' to the class."
) from e
# In case the init parameter was not assigned, we use the default value
param_value = param.default
init_parameters[name] = param_value

return default_to_dict(obj, **init_parameters)
f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
f"of type '{type(v).__name__}' in the serialized data under key '{k}'."
)
if isinstance(v, (list, set, tuple)):
check_iterable(v)
elif isinstance(v, dict):
check_dict(v)

check_dict(data)


def generate_qualified_class_name(cls: Type[object]) -> str:
Expand Down
23 changes: 21 additions & 2 deletions haystack/marshal/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,33 @@ def construct_python_tuple(self, node: yaml.SequenceNode):
return tuple(self.construct_sequence(node))


class YamlDumper(yaml.SafeDumper): # pylint: disable=too-many-ancestors
def represent_tuple(self, data: tuple):
"""Represent a Python tuple."""
return self.represent_sequence("tag:yaml.org,2002:python/tuple", data)


YamlDumper.add_representer(tuple, YamlDumper.represent_tuple)
YamlLoader.add_constructor("tag:yaml.org,2002:python/tuple", YamlLoader.construct_python_tuple)


class YamlMarshaller:
def marshal(self, dict_: Dict[str, Any]) -> str:
"""Return a YAML representation of the given dictionary."""
return yaml.dump(dict_)
try:
return yaml.dump(dict_, Dumper=YamlDumper)
except yaml.representer.RepresenterError as e:
raise TypeError(
"Error dumping pipeline to YAML - Ensure that all pipeline "
"components only serialize basic Python types"
) from e

def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]:
"""Return a dictionary from the given YAML data."""
return yaml.load(data_, Loader=YamlLoader)
try:
return yaml.load(data_, Loader=YamlLoader)
except yaml.constructor.ConstructorError as e:
raise TypeError(
"Error loading pipeline from YAML - Ensure that all pipeline "
"components only serialize basic Python types"
) from e
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
features:
- |
Improved serialization/deserialization errors to provide extra context about the delinquent components when possible.
fixes:
- |
Serialized data of components are now explicitly enforced to be one of the following basic Python datatypes: str, int, float, bool, list, dict, set, tuple or None.
25 changes: 24 additions & 1 deletion test/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from haystack.core.pipeline import Pipeline
from haystack.core.component import component
from haystack.core.errors import DeserializationError
from haystack.core.errors import DeserializationError, SerializationError
from haystack.testing import factory
from haystack.core.serialization import (
default_to_dict,
default_from_dict,
generate_qualified_class_name,
import_class_by_name,
component_to_dict,
)


Expand Down Expand Up @@ -106,3 +107,25 @@ def test_import_class_by_name_no_valid_class():
data = "some.invalid.class"
with pytest.raises(ImportError):
import_class_by_name(data)


class CustomData:
def __init__(self, a: int, b: str) -> None:
self.a = a
self.b = b


@component()
class UnserializableClass:
def __init__(self, a: int, b: str, c: CustomData) -> None:
self.a = a
self.b = b
self.c = c

def run(self):
pass


def test_component_to_dict_invalid_type():
with pytest.raises(SerializationError, match="unsupported value of type 'CustomData'"):
component_to_dict(UnserializableClass(1, "s", CustomData(99, "aa")), "invalid_component")
18 changes: 18 additions & 0 deletions test/marshal/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@
from haystack.marshal.yaml import YamlMarshaller


class InvalidClass:
def __init__(self) -> None:
self.a = 1
self.b = None
self.c = "string"


@pytest.fixture
def yaml_data():
return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, None, True), "dict": {"set": {False}}}


@pytest.fixture
def invalid_yaml_data():
return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, InvalidClass(), True), "dict": {"set": {False}}}


@pytest.fixture
def serialized_yaml_str():
return """key: value
Expand All @@ -36,6 +48,12 @@ def test_yaml_marshal(yaml_data, serialized_yaml_str):
assert marshalled.strip().replace("\n", "") == serialized_yaml_str.strip().replace("\n", "")


def test_yaml_marshal_invalid_type(invalid_yaml_data):
with pytest.raises(TypeError, match="basic Python types"):
marshaller = YamlMarshaller()
marshalled = marshaller.marshal(invalid_yaml_data)


def test_yaml_unmarshal(yaml_data, serialized_yaml_str):
marshaller = YamlMarshaller()
unmarshalled = marshaller.unmarshal(serialized_yaml_str)
Expand Down

0 comments on commit 9061773

Please sign in to comment.