Skip to content

Commit

Permalink
refactor async retriever to use a AsyncJobStreamPartitionRouter to fo…
Browse files Browse the repository at this point in the history
…llow a more standard low-code pattern
  • Loading branch information
brianjlai committed Dec 13, 2024
1 parent 580f60e commit a5321da
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@
SinglePartitionRouter,
SubstreamPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
ParentStreamConfig,
)
Expand Down Expand Up @@ -2146,22 +2149,28 @@ def create_async_retriever(
urls_extractor=urls_extractor,
)

return AsyncRetriever(
async_job_partition_router = AsyncJobPartitionRouter(
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
job_repository,
stream_slices,
JobTracker(
1
), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
JobTracker(1),
# FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
self._message_repository,
has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
has_bulk_parent=False,
# FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
),
record_selector=record_selector,
stream_slicer=stream_slicer,
config=config,
parameters=model.parameters or {},
)

return AsyncRetriever(
record_selector=record_selector,
stream_slicer=async_job_partition_router,
config=config,
parameters=model.parameters or {},
)

@staticmethod
def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
return Spec(
Expand Down
10 changes: 9 additions & 1 deletion airbyte_cdk/sources/declarative/partition_routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import AsyncJobPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import CartesianProductStreamSlicer
from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter

__all__ = ["CartesianProductStreamSlicer", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", "PartitionRouter"]
__all__ = [
"AsyncJobPartitionRouter",
"CartesianProductStreamSlicer",
"ListPartitionRouter",
"SinglePartitionRouter",
"SubstreamPartitionRouter",
"PartitionRouter"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
AsyncJobOrchestrator,
AsyncPartition,
)
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
SinglePartitionRouter,
)
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import Config, StreamSlice
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


@dataclass
class AsyncJobPartitionRouter(StreamSlicer):
"""
Partition router that creates async jobs in a source API, periodically polls for job
completion, and supplies the completed job URL locations as stream slices so that
records can be extracted.
"""

config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self.stream_slicer.stream_slices()
self._job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)

def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
if not self._job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self._job_orchestrator.fetch_records(partition=partition)
39 changes: 8 additions & 31 deletions airbyte_cdk/sources/declarative/retrievers/async_retriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.


from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Iterable, Mapping, Optional
from dataclasses import InitVar, dataclass
from typing import Any, Iterable, Mapping, Optional

from typing_extensions import deprecated

Expand All @@ -12,9 +12,10 @@
AsyncPartition,
)
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
Expand All @@ -29,15 +30,10 @@
class AsyncRetriever(Retriever):
config: Config
parameters: InitVar[Mapping[str, Any]]
job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator]
record_selector: RecordSelector
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)
stream_slicer: AsyncJobPartitionRouter

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None
self._parameters = parameters

@property
Expand All @@ -54,17 +50,6 @@ def state(self, value: StreamState) -> None:
"""
pass

@property
def _job_orchestrator(self) -> AsyncJobOrchestrator:
if not self.__job_orchestrator:
raise AirbyteTracedException(
message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support",
internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`",
failure_type=FailureType.system_error,
)

return self.__job_orchestrator

def _get_stream_state(self) -> StreamState:
"""
Gets the current state of the stream.
Expand Down Expand Up @@ -99,15 +84,7 @@ def _validate_and_get_stream_slice_partition(
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices

def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
slices = self.stream_slicer.stream_slices()
self.__job_orchestrator = self._job_orchestrator_factory(slices)

for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
yield StreamSlice(
partition=dict(completed_partition.stream_slice.partition)
| {"partition": completed_partition},
cursor_slice=completed_partition.stream_slice.cursor_slice,
)
return self.stream_slicer.stream_slices()

def read_records(
self,
Expand All @@ -116,7 +93,7 @@ def read_records(
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
records: Iterable[Mapping[str, Any]] = self._job_orchestrator.fetch_records(partition)
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)

yield from self.record_selector.filter_and_transform(
all_data=records,
Expand Down
19 changes: 13 additions & 6 deletions unit_tests/sources/declarative/async_job/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
AsyncJobPartitionRouter,
)
from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
Expand Down Expand Up @@ -79,12 +82,16 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
config={},
parameters={},
record_selector=noop_record_selector,
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
stream_slicer=AsyncJobPartitionRouter(
stream_slicer=self._stream_slicer,
job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator(
MockAsyncJobRepository(),
stream_slices,
JobTracker(_NO_LIMIT),
self._message_repository,
),
config={},
parameters={},
),
),
config={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from airbyte_cdk import AirbyteTracedException
from airbyte_cdk.models import FailureType, Level
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator
from airbyte_cdk.sources.declarative.auth.token import (
ApiKeyAuthenticator,
Expand All @@ -40,6 +41,7 @@
ResumableFullRefreshCursor,
)
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.models import AsyncRetriever as AsyncRetrieverModel
from airbyte_cdk.sources.declarative.models import CheckStream as CheckStreamModel
from airbyte_cdk.sources.declarative.models import (
CompositeErrorHandler as CompositeErrorHandlerModel,
Expand Down Expand Up @@ -85,6 +87,7 @@
ModelToComponentFactory,
)
from airbyte_cdk.sources.declarative.partition_routers import (
AsyncJobPartitionRouter,
CartesianProductStreamSlicer,
ListPartitionRouter,
SinglePartitionRouter,
Expand All @@ -102,6 +105,7 @@
WaitTimeFromHeaderBackoffStrategy,
WaitUntilTimeFromHeaderBackoffStrategy,
)
from airbyte_cdk.sources.declarative.requesters.http_job_repository import AsyncHttpJobRepository
from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.strategies import (
CursorPaginationStrategy,
Expand All @@ -121,6 +125,7 @@
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers import (
AsyncRetriever,
SimpleRetriever,
SimpleRetrieverTestReadDecorator,
)
Expand All @@ -138,6 +143,7 @@
from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import (
SingleUseRefreshTokenOauth2Authenticator,
)
from airbyte_cdk.sources.types import StreamSlice
from unit_tests.sources.declarative.parsers.testing_components import (
TestingCustomSubstreamPartitionRouter,
TestingSomeComponent,
Expand Down Expand Up @@ -3294,3 +3300,97 @@ def test_create_custom_record_extractor():
}
component = factory.create_component(CustomRecordExtractorModel, definition, {})
assert isinstance(component, CustomRecordExtractor)


def test_create_async_retriever():
config = {"api_key": "123"}

definition = {
"type": "AsyncRetriever",
"status_mapping": {
"failed": ["failed"],
"running": ["pending"],
"timeout": ["timeout"],
"completed": ["ready"],
},
"urls_extractor": {"type": "DpathExtractor", "field_path": ["urls"]},
"record_selector": {
"type": "RecordSelector",
"extractor": {"type": "DpathExtractor", "field_path": ["data"]},
},
"status_extractor": {"type": "DpathExtractor", "field_path": ["status"]},
"polling_requester": {
"type": "HttpRequester",
"path": "/v3/marketing/contacts/exports/{{stream_slice['create_job_response'].json()['id'] }}",
"url_base": "https://api.sendgrid.com",
"http_method": "GET",
"authenticator": {
"type": "BearerAuthenticator",
"api_token": "{{ config['api_key'] }}",
},
},
"creation_requester": {
"type": "HttpRequester",
"path": "/v3/marketing/contacts/exports",
"url_base": "https://api.sendgrid.com",
"http_method": "POST",
"authenticator": {
"type": "BearerAuthenticator",
"api_token": "{{ config['api_key'] }}",
},
},
"download_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}",
"url_base": "",
"http_method": "GET",
},
"abort_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}/abort",
"url_base": "",
"http_method": "POST",
},
"delete_requester": {
"type": "HttpRequester",
"path": "{{stream_slice['url']}}",
"url_base": "",
"http_method": "POST",
},
}

component = factory.create_component(
model_type=AsyncRetrieverModel,
component_definition=definition,
name="test_stream",
primary_key="id",
stream_slicer=None,
transformations=[],
config=config,
)

assert isinstance(component, AsyncRetriever)

async_job_partition_router = component.stream_slicer
assert isinstance(async_job_partition_router, AsyncJobPartitionRouter)
assert isinstance(async_job_partition_router.stream_slicer, SinglePartitionRouter)
job_orchestrator = async_job_partition_router.job_orchestrator_factory(
[StreamSlice(partition={}, cursor_slice={})]
)
assert isinstance(job_orchestrator, AsyncJobOrchestrator)

job_repository = job_orchestrator._job_repository
assert isinstance(job_repository, AsyncHttpJobRepository)
assert job_repository.creation_requester
assert job_repository.polling_requester
assert job_repository.download_retriever
assert job_repository.abort_requester
assert job_repository.delete_requester
assert job_repository.status_extractor
assert job_repository.urls_extractor

selector = component.record_selector
extractor = selector.extractor
assert isinstance(selector, RecordSelector)
assert isinstance(extractor, DpathExtractor)
assert extractor.field_path == ["data"]

0 comments on commit a5321da

Please sign in to comment.