Skip to content

Commit

Permalink
fix: fix deserialization issues in multi-threading environments (#8651)
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge authored and julian-risch committed Dec 18, 2024
1 parent 55e7fbf commit e734e6b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
5 changes: 2 additions & 3 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import importlib
import itertools
from collections import defaultdict
from copy import deepcopy
Expand All @@ -26,7 +25,7 @@
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.marshal import Marshaller, YamlMarshaller
from haystack.utils import is_in_jupyter
from haystack.utils import is_in_jupyter, type_serialization

from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _to_mermaid_image
Expand Down Expand Up @@ -161,7 +160,7 @@ def from_dict(
# Import the module first...
module, _ = component_data["type"].rsplit(".", 1)
logger.debug("Trying to import module {module_name}", module_name=module)
importlib.import_module(module)
type_serialization.thread_safe_import(module)
# ...then try again
if component_data["type"] not in component.registry:
raise PipelineError(
Expand Down
20 changes: 19 additions & 1 deletion haystack/utils/type_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import inspect
import sys
import typing
from threading import Lock
from types import ModuleType
from typing import Any, get_args, get_origin

from haystack import DeserializationError

_import_lock = Lock()


def serialize_type(target: Any) -> str:
"""
Expand Down Expand Up @@ -132,7 +136,7 @@ def parse_generic_args(args_str):
module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
module = thread_safe_import(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e

Expand All @@ -141,3 +145,17 @@ def parse_generic_args(args_str):
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")

return deserialized_type


def thread_safe_import(module_name: str) -> ModuleType:
"""
Import a module in a thread-safe manner.
Importing modules in a multi-threaded environment can lead to race conditions.
This function ensures that the module is imported in a thread-safe manner without having impact
on the performance of the import for single-threaded environments.
:param module_name: the module to import
"""
with _import_lock:
return importlib.import_module(module_name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixes issues with deserialization of components in multi-threaded environments.

0 comments on commit e734e6b

Please sign in to comment.