diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index dc99c414..1b607870 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -311,7 +311,9 @@ def _group_streams( declarative_stream=declarative_stream ) and hasattr(declarative_stream.retriever, "stream_slicer") - and isinstance(declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) + and isinstance( + declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor + ) ): stream_state = state_manager.get_stream_state( stream_name=declarative_stream.name, namespace=declarative_stream.namespace @@ -319,16 +321,15 @@ def _group_streams( partition_router = declarative_stream.retriever.stream_slicer._partition_router cursor = self._constructor.create_concurrent_cursor_from_perpartition_cursor( - state_manager=state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=incremental_sync_component_definition, - stream_name=declarative_stream.name, - stream_namespace=declarative_stream.namespace, - config=config or {}, - stream_state=stream_state, - partition_router=partition_router, - ) - + state_manager=state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=incremental_sync_component_definition, + stream_name=declarative_stream.name, + stream_namespace=declarative_stream.namespace, + config=config or {}, + stream_state=stream_state, + partition_router=partition_router, + ) partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index 2e1da77f..e43134f0 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -1,16 +1,19 @@ import copy +import logging # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import threading -import logging from collections import OrderedDict from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager -from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import iterate_with_last_flag_and_state, Timer from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor +from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( + Timer, + iterate_with_last_flag_and_state, +) from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import ( @@ -123,18 +126,33 @@ def state(self) -> MutableMapping[str, Any]: def close_partition(self, partition: Partition) -> None: print(f"Closing partition {self._to_partition_key(partition._stream_slice.partition)}") - self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition(partition=partition) - with (self._lock): - self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)].acquire() - cursor = self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)] + self._cursor_per_partition[ + self._to_partition_key(partition._stream_slice.partition) + ].close_partition(partition=partition) + with self._lock: + self._semaphore_per_partition[ + self._to_partition_key(partition._stream_slice.partition) + ].acquire() + cursor = self._cursor_per_partition[ + self._to_partition_key(partition._stream_slice.partition) + ] cursor_state = cursor._connector_state_converter.convert_to_state_message( cursor._cursor_field, cursor.state ) print(f"State {cursor_state} {cursor.state}") - if self._to_partition_key(partition._stream_slice.partition) in self._finished_partitions \ - and self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)]._value == 0: - if self._new_global_cursor is None \ - or self._new_global_cursor[self.cursor_field.cursor_field_key] < cursor_state[self.cursor_field.cursor_field_key]: + if ( + self._to_partition_key(partition._stream_slice.partition) + in self._finished_partitions + and self._semaphore_per_partition[ + self._to_partition_key(partition._stream_slice.partition) + ]._value + == 0 + ): + if ( + self._new_global_cursor is None + or self._new_global_cursor[self.cursor_field.cursor_field_key] + < cursor_state[self.cursor_field.cursor_field_key] + ): self._new_global_cursor = copy.deepcopy(cursor_state) def ensure_at_least_one_state_emitted(self) -> None: @@ -142,7 +160,9 @@ def ensure_at_least_one_state_emitted(self) -> None: The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be called. """ - if not any(semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()): + if not any( + semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items() + ): self._global_cursor = self._new_global_cursor self._lookback_window = self._timer.finish() self._parent_state = self._partition_router.get_stream_state() @@ -159,7 +179,6 @@ def _emit_state_message(self) -> None: ) self._message_repository.emit_message(state_message) - def stream_slices(self) -> Iterable[StreamSlice]: slices = self._partition_router.stream_slices() self._timer.start() @@ -179,11 +198,13 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str ) cursor = self._create_cursor(partition_state) self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor - self._semaphore_per_partition[self._to_partition_key(partition.partition)] = threading.Semaphore(0) + self._semaphore_per_partition[self._to_partition_key(partition.partition)] = ( + threading.Semaphore(0) + ) for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state( - cursor.stream_slices(), - lambda: None, + cursor.stream_slices(), + lambda: None, ): self._semaphore_per_partition[self._to_partition_key(partition.partition)].release() if is_last_slice: @@ -251,7 +272,9 @@ def _set_initial_state(self, stream_state: StreamState) -> None: self._cursor_per_partition[self._to_partition_key(state["partition"])] = ( self._create_cursor(state["cursor"]) ) - self._semaphore_per_partition[self._to_partition_key(state["partition"])] = threading.Semaphore(0) + self._semaphore_per_partition[self._to_partition_key(state["partition"])] = ( + threading.Semaphore(0) + ) # set default state for missing partitions if it is per partition with fallback to global if "state" in stream_state: @@ -262,7 +285,9 @@ def _set_initial_state(self, stream_state: StreamState) -> None: def observe(self, record: Record) -> None: print(self._to_partition_key(record.associated_slice.partition), record) - self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record) + self._cursor_per_partition[ + self._to_partition_key(record.associated_slice.partition) + ].observe(record) def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index d7322709..1529e90e 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -303,7 +303,10 @@ def get_request_body_json( raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: - if self._to_partition_key(record.associated_slice.partition) not in self._cursor_per_partition: + if ( + self._to_partition_key(record.associated_slice.partition) + not in self._cursor_per_partition + ): partition_state = ( self._state_to_migrate_from if self._state_to_migrate_from @@ -311,7 +314,9 @@ def should_be_synced(self, record: Record) -> bool: ) cursor = self._create_cursor(partition_state) - self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)] = cursor + self._cursor_per_partition[ + self._to_partition_key(record.associated_slice.partition) + ] = cursor return self._get_cursor(record).should_be_synced( self._convert_record_to_cursor_record(record) ) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index d210b475..2decfbd4 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -969,7 +969,7 @@ def create_concurrent_cursor_from_perpartition_cursor( config: Config, stream_state: MutableMapping[str, Any], partition_router, - **kwargs: Any, + **kwargs: Any, ) -> ConcurrentPerPartitionCursor: component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: @@ -1000,21 +1000,21 @@ def create_concurrent_cursor_from_perpartition_cursor( stream_name=stream_name, stream_namespace=stream_namespace, config=config, - message_repository=NoopMessageRepository() + message_repository=NoopMessageRepository(), ) ) # Return the concurrent cursor and state converter return ConcurrentPerPartitionCursor( - cursor_factory=cursor_factory, - partition_router=partition_router, - stream_name=stream_name, - stream_namespace=stream_namespace, - stream_state=stream_state, - message_repository=self._message_repository, # type: ignore - connector_state_manager=state_manager, - cursor_field=cursor_field, - ) + cursor_factory=cursor_factory, + partition_router=partition_router, + stream_name=stream_name, + stream_namespace=stream_namespace, + stream_state=stream_state, + message_repository=self._message_repository, # type: ignore + connector_state_manager=state_manager, + cursor_field=cursor_field, + ) @staticmethod def create_constant_backoff_strategy( @@ -1298,15 +1298,15 @@ def create_declarative_stream( raise ValueError( "Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead" ) - cursor = combined_slicers if isinstance( - combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) - ) else self._create_component_from_model( - model=model.incremental_sync, config=config + cursor = ( + combined_slicers + if isinstance( + combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) + ) + else self._create_component_from_model(model=model.incremental_sync, config=config) ) - client_side_incremental_sync = { - "cursor": cursor - } + client_side_incremental_sync = {"cursor": cursor} if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel): cursor_model = model.incremental_sync diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 09ed2bc8..31f6377f 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -38,6 +38,7 @@ def create(self, stream_slice: StreamSlice) -> Partition: stream_slice, ) + class DeclarativePartition(Partition): def __init__( self, diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index ab713a40..57948d79 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +import copy from typing import Any, List, Mapping, MutableMapping, Optional, Union from unittest.mock import MagicMock -import copy import pytest import requests_mock @@ -264,9 +264,7 @@ def _run_read( source = ConcurrentDeclarativeSource( source_config=manifest, config=config, catalog=catalog, state=state ) - messages = list( - source.read(logger=source.logger, config=config, catalog=catalog, state=[]) - ) + messages = list(source.read(logger=source.logger, config=config, catalog=catalog, state=[])) return messages @@ -474,11 +472,9 @@ def _run_read( "cursor": {"created_at": "2024-01-10T00:00:00Z"}, }, ], - 'lookback_window': 1, - 'parent_state': {}, - 'state': {'created_at': '2024-01-15T00:00:00Z'} - - + "lookback_window": 1, + "parent_state": {}, + "state": {"created_at": "2024-01-15T00:00:00Z"}, }, ), ], @@ -520,7 +516,9 @@ def test_incremental_parent_state_no_incremental_dependency( output = _run_read(manifest, config, _stream_name, initial_state) output_data = [message.record.data for message in output if message.record] - assert set(tuple(sorted(d.items())) for d in output_data) == set(tuple(sorted(d.items())) for d in expected_records) + assert set(tuple(sorted(d.items())) for d in output_data) == set( + tuple(sorted(d.items())) for d in expected_records + ) final_state = [ orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output @@ -565,8 +563,9 @@ def run_incremental_parent_state_test( output_data = [message.record.data for message in output if message.record] # Assert that output_data equals expected_records - assert (sorted(output_data, key=lambda x: orjson.dumps(x)) - == sorted(expected_records, key=lambda x: orjson.dumps(x))) + assert sorted(output_data, key=lambda x: orjson.dumps(x)) == sorted( + expected_records, key=lambda x: orjson.dumps(x) + ) # Collect the intermediate states and records produced before each state cumulative_records = [] @@ -884,8 +883,8 @@ def run_incremental_parent_state_test( "cursor": {"created_at": "2024-01-13T00:00:00Z"}, }, { - 'partition': {'id': 12, 'parent_slice': {'id': 1, 'parent_slice': {}}}, - 'cursor': {'created_at': '2024-01-01T00:00:01Z'}, + "partition": {"id": 12, "parent_slice": {"id": 1, "parent_slice": {}}}, + "cursor": {"created_at": "2024-01-01T00:00:01Z"}, }, { "partition": {"id": 20, "parent_slice": {"id": 2, "parent_slice": {}}}, @@ -1141,7 +1140,9 @@ def test_incremental_parent_state_migration( output = _run_read(manifest, config, _stream_name, initial_state) output_data = [message.record.data for message in output if message.record] - assert set(tuple(sorted(d.items())) for d in output_data) == set(tuple(sorted(d.items())) for d in expected_records) + assert set(tuple(sorted(d.items())) for d in output_data) == set( + tuple(sorted(d.items())) for d in expected_records + ) final_state = [ orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output