diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 85bce965d..729e9001a 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -56,8 +56,9 @@ class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): - # By default, we defer to a value of 1 which represents running a connector using the Concurrent CDK engine on only one thread. - SINGLE_THREADED_CONCURRENCY_LEVEL = 1 + # By default, we defer to a value of 2. A value lower than than could cause a PartitionEnqueuer to be stuck in a state of deadlock + # because it has hit the limit of futures but not partition reader is consuming them. + SINGLE_THREADED_CONCURRENCY_LEVEL = 2 def __init__( self, @@ -78,6 +79,9 @@ def __init__( emit_connector_builder_messages=emit_connector_builder_messages, disable_resumable_full_refresh=True, ) + self._config = config + self._concurrent_streams: Optional[List[AbstractStream]] = None + self._synchronous_streams: Optional[List[Stream]] = None super().__init__( source_config=source_config, @@ -88,21 +92,6 @@ def __init__( self._state = state - self._concurrent_streams: Optional[List[AbstractStream]] - self._synchronous_streams: Optional[List[Stream]] - - # If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because - # they might depend on it. Ideally we want to have a static method on this class to get the spec without - # any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this - # for our future improvements to the CDK. - if config: - self._concurrent_streams, self._synchronous_streams = self._group_streams( - config=config or {} - ) - else: - self._concurrent_streams = None - self._synchronous_streams = None - concurrency_level_from_manifest = self._source_config.get("concurrency_level") if concurrency_level_from_manifest: concurrency_level_component = self._constructor.create_component( @@ -121,7 +110,7 @@ def __init__( ) # Partition_generation iterates using range based on this value. If this is floored to zero we end up in a dead lock during start up else: concurrency_level = self.SINGLE_THREADED_CONCURRENCY_LEVEL - initial_number_of_partitions_to_generate = self.SINGLE_THREADED_CONCURRENCY_LEVEL + initial_number_of_partitions_to_generate = self.SINGLE_THREADED_CONCURRENCY_LEVEL // 2 self._concurrent_source = ConcurrentSource.create( num_workers=concurrency_level, @@ -131,6 +120,19 @@ def __init__( message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory ) + def _actually_group(self) -> None: + # If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because + # they might depend on it. Ideally we want to have a static method on this class to get the spec without + # any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this + # for our future improvements to the CDK. + if self._config: + self._concurrent_streams, self._synchronous_streams = self._group_streams( + config=self._config or {} + ) + else: + self._concurrent_streams = None + self._synchronous_streams = None + def read( self, logger: logging.Logger, @@ -140,6 +142,9 @@ def read( ) -> Iterator[AirbyteMessage]: # ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent # streams must be saved so that they can be removed from the catalog before starting synchronous streams + if self._concurrent_streams is None: + self._actually_group() + if self._concurrent_streams: concurrent_stream_names = set( [concurrent_stream.name for concurrent_stream in self._concurrent_streams] @@ -165,6 +170,9 @@ def read( yield from super().read(logger, config, filtered_catalog, state) def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: + if self._concurrent_streams is None: + self._actually_group() + concurrent_streams = self._concurrent_streams or [] synchronous_streams = self._synchronous_streams or [] return AirbyteCatalog( @@ -193,7 +201,7 @@ def _group_streams( state_manager = ConnectorStateManager(state=self._state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later name_to_stream_mapping = { - stream["name"]: stream for stream in self.resolved_manifest["streams"] + stream["name"]: stream for stream in self._initialize_cache_for_parent_streams(self.resolved_manifest["streams"]) } for declarative_stream in self.streams(config=config): diff --git a/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte_cdk/sources/declarative/interpolation/jinja.py index ecbe9a349..f1f126e91 100644 --- a/airbyte_cdk/sources/declarative/interpolation/jinja.py +++ b/airbyte_cdk/sources/declarative/interpolation/jinja.py @@ -4,7 +4,7 @@ import ast from functools import cache -from typing import Any, Mapping, Optional, Tuple, Type +from typing import Any, Mapping, Optional, Set, Tuple, Type from jinja2 import meta from jinja2.environment import Template @@ -30,6 +30,34 @@ def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool: return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"' +# These aliases are used to deprecate existing keywords without breaking all existing connectors. +_ALIASES = { + "stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values + "stream_partition": "stream_slice", # Use stream_partition to access partition router's values +} + +# These extensions are not installed so they're not currently a problem, +# but we're still explicitely removing them from the jinja context. +# At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks +_RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops + +# By default, these Python builtin functions are available in the Jinja context. +# We explicitely remove them because of the potential security risk. +# Please add a unit test to test_jinja.py when adding a restriction. +_RESTRICTED_BUILTIN_FUNCTIONS = [ + "range" +] # The range function can cause very expensive computations + +_ENVIRONMENT = StreamPartitionAccessEnvironment() +_ENVIRONMENT.filters.update(**filters) +_ENVIRONMENT.globals.update(**macros) + +for extension in _RESTRICTED_EXTENSIONS: + _ENVIRONMENT.extensions.pop(extension, None) +for builtin in _RESTRICTED_BUILTIN_FUNCTIONS: + _ENVIRONMENT.globals.pop(builtin, None) + + class JinjaInterpolation(Interpolation): """ Interpolation strategy using the Jinja2 template engine. @@ -48,34 +76,6 @@ class JinjaInterpolation(Interpolation): Additional information on jinja templating can be found at https://jinja.palletsprojects.com/en/3.1.x/templates/# """ - # These aliases are used to deprecate existing keywords without breaking all existing connectors. - ALIASES = { - "stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values - "stream_partition": "stream_slice", # Use stream_partition to access partition router's values - } - - # These extensions are not installed so they're not currently a problem, - # but we're still explicitely removing them from the jinja context. - # At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks - RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops - - # By default, these Python builtin functions are available in the Jinja context. - # We explicitely remove them because of the potential security risk. - # Please add a unit test to test_jinja.py when adding a restriction. - RESTRICTED_BUILTIN_FUNCTIONS = [ - "range" - ] # The range function can cause very expensive computations - - def __init__(self) -> None: - self._environment = StreamPartitionAccessEnvironment() - self._environment.filters.update(**filters) - self._environment.globals.update(**macros) - - for extension in self.RESTRICTED_EXTENSIONS: - self._environment.extensions.pop(extension, None) - for builtin in self.RESTRICTED_BUILTIN_FUNCTIONS: - self._environment.globals.pop(builtin, None) - def eval( self, input_str: str, @@ -86,7 +86,7 @@ def eval( ) -> Any: context = {"config": config, **additional_parameters} - for alias, equivalent in self.ALIASES.items(): + for alias, equivalent in _ALIASES.items(): if alias in context: # This is unexpected. We could ignore or log a warning, but failing loudly should result in fewer surprises raise ValueError( @@ -105,6 +105,7 @@ def eval( raise Exception(f"Expected a string, got {input_str}") except UndefinedError: pass + # If result is empty or resulted in an undefined error, evaluate and return the default string return self._literal_eval(self._eval(default, context), valid_types) @@ -132,16 +133,16 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: return s @cache - def _find_undeclared_variables(self, s: Optional[str]) -> Template: + def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]: """ Find undeclared variables and cache them """ - ast = self._environment.parse(s) # type: ignore # parse is able to handle None + ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None return meta.find_undeclared_variables(ast) @cache - def _compile(self, s: Optional[str]) -> Template: + def _compile(self, s: str) -> Template: """ We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader """ - return self._environment.from_string(s) + return _ENVIRONMENT.from_string(s) diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index 91e2a63d9..04f99afc4 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -54,6 +54,7 @@ from airbyte_cdk.utils.traced_exception import AirbyteTracedException BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") +logger = logging.getLogger("airbyte") class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException): @@ -94,6 +95,7 @@ def __init__( ): self._name = name self._api_budget: APIBudget = api_budget or APIBudget(policies=[]) + self._logger = logger if session: self._session = session else: @@ -107,7 +109,6 @@ def __init__( ) if isinstance(authenticator, AuthBase): self._session.auth = authenticator - self._logger = logger self._error_handler = error_handler or HttpStatusErrorHandler(self._logger) if backoff_strategy is not None: if isinstance(backoff_strategy, list): @@ -141,9 +142,11 @@ def _request_session(self) -> requests.Session: if cache_dir: sqlite_path = str(Path(cache_dir) / self.cache_filename) else: + self._logger.info("Using memory for cache") # TODO: remove sqlite_path = "file::memory:?cache=shared" + backend = SkipFailureSQLiteCache(self._name, sqlite_path) # TODO maybe add a busy timeout return CachedLimiterSession( - sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True + sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True ) # type: ignore # there are no typeshed stubs for requests_cache else: return LimiterSession(api_budget=self._api_budget) @@ -517,3 +520,44 @@ def send_request( ) return request, response + + +class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict): + def __getitem__(self, key): # type: ignore # lib is not typed + try: + return super().__getitem__(key) # type: ignore # lib is not typed + except Exception as exception: + if not isinstance(exception, KeyError): + logger.warning(f"Error while retrieving item from cache: {exception}") + else: + raise exception + + def _write(self, key: str, value: str) -> None: + try: + super()._write(key, value) # type: ignore # lib is not typed + except Exception as exception: + logger.warning(f"Error while saving item to cache: {exception}") + + +class SkipFailureSQLiteCache(requests_cache.backends.sqlite.SQLiteCache): + def __init__( # type: ignore # ignoring as lib is not typed + self, + table_name="response", + db_path="http_cache", + serializer=None, + **kwargs, + ) -> None: + super().__init__(db_path, serializer, **kwargs) + skwargs = {"serializer": serializer, **kwargs} if serializer else kwargs + self.responses: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict( + db_path, table_name=table_name, fast_save=True, wal=True, **skwargs + ) + self.redirects: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict( + db_path, + table_name=f"redirects_{table_name}", + fast_save=True, + wal=True, + lock=self.responses._lock, + serializer=None, + **kwargs, + ) diff --git a/unit_tests/sources/streams/http/test_http_client.py b/unit_tests/sources/streams/http/test_http_client.py index 29bac0ec8..a62bb14c0 100644 --- a/unit_tests/sources/streams/http/test_http_client.py +++ b/unit_tests/sources/streams/http/test_http_client.py @@ -1,11 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. - +import contextlib import logging from datetime import timedelta +from sqlite3 import OperationalError from unittest.mock import MagicMock, patch import pytest import requests +import requests_cache from pympler import asizeof from requests_cache import CachedRequest @@ -741,3 +743,27 @@ def test_given_different_headers_then_response_is_not_cached(requests_mock): ) assert second_response.json()["test"] == "second response" + + +class RaiseOnInsertConnection: + def execute(*args, **kwargs) -> None: + if "INSERT" in str(args): + raise OperationalError("database table is locked") + + +def test_given_cache_save_failure_then_do_not_break(requests_mock, monkeypatch): + @contextlib.contextmanager + def _create_sqlite_write_error_connection(*args, **kwargs): + yield RaiseOnInsertConnection() + + monkeypatch.setattr( + requests_cache.backends.sqlite.SQLiteDict, + "connection", + _create_sqlite_write_error_connection, + ) + http_client = HttpClient(name="test", logger=MagicMock(), use_cache=True) + requests_mock.register_uri("GET", "https://google.com/", json={"test": "response"}) + + request, response = http_client.send_request("GET", "https://google.com/", request_kwargs={}) + + assert response.json()