Skip to content

Commit

Permalink
Instantiate one straem per thread
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Nov 11, 2024
1 parent 734d778 commit 09fc5d7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
16 changes: 14 additions & 2 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Generic, Iterator, List, Mapping, Optional, Tuple, Union

from airbyte_cdk.models import (
AirbyteCatalog,
Expand Down Expand Up @@ -229,7 +229,7 @@ def _group_streams(
declarative_stream.retriever.cursor = None

partition_generator = CursorPartitionGenerator(
stream=declarative_stream,
stream_factory=self._new_stream_instance_factory(declarative_stream, config),
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
cursor=cursor,
connector_state_converter=connector_state_converter,
Expand Down Expand Up @@ -344,3 +344,15 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

def _new_stream_instance_factory(self, stream_to_copy: Stream, config: Mapping[str, Any]) -> Callable[[], Stream]:
"""
Some of the declarative components are stateful. Therefore, we create one stream per thread in order to avoid threads updating
the same field for a specific instance.
"""
def _create_new_stream() -> Stream:
streams_with_same_name = list(filter(lambda stream: stream.name == stream_to_copy.name, self.streams(config)))
if len(streams_with_same_name) == 1:
return streams_with_same_name[0]
raise ValueError(f"Expected one stream with name `{stream_to_copy.name}` but got {len(streams_with_same_name)}")
return _create_new_stream
10 changes: 5 additions & 5 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging
from functools import lru_cache
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import (
AirbyteLogMessage,
Expand Down Expand Up @@ -390,7 +390,7 @@ class CursorPartitionGenerator(PartitionGenerator):

def __init__(
self,
stream: Stream,
stream_factory: Callable[[], Stream],
message_repository: MessageRepository,
cursor: Cursor,
connector_state_converter: DateTimeStreamStateConverter,
Expand All @@ -400,12 +400,12 @@ def __init__(
"""
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.
:param stream: The stream to delegate to for partition generation.
:param stream_factory: The stream factory that created the stream to delegate to for partition generation.
:param message_repository: The message repository to use to emit non-record messages.
:param sync_mode: The synchronization mode.
:param cursor: A Cursor object that maintains the state and the cursor field.
"""
self._stream = stream
self._stream_factory = stream_factory
self.message_repository = message_repository
self._sync_mode = SyncMode.full_refresh
self._cursor = cursor
Expand Down Expand Up @@ -445,7 +445,7 @@ def generate(self) -> Iterable[Partition]:
)

yield StreamPartition(
self._stream,
self._stream_factory(),
copy.deepcopy(stream_slice),
self.message_repository,
self._sync_mode,
Expand Down
4 changes: 2 additions & 2 deletions unit_tests/sources/streams/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def test_get_error_display_message(exception, expected_display_message):


def test_cursor_partition_generator():
stream = Mock()
stream_factory = Mock()
cursor = Mock()
message_repository = Mock()
connector_state_converter = CustomFormatConcurrentStreamStateConverter(
Expand All @@ -468,7 +468,7 @@ def test_cursor_partition_generator():
]

partition_generator = CursorPartitionGenerator(
stream,
stream_factory,
message_repository,
cursor,
connector_state_converter,
Expand Down

0 comments on commit 09fc5d7

Please sign in to comment.