diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 62e0b578..df0f8b37 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,7 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Generic, Iterator, List, Mapping, Optional, Tuple, Union from airbyte_cdk.models import ( AirbyteCatalog, @@ -229,7 +229,7 @@ def _group_streams( declarative_stream.retriever.cursor = None partition_generator = CursorPartitionGenerator( - stream=declarative_stream, + stream_factory=self._new_stream_instance_factory(declarative_stream, config), message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory cursor=cursor, connector_state_converter=connector_state_converter, @@ -344,3 +344,15 @@ def _remove_concurrent_streams_from_catalog( if stream.stream.name not in concurrent_stream_names ] ) + + def _new_stream_instance_factory(self, stream_to_copy: Stream, config: Mapping[str, Any]) -> Callable[[], Stream]: + """ + Some of the declarative components are stateful. Therefore, we create one stream per thread in order to avoid threads updating + the same field for a specific instance. + """ + def _create_new_stream() -> Stream: + streams_with_same_name = list(filter(lambda stream: stream.name == stream_to_copy.name, self.streams(config))) + if len(streams_with_same_name) == 1: + return streams_with_same_name[0] + raise ValueError(f"Expected one stream with name `{stream_to_copy.name}` but got {len(streams_with_same_name)}") + return _create_new_stream diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index d4b539a5..b1c95f96 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -6,7 +6,7 @@ import json import logging from functools import lru_cache -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union from airbyte_cdk.models import ( AirbyteLogMessage, @@ -390,7 +390,7 @@ class CursorPartitionGenerator(PartitionGenerator): def __init__( self, - stream: Stream, + stream_factory: Callable[[], Stream], message_repository: MessageRepository, cursor: Cursor, connector_state_converter: DateTimeStreamStateConverter, @@ -400,12 +400,12 @@ def __init__( """ Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor. - :param stream: The stream to delegate to for partition generation. + :param stream_factory: The stream factory that created the stream to delegate to for partition generation. :param message_repository: The message repository to use to emit non-record messages. :param sync_mode: The synchronization mode. :param cursor: A Cursor object that maintains the state and the cursor field. """ - self._stream = stream + self._stream_factory = stream_factory self.message_repository = message_repository self._sync_mode = SyncMode.full_refresh self._cursor = cursor @@ -445,7 +445,7 @@ def generate(self) -> Iterable[Partition]: ) yield StreamPartition( - self._stream, + self._stream_factory(), copy.deepcopy(stream_slice), self.message_repository, self._sync_mode, diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index cbebfe7c..3dbb8db2 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -448,7 +448,7 @@ def test_get_error_display_message(exception, expected_display_message): def test_cursor_partition_generator(): - stream = Mock() + stream_factory = Mock() cursor = Mock() message_repository = Mock() connector_state_converter = CustomFormatConcurrentStreamStateConverter( @@ -468,7 +468,7 @@ def test_cursor_partition_generator(): ] partition_generator = CursorPartitionGenerator( - stream, + stream_factory, message_repository, cursor, connector_state_converter,