From dcecb978e1fb97bdc82da24961e61dd5055ab30e Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Tue, 17 Dec 2024 15:19:51 +0800 Subject: [PATCH] feat: use importlib when deserializing callables --- haystack/utils/callable_serialization.py | 15 ++++++++------- ...n-deserializing-callable-1f36f07c4518c2cf.yaml | 6 ++++++ test/utils/test_callable_serialization.py | 9 +++++++++ 3 files changed, 23 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 7815712e7c..b9717cef6c 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -import sys +from importlib import import_module from typing import Callable, Optional from haystack import DeserializationError @@ -37,10 +37,11 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]: parts = callable_handle.split(".") module_name = ".".join(parts[:-1]) function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - raise DeserializationError(f"Could not locate the module of the callable: {module_name}") - streaming_callback = getattr(module, function_name, None) - if not streaming_callback: + try: + module = import_module(module_name, None) + except Exception as e: + raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e + deserialized_callable = getattr(module, function_name, None) + if not deserialized_callable: raise DeserializationError(f"Could not locate the callable: {function_name}") - return streaming_callback + return deserialized_callable diff --git a/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml b/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml new file mode 100644 index 0000000000..6780097ccd --- /dev/null +++ b/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Improved performance of deserializing callables by using `importlib` instead of `sys.modules`. + This change allows importing local functions and classes that are not in `sys.modules` + when deserializing callables. diff --git a/test/utils/test_callable_serialization.py b/test/utils/test_callable_serialization.py index c0afafa73e..941aa14cdf 100644 --- a/test/utils/test_callable_serialization.py +++ b/test/utils/test_callable_serialization.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import pytest import requests +from haystack import DeserializationError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import serialize_callable, deserialize_callable @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local(): result = serialize_callable(requests.api.get) fn = deserialize_callable(result) assert fn is requests.api.get + + +def test_callable_deserialization_error(): + with pytest.raises(DeserializationError): + deserialize_callable("this.is.not.a.valid.module") + with pytest.raises(DeserializationError): + deserialize_callable("sys.foobar")