diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 31ad2ad93c..d8f2a65932 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import importlib import itertools from collections import defaultdict from copy import deepcopy @@ -26,7 +25,7 @@ from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict from haystack.core.type_utils import _type_name, _types_are_compatible from haystack.marshal import Marshaller, YamlMarshaller -from haystack.utils import is_in_jupyter +from haystack.utils import is_in_jupyter, type_serialization from .descriptions import find_pipeline_inputs, find_pipeline_outputs from .draw import _to_mermaid_image @@ -161,7 +160,7 @@ def from_dict( # Import the module first... module, _ = component_data["type"].rsplit(".", 1) logger.debug("Trying to import module {module_name}", module_name=module) - importlib.import_module(module) + type_serialization.thread_safe_import(module) # ...then try again if component_data["type"] not in component.registry: raise PipelineError( diff --git a/haystack/utils/type_serialization.py b/haystack/utils/type_serialization.py index b2dd319d52..5ffb505bb1 100644 --- a/haystack/utils/type_serialization.py +++ b/haystack/utils/type_serialization.py @@ -6,10 +6,14 @@ import inspect import sys import typing +from threading import Lock +from types import ModuleType from typing import Any, get_args, get_origin from haystack import DeserializationError +_import_lock = Lock() + def serialize_type(target: Any) -> str: """ @@ -132,7 +136,7 @@ def parse_generic_args(args_str): module = sys.modules.get(module_name) if not module: try: - module = importlib.import_module(module_name) + module = thread_safe_import(module_name) except ImportError as e: raise DeserializationError(f"Could not import the module: {module_name}") from e @@ -141,3 +145,17 @@ def parse_generic_args(args_str): raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") return deserialized_type + + +def thread_safe_import(module_name: str) -> ModuleType: + """ + Import a module in a thread-safe manner. + + Importing modules in a multi-threaded environment can lead to race conditions. + This function ensures that the module is imported in a thread-safe manner without having impact + on the performance of the import for single-threaded environments. + + :param module_name: the module to import + """ + with _import_lock: + return importlib.import_module(module_name) diff --git a/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml b/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml new file mode 100644 index 0000000000..3f1a0a2e78 --- /dev/null +++ b/releasenotes/notes/thread-safe-module-import-ed04ad216820ab85.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes issues with deserialization of components in multi-threaded environments.