Skip to content

Commit

Permalink
Auto-fix lint and format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
octavia-squidington-iii committed Dec 18, 2024
1 parent 79ffb77 commit a36726b
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 63 deletions.
23 changes: 12 additions & 11 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,24 +311,25 @@ def _group_streams(
declarative_stream=declarative_stream
)
and hasattr(declarative_stream.retriever, "stream_slicer")
and isinstance(declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor)
and isinstance(
declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor
)
):
stream_state = state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
partition_router = declarative_stream.retriever.stream_slicer._partition_router

cursor = self._constructor.create_concurrent_cursor_from_perpartition_cursor(
state_manager=state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=incremental_sync_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
partition_router=partition_router,
)

state_manager=state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=incremental_sync_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
partition_router=partition_router,
)

partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import copy
import logging

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import threading
import logging
from collections import OrderedDict
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import iterate_with_last_flag_and_state, Timer
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
Timer,
iterate_with_last_flag_and_state,
)
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import (
Expand Down Expand Up @@ -123,26 +126,43 @@ def state(self) -> MutableMapping[str, Any]:

def close_partition(self, partition: Partition) -> None:
print(f"Closing partition {self._to_partition_key(partition._stream_slice.partition)}")
self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition(partition=partition)
with (self._lock):
self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)].acquire()
cursor = self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)]
self._cursor_per_partition[
self._to_partition_key(partition._stream_slice.partition)
].close_partition(partition=partition)
with self._lock:
self._semaphore_per_partition[
self._to_partition_key(partition._stream_slice.partition)
].acquire()
cursor = self._cursor_per_partition[
self._to_partition_key(partition._stream_slice.partition)
]
cursor_state = cursor._connector_state_converter.convert_to_state_message(
cursor._cursor_field, cursor.state
)
print(f"State {cursor_state} {cursor.state}")
if self._to_partition_key(partition._stream_slice.partition) in self._finished_partitions \
and self._semaphore_per_partition[self._to_partition_key(partition._stream_slice.partition)]._value == 0:
if self._new_global_cursor is None \
or self._new_global_cursor[self.cursor_field.cursor_field_key] < cursor_state[self.cursor_field.cursor_field_key]:
if (
self._to_partition_key(partition._stream_slice.partition)
in self._finished_partitions
and self._semaphore_per_partition[
self._to_partition_key(partition._stream_slice.partition)
]._value
== 0
):
if (
self._new_global_cursor is None
or self._new_global_cursor[self.cursor_field.cursor_field_key]
< cursor_state[self.cursor_field.cursor_field_key]
):
self._new_global_cursor = copy.deepcopy(cursor_state)

def ensure_at_least_one_state_emitted(self) -> None:
"""
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
"""
if not any(semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()):
if not any(
semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()
):
self._global_cursor = self._new_global_cursor
self._lookback_window = self._timer.finish()
self._parent_state = self._partition_router.get_stream_state()
Expand All @@ -159,7 +179,6 @@ def _emit_state_message(self) -> None:
)
self._message_repository.emit_message(state_message)


def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
self._timer.start()
Expand All @@ -179,11 +198,13 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str
)
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = threading.Semaphore(0)
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
threading.Semaphore(0)
)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
if is_last_slice:
Expand Down Expand Up @@ -251,7 +272,9 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
self._create_cursor(state["cursor"])
)
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = threading.Semaphore(0)
self._semaphore_per_partition[self._to_partition_key(state["partition"])] = (
threading.Semaphore(0)
)

# set default state for missing partitions if it is per partition with fallback to global
if "state" in stream_state:
Expand All @@ -262,7 +285,9 @@ def _set_initial_state(self, stream_state: StreamState) -> None:

def observe(self, record: Record) -> None:
print(self._to_partition_key(record.associated_slice.partition), record)
self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record)
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,20 @@ def get_request_body_json(
raise ValueError("A partition needs to be provided in order to get request body json")

def should_be_synced(self, record: Record) -> bool:
if self._to_partition_key(record.associated_slice.partition) not in self._cursor_per_partition:
if (
self._to_partition_key(record.associated_slice.partition)
not in self._cursor_per_partition
):
partition_state = (
self._state_to_migrate_from
if self._state_to_migrate_from
else self._NO_CURSOR_STATE
)
cursor = self._create_cursor(partition_state)

self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)] = cursor
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
] = cursor
return self._get_cursor(record).should_be_synced(
self._convert_record_to_cursor_record(record)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
config: Config,
stream_state: MutableMapping[str, Any],
partition_router,
**kwargs: Any,
**kwargs: Any,
) -> ConcurrentPerPartitionCursor:
component_type = component_definition.get("type")
if component_definition.get("type") != model_type.__name__:
Expand Down Expand Up @@ -1000,21 +1000,21 @@ def create_concurrent_cursor_from_perpartition_cursor(
stream_name=stream_name,
stream_namespace=stream_namespace,
config=config,
message_repository=NoopMessageRepository()
message_repository=NoopMessageRepository(),
)
)

# Return the concurrent cursor and state converter
return ConcurrentPerPartitionCursor(
cursor_factory=cursor_factory,
partition_router=partition_router,
stream_name=stream_name,
stream_namespace=stream_namespace,
stream_state=stream_state,
message_repository=self._message_repository, # type: ignore
connector_state_manager=state_manager,
cursor_field=cursor_field,
)
cursor_factory=cursor_factory,
partition_router=partition_router,
stream_name=stream_name,
stream_namespace=stream_namespace,
stream_state=stream_state,
message_repository=self._message_repository, # type: ignore
connector_state_manager=state_manager,
cursor_field=cursor_field,
)

@staticmethod
def create_constant_backoff_strategy(
Expand Down Expand Up @@ -1298,15 +1298,15 @@ def create_declarative_stream(
raise ValueError(
"Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead"
)
cursor = combined_slicers if isinstance(
combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor)
) else self._create_component_from_model(
model=model.incremental_sync, config=config
cursor = (
combined_slicers
if isinstance(
combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor)
)
else self._create_component_from_model(model=model.incremental_sync, config=config)
)

client_side_incremental_sync = {
"cursor": cursor
}
client_side_incremental_sync = {"cursor": cursor}

if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel):
cursor_model = model.incremental_sync
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
stream_slice,
)


class DeclarativePartition(Partition):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

import copy
from typing import Any, List, Mapping, MutableMapping, Optional, Union
from unittest.mock import MagicMock
import copy

import pytest
import requests_mock
Expand Down Expand Up @@ -264,9 +264,7 @@ def _run_read(
source = ConcurrentDeclarativeSource(
source_config=manifest, config=config, catalog=catalog, state=state
)
messages = list(
source.read(logger=source.logger, config=config, catalog=catalog, state=[])
)
messages = list(source.read(logger=source.logger, config=config, catalog=catalog, state=[]))
return messages


Expand Down Expand Up @@ -474,11 +472,9 @@ def _run_read(
"cursor": {"created_at": "2024-01-10T00:00:00Z"},
},
],
'lookback_window': 1,
'parent_state': {},
'state': {'created_at': '2024-01-15T00:00:00Z'}
"lookback_window": 1,
"parent_state": {},
"state": {"created_at": "2024-01-15T00:00:00Z"},
},
),
],
Expand Down Expand Up @@ -520,7 +516,9 @@ def test_incremental_parent_state_no_incremental_dependency(
output = _run_read(manifest, config, _stream_name, initial_state)
output_data = [message.record.data for message in output if message.record]

assert set(tuple(sorted(d.items())) for d in output_data) == set(tuple(sorted(d.items())) for d in expected_records)
assert set(tuple(sorted(d.items())) for d in output_data) == set(
tuple(sorted(d.items())) for d in expected_records
)
final_state = [
orjson.loads(orjson.dumps(message.state.stream.stream_state))
for message in output
Expand Down Expand Up @@ -565,8 +563,9 @@ def run_incremental_parent_state_test(
output_data = [message.record.data for message in output if message.record]

# Assert that output_data equals expected_records
assert (sorted(output_data, key=lambda x: orjson.dumps(x))
== sorted(expected_records, key=lambda x: orjson.dumps(x)))
assert sorted(output_data, key=lambda x: orjson.dumps(x)) == sorted(
expected_records, key=lambda x: orjson.dumps(x)
)

# Collect the intermediate states and records produced before each state
cumulative_records = []
Expand Down Expand Up @@ -884,8 +883,8 @@ def run_incremental_parent_state_test(
"cursor": {"created_at": "2024-01-13T00:00:00Z"},
},
{
'partition': {'id': 12, 'parent_slice': {'id': 1, 'parent_slice': {}}},
'cursor': {'created_at': '2024-01-01T00:00:01Z'},
"partition": {"id": 12, "parent_slice": {"id": 1, "parent_slice": {}}},
"cursor": {"created_at": "2024-01-01T00:00:01Z"},
},
{
"partition": {"id": 20, "parent_slice": {"id": 2, "parent_slice": {}}},
Expand Down Expand Up @@ -1141,7 +1140,9 @@ def test_incremental_parent_state_migration(
output = _run_read(manifest, config, _stream_name, initial_state)
output_data = [message.record.data for message in output if message.record]

assert set(tuple(sorted(d.items())) for d in output_data) == set(tuple(sorted(d.items())) for d in expected_records)
assert set(tuple(sorted(d.items())) for d in output_data) == set(
tuple(sorted(d.items())) for d in expected_records
)
final_state = [
orjson.loads(orjson.dumps(message.state.stream.stream_state))
for message in output
Expand Down

0 comments on commit a36726b

Please sign in to comment.