Skip to content

Commit

Permalink
feat: use importlib when deserializing callables
Browse files Browse the repository at this point in the history
  • Loading branch information
LastRemote committed Dec 17, 2024
1 parent 3b9a60b commit dcecb97
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
15 changes: 8 additions & 7 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
import sys
from importlib import import_module
from typing import Callable, Optional

from haystack import DeserializationError
Expand Down Expand Up @@ -37,10 +37,11 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]:
parts = callable_handle.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
try:
module = import_module(module_name, None)
except Exception as e:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
deserialized_callable = getattr(module, function_name, None)
if not deserialized_callable:
raise DeserializationError(f"Could not locate the callable: {function_name}")
return streaming_callback
return deserialized_callable
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Improved performance of deserializing callables by using `importlib` instead of `sys.modules`.
This change allows importing local functions and classes that are not in `sys.modules`
when deserializing callables.
9 changes: 9 additions & 0 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests

from haystack import DeserializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils import serialize_callable, deserialize_callable

Expand Down Expand Up @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local():
result = serialize_callable(requests.api.get)
fn = deserialize_callable(result)
assert fn is requests.api.get


def test_callable_deserialization_error():
with pytest.raises(DeserializationError):
deserialize_callable("this.is.not.a.valid.module")
with pytest.raises(DeserializationError):
deserialize_callable("sys.foobar")

0 comments on commit dcecb97

Please sign in to comment.