diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 9a405f8f2..b62886ac6 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -308,6 +308,9 @@ SinglePartitionRouter, SubstreamPartitionRouter, ) +from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import ( + AsyncJobPartitionRouter, +) from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( ParentStreamConfig, ) @@ -2146,22 +2149,28 @@ def create_async_retriever( urls_extractor=urls_extractor, ) - return AsyncRetriever( + async_job_partition_router = AsyncJobPartitionRouter( job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator( job_repository, stream_slices, - JobTracker( - 1 - ), # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 + JobTracker(1), + # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 self._message_repository, - has_bulk_parent=False, # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk + has_bulk_parent=False, + # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk ), - record_selector=record_selector, stream_slicer=stream_slicer, config=config, parameters=model.parameters or {}, ) + return AsyncRetriever( + record_selector=record_selector, + stream_slicer=async_job_partition_router, + config=config, + parameters=model.parameters or {}, + ) + @staticmethod def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: return Spec( diff --git a/airbyte_cdk/sources/declarative/partition_routers/__init__.py b/airbyte_cdk/sources/declarative/partition_routers/__init__.py index 9487f5e1d..4e51ff657 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/__init__.py +++ b/airbyte_cdk/sources/declarative/partition_routers/__init__.py @@ -2,10 +2,18 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # +from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import AsyncJobPartitionRouter from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import CartesianProductStreamSlicer from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter -__all__ = ["CartesianProductStreamSlicer", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", "PartitionRouter"] +__all__ = [ + "AsyncJobPartitionRouter", + "CartesianProductStreamSlicer", + "ListPartitionRouter", + "SinglePartitionRouter", + "SubstreamPartitionRouter", + "PartitionRouter" +] diff --git a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py new file mode 100644 index 000000000..8e4b5d455 --- /dev/null +++ b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. + +from dataclasses import InitVar, dataclass, field +from typing import Any, Callable, Iterable, Mapping, Optional + +from airbyte_cdk.models import FailureType +from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( + AsyncJobOrchestrator, + AsyncPartition, +) +from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import ( + SinglePartitionRouter, +) +from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer +from airbyte_cdk.sources.types import Config, StreamSlice +from airbyte_cdk.utils.traced_exception import AirbyteTracedException + + +@dataclass +class AsyncJobPartitionRouter(StreamSlicer): + """ + Partition router that creates async jobs in a source API, periodically polls for job + completion, and supplies the completed job URL locations as stream slices so that + records can be extracted. + """ + + config: Config + parameters: InitVar[Mapping[str, Any]] + job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator] + stream_slicer: StreamSlicer = field( + default_factory=lambda: SinglePartitionRouter(parameters={}) + ) + + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + self._job_orchestrator_factory = self.job_orchestrator_factory + self._job_orchestrator: Optional[AsyncJobOrchestrator] = None + self._parameters = parameters + + def stream_slices(self) -> Iterable[StreamSlice]: + slices = self.stream_slicer.stream_slices() + self._job_orchestrator = self._job_orchestrator_factory(slices) + + for completed_partition in self._job_orchestrator.create_and_get_completed_partitions(): + yield StreamSlice( + partition=dict(completed_partition.stream_slice.partition) + | {"partition": completed_partition}, + cursor_slice=completed_partition.stream_slice.cursor_slice, + ) + + def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: + if not self._job_orchestrator: + raise AirbyteTracedException( + message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support", + internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`", + failure_type=FailureType.system_error, + ) + + return self._job_orchestrator.fetch_records(partition=partition) diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index 3d9a3ead9..d75237b03 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Iterable, Mapping, Optional +from dataclasses import InitVar, dataclass +from typing import Any, Iterable, Mapping, Optional from typing_extensions import deprecated @@ -12,9 +12,10 @@ AsyncPartition, ) from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector -from airbyte_cdk.sources.declarative.partition_routers import SinglePartitionRouter +from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import ( + AsyncJobPartitionRouter, +) from airbyte_cdk.sources.declarative.retrievers import Retriever -from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -29,15 +30,10 @@ class AsyncRetriever(Retriever): config: Config parameters: InitVar[Mapping[str, Any]] - job_orchestrator_factory: Callable[[Iterable[StreamSlice]], AsyncJobOrchestrator] record_selector: RecordSelector - stream_slicer: StreamSlicer = field( - default_factory=lambda: SinglePartitionRouter(parameters={}) - ) + stream_slicer: AsyncJobPartitionRouter def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._job_orchestrator_factory = self.job_orchestrator_factory - self.__job_orchestrator: Optional[AsyncJobOrchestrator] = None self._parameters = parameters @property @@ -54,17 +50,6 @@ def state(self, value: StreamState) -> None: """ pass - @property - def _job_orchestrator(self) -> AsyncJobOrchestrator: - if not self.__job_orchestrator: - raise AirbyteTracedException( - message="Invalid state within AsyncJobRetriever. Please contact Airbyte Support", - internal_message="AsyncPartitionRepository is expected to be accessed only after `stream_slices`", - failure_type=FailureType.system_error, - ) - - return self.__job_orchestrator - def _get_stream_state(self) -> StreamState: """ Gets the current state of the stream. @@ -99,15 +84,7 @@ def _validate_and_get_stream_slice_partition( return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices def stream_slices(self) -> Iterable[Optional[StreamSlice]]: - slices = self.stream_slicer.stream_slices() - self.__job_orchestrator = self._job_orchestrator_factory(slices) - - for completed_partition in self._job_orchestrator.create_and_get_completed_partitions(): - yield StreamSlice( - partition=dict(completed_partition.stream_slice.partition) - | {"partition": completed_partition}, - cursor_slice=completed_partition.stream_slice.cursor_slice, - ) + return self.stream_slicer.stream_slices() def read_records( self, @@ -116,7 +93,7 @@ def read_records( ) -> Iterable[StreamData]: stream_state: StreamState = self._get_stream_state() partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice) - records: Iterable[Mapping[str, Any]] = self._job_orchestrator.fetch_records(partition) + records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition) yield from self.record_selector.filter_and_transform( all_data=records, diff --git a/unit_tests/sources/declarative/async_job/test_integration.py b/unit_tests/sources/declarative/async_job/test_integration.py index be0784885..36814a508 100644 --- a/unit_tests/sources/declarative/async_job/test_integration.py +++ b/unit_tests/sources/declarative/async_job/test_integration.py @@ -20,6 +20,9 @@ from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector +from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import ( + AsyncJobPartitionRouter, +) from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer @@ -79,12 +82,16 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: config={}, parameters={}, record_selector=noop_record_selector, - stream_slicer=self._stream_slicer, - job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator( - MockAsyncJobRepository(), - stream_slices, - JobTracker(_NO_LIMIT), - self._message_repository, + stream_slicer=AsyncJobPartitionRouter( + stream_slicer=self._stream_slicer, + job_orchestrator_factory=lambda stream_slices: AsyncJobOrchestrator( + MockAsyncJobRepository(), + stream_slices, + JobTracker(_NO_LIMIT), + self._message_repository, + ), + config={}, + parameters={}, ), ), config={}, diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index e849af853..c50cfd521 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -14,6 +14,7 @@ from airbyte_cdk import AirbyteTracedException from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator from airbyte_cdk.sources.declarative.auth.token import ( ApiKeyAuthenticator, @@ -40,6 +41,7 @@ ResumableFullRefreshCursor, ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString +from airbyte_cdk.sources.declarative.models import AsyncRetriever as AsyncRetrieverModel from airbyte_cdk.sources.declarative.models import CheckStream as CheckStreamModel from airbyte_cdk.sources.declarative.models import ( CompositeErrorHandler as CompositeErrorHandlerModel, @@ -85,6 +87,7 @@ ModelToComponentFactory, ) from airbyte_cdk.sources.declarative.partition_routers import ( + AsyncJobPartitionRouter, CartesianProductStreamSlicer, ListPartitionRouter, SinglePartitionRouter, @@ -102,6 +105,7 @@ WaitTimeFromHeaderBackoffStrategy, WaitUntilTimeFromHeaderBackoffStrategy, ) +from airbyte_cdk.sources.declarative.requesters.http_job_repository import AsyncHttpJobRepository from airbyte_cdk.sources.declarative.requesters.paginators import DefaultPaginator from airbyte_cdk.sources.declarative.requesters.paginators.strategies import ( CursorPaginationStrategy, @@ -121,6 +125,7 @@ from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers import ( + AsyncRetriever, SimpleRetriever, SimpleRetrieverTestReadDecorator, ) @@ -138,6 +143,7 @@ from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import ( SingleUseRefreshTokenOauth2Authenticator, ) +from airbyte_cdk.sources.types import StreamSlice from unit_tests.sources.declarative.parsers.testing_components import ( TestingCustomSubstreamPartitionRouter, TestingSomeComponent, @@ -3294,3 +3300,97 @@ def test_create_custom_record_extractor(): } component = factory.create_component(CustomRecordExtractorModel, definition, {}) assert isinstance(component, CustomRecordExtractor) + + +def test_create_async_retriever(): + config = {"api_key": "123"} + + definition = { + "type": "AsyncRetriever", + "status_mapping": { + "failed": ["failed"], + "running": ["pending"], + "timeout": ["timeout"], + "completed": ["ready"], + }, + "urls_extractor": {"type": "DpathExtractor", "field_path": ["urls"]}, + "record_selector": { + "type": "RecordSelector", + "extractor": {"type": "DpathExtractor", "field_path": ["data"]}, + }, + "status_extractor": {"type": "DpathExtractor", "field_path": ["status"]}, + "polling_requester": { + "type": "HttpRequester", + "path": "/v3/marketing/contacts/exports/{{stream_slice['create_job_response'].json()['id'] }}", + "url_base": "https://api.sendgrid.com", + "http_method": "GET", + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config['api_key'] }}", + }, + }, + "creation_requester": { + "type": "HttpRequester", + "path": "/v3/marketing/contacts/exports", + "url_base": "https://api.sendgrid.com", + "http_method": "POST", + "authenticator": { + "type": "BearerAuthenticator", + "api_token": "{{ config['api_key'] }}", + }, + }, + "download_requester": { + "type": "HttpRequester", + "path": "{{stream_slice['url']}}", + "url_base": "", + "http_method": "GET", + }, + "abort_requester": { + "type": "HttpRequester", + "path": "{{stream_slice['url']}}/abort", + "url_base": "", + "http_method": "POST", + }, + "delete_requester": { + "type": "HttpRequester", + "path": "{{stream_slice['url']}}", + "url_base": "", + "http_method": "POST", + }, + } + + component = factory.create_component( + model_type=AsyncRetrieverModel, + component_definition=definition, + name="test_stream", + primary_key="id", + stream_slicer=None, + transformations=[], + config=config, + ) + + assert isinstance(component, AsyncRetriever) + + async_job_partition_router = component.stream_slicer + assert isinstance(async_job_partition_router, AsyncJobPartitionRouter) + assert isinstance(async_job_partition_router.stream_slicer, SinglePartitionRouter) + job_orchestrator = async_job_partition_router.job_orchestrator_factory( + [StreamSlice(partition={}, cursor_slice={})] + ) + assert isinstance(job_orchestrator, AsyncJobOrchestrator) + + job_repository = job_orchestrator._job_repository + assert isinstance(job_repository, AsyncHttpJobRepository) + assert job_repository.creation_requester + assert job_repository.polling_requester + assert job_repository.download_retriever + assert job_repository.abort_requester + assert job_repository.delete_requester + assert job_repository.status_extractor + assert job_repository.urls_extractor + + selector = component.record_selector + extractor = selector.extractor + assert isinstance(selector, RecordSelector) + assert isinstance(extractor, DpathExtractor) + assert extractor.field_path == ["data"]