Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(concurrency): support failed on http cache write #115

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
44 changes: 26 additions & 18 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@


class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]):
# By default, we defer to a value of 1 which represents running a connector using the Concurrent CDK engine on only one thread.
SINGLE_THREADED_CONCURRENCY_LEVEL = 1
# By default, we defer to a value of 2. A value lower than than could cause a PartitionEnqueuer to be stuck in a state of deadlock
# because it has hit the limit of futures but not partition reader is consuming them.
SINGLE_THREADED_CONCURRENCY_LEVEL = 2

def __init__(
self,
Expand All @@ -78,6 +79,9 @@ def __init__(
emit_connector_builder_messages=emit_connector_builder_messages,
disable_resumable_full_refresh=True,
)
self._config = config
self._concurrent_streams: Optional[List[AbstractStream]] = None
self._synchronous_streams: Optional[List[Stream]] = None

super().__init__(
source_config=source_config,
Expand All @@ -88,21 +92,6 @@ def __init__(

self._state = state

self._concurrent_streams: Optional[List[AbstractStream]]
self._synchronous_streams: Optional[List[Stream]]

# If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because
# they might depend on it. Ideally we want to have a static method on this class to get the spec without
# any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this
# for our future improvements to the CDK.
if config:
self._concurrent_streams, self._synchronous_streams = self._group_streams(
config=config or {}
)
else:
self._concurrent_streams = None
self._synchronous_streams = None

concurrency_level_from_manifest = self._source_config.get("concurrency_level")
if concurrency_level_from_manifest:
concurrency_level_component = self._constructor.create_component(
Expand All @@ -121,7 +110,7 @@ def __init__(
) # Partition_generation iterates using range based on this value. If this is floored to zero we end up in a dead lock during start up
else:
concurrency_level = self.SINGLE_THREADED_CONCURRENCY_LEVEL
initial_number_of_partitions_to_generate = self.SINGLE_THREADED_CONCURRENCY_LEVEL
initial_number_of_partitions_to_generate = self.SINGLE_THREADED_CONCURRENCY_LEVEL // 2

self._concurrent_source = ConcurrentSource.create(
num_workers=concurrency_level,
Expand All @@ -131,6 +120,19 @@ def __init__(
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
)

def _actually_group(self) -> None:
# If the connector command was SPEC, there is no incoming config, and we cannot instantiate streams because
# they might depend on it. Ideally we want to have a static method on this class to get the spec without
# any other arguments, but the existing entrypoint.py isn't designed to support this. Just noting this
# for our future improvements to the CDK.
if self._config:
self._concurrent_streams, self._synchronous_streams = self._group_streams(
config=self._config or {}
)
else:
self._concurrent_streams = None
self._synchronous_streams = None

def read(
self,
logger: logging.Logger,
Expand All @@ -140,6 +142,9 @@ def read(
) -> Iterator[AirbyteMessage]:
# ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of the concurrent
# streams must be saved so that they can be removed from the catalog before starting synchronous streams
if self._concurrent_streams is None:
self._actually_group()

if self._concurrent_streams:
concurrent_stream_names = set(
[concurrent_stream.name for concurrent_stream in self._concurrent_streams]
Expand All @@ -165,6 +170,9 @@ def read(
yield from super().read(logger, config, filtered_catalog, state)

def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
if self._concurrent_streams is None:
self._actually_group()

concurrent_streams = self._concurrent_streams or []
synchronous_streams = self._synchronous_streams or []
return AirbyteCatalog(
Expand Down
69 changes: 35 additions & 34 deletions airbyte_cdk/sources/declarative/interpolation/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
from functools import cache
from typing import Any, Mapping, Optional, Tuple, Type
from typing import Any, Mapping, Optional, Set, Tuple, Type

from jinja2 import meta
from jinja2.environment import Template
Expand All @@ -30,6 +30,34 @@ def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool:
return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"'


# These aliases are used to deprecate existing keywords without breaking all existing connectors.
_ALIASES = {
"stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values
"stream_partition": "stream_slice", # Use stream_partition to access partition router's values
}

# These extensions are not installed so they're not currently a problem,
# but we're still explicitely removing them from the jinja context.
# At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks
_RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops

# By default, these Python builtin functions are available in the Jinja context.
# We explicitely remove them because of the potential security risk.
# Please add a unit test to test_jinja.py when adding a restriction.
_RESTRICTED_BUILTIN_FUNCTIONS = [
"range"
] # The range function can cause very expensive computations

_ENVIRONMENT = StreamPartitionAccessEnvironment()
_ENVIRONMENT.filters.update(**filters)
_ENVIRONMENT.globals.update(**macros)

for extension in _RESTRICTED_EXTENSIONS:
_ENVIRONMENT.extensions.pop(extension, None)
for builtin in _RESTRICTED_BUILTIN_FUNCTIONS:
_ENVIRONMENT.globals.pop(builtin, None)


class JinjaInterpolation(Interpolation):
"""
Interpolation strategy using the Jinja2 template engine.
Expand All @@ -48,34 +76,6 @@ class JinjaInterpolation(Interpolation):
Additional information on jinja templating can be found at https://jinja.palletsprojects.com/en/3.1.x/templates/#
"""

# These aliases are used to deprecate existing keywords without breaking all existing connectors.
ALIASES = {
"stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values
"stream_partition": "stream_slice", # Use stream_partition to access partition router's values
}

# These extensions are not installed so they're not currently a problem,
# but we're still explicitely removing them from the jinja context.
# At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks
RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops

# By default, these Python builtin functions are available in the Jinja context.
# We explicitely remove them because of the potential security risk.
# Please add a unit test to test_jinja.py when adding a restriction.
RESTRICTED_BUILTIN_FUNCTIONS = [
"range"
] # The range function can cause very expensive computations

def __init__(self) -> None:
self._environment = StreamPartitionAccessEnvironment()
self._environment.filters.update(**filters)
self._environment.globals.update(**macros)

for extension in self.RESTRICTED_EXTENSIONS:
self._environment.extensions.pop(extension, None)
for builtin in self.RESTRICTED_BUILTIN_FUNCTIONS:
self._environment.globals.pop(builtin, None)

def eval(
self,
input_str: str,
Expand All @@ -86,7 +86,7 @@ def eval(
) -> Any:
context = {"config": config, **additional_parameters}

for alias, equivalent in self.ALIASES.items():
for alias, equivalent in _ALIASES.items():
if alias in context:
# This is unexpected. We could ignore or log a warning, but failing loudly should result in fewer surprises
raise ValueError(
Expand All @@ -105,6 +105,7 @@ def eval(
raise Exception(f"Expected a string, got {input_str}")
except UndefinedError:
pass

# If result is empty or resulted in an undefined error, evaluate and return the default string
return self._literal_eval(self._eval(default, context), valid_types)

Expand Down Expand Up @@ -132,16 +133,16 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]:
return s

@cache
def _find_undeclared_variables(self, s: Optional[str]) -> Template:
def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]:
"""
Find undeclared variables and cache them
"""
ast = self._environment.parse(s) # type: ignore # parse is able to handle None
ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None
return meta.find_undeclared_variables(ast)

@cache
def _compile(self, s: Optional[str]) -> Template:
def _compile(self, s: str) -> Template:
"""
We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader
"""
return self._environment.from_string(s)
return _ENVIRONMENT.from_string(s)
48 changes: 46 additions & 2 deletions airbyte_cdk/sources/streams/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH")
logger = logging.getLogger("airbyte")


class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException):
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
):
self._name = name
self._api_budget: APIBudget = api_budget or APIBudget(policies=[])
self._logger = logger
if session:
self._session = session
else:
Expand All @@ -107,7 +109,6 @@ def __init__(
)
if isinstance(authenticator, AuthBase):
self._session.auth = authenticator
self._logger = logger
self._error_handler = error_handler or HttpStatusErrorHandler(self._logger)
if backoff_strategy is not None:
if isinstance(backoff_strategy, list):
Expand Down Expand Up @@ -141,9 +142,11 @@ def _request_session(self) -> requests.Session:
if cache_dir:
sqlite_path = str(Path(cache_dir) / self.cache_filename)
else:
self._logger.info("Using memory for cache") # TODO: remove
sqlite_path = "file::memory:?cache=shared"
backend = SkipFailureSQLiteCache(self._name, sqlite_path) # TODO maybe add a busy timeout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Would you consider adding a busy timeout to prevent SQLite lock errors? Here's a suggestion 🤔

I noticed we have a test case simulating a "database table is locked" error, and the current implementation inherits from requests_cache.backends.sqlite.SQLiteCache without any timeout configuration. What if we modified the __init__ to include a busy timeout? Something like:

def __init__(self, table_name="response", db_path="http_cache", serializer=None):
    super().__init__(table_name, db_path, serializer)
    self._db.execute("PRAGMA busy_timeout = 30000")  # 30 seconds timeout, wdyt?

This would help prevent those pesky lock errors in concurrent scenarios while keeping the implementation simple. What do you think about this approach? 🤔

🔗 Analysis chain

Consider implementing busy timeout for SQLite.

The TODO comment suggests adding a busy timeout. This could help prevent "database is locked" errors in concurrent scenarios. Would you like me to provide an example implementation?

Let's check if there are any existing busy timeout implementations or issues:


Let me check the implementation of the SQLite cache to understand the context better.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for existing busy timeout implementations
rg -A 2 "busy_timeout|pragma busy_timeout"

# Search for database locked errors in tests
rg "database.*locked|sqlite.*locked"

Length of output: 198


Script:

#!/bin/bash
# Find the SQLite cache implementation
ast-grep --pattern 'class SkipFailureSQLiteCache'

# Also check for any test cases related to concurrent access
rg -B 2 -A 5 "SkipFailureSQLiteCache"

Length of output: 1631

return CachedLimiterSession(
sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True
sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True
) # type: ignore # there are no typeshed stubs for requests_cache
else:
return LimiterSession(api_budget=self._api_budget)
Expand Down Expand Up @@ -517,3 +520,44 @@ def send_request(
)

return request, response


class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
def __getitem__(self, key): # type: ignore # lib is not typed
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except Exception as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving item from cache: {exception}")
else:
raise exception

def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except Exception as exception:
logger.warning(f"Error while saving item to cache: {exception}")

Comment on lines +525 to +540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and type safety

A few suggestions to make this more robust:

  1. The broad exception handling could mask critical issues. What do you think about catching specific exceptions?
  2. The warning messages could be more descriptive by including the key. wdyt?
  3. Consider adding type hints for better maintainability?

Here's a potential improvement:

-class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
+class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
+    """A SQLiteDict that logs warnings instead of raising exceptions on cache operations."""
+
-    def __getitem__(self, key):  # type: ignore  # lib is not typed
+    def __getitem__(self, key: str) -> Any:  # type: ignore  # return type from parent
         try:
             return super().__getitem__(key)  # type: ignore  # lib is not typed
-        except Exception as exception:
+        except (sqlite3.Error, IOError) as exception:
             if not isinstance(exception, KeyError):
-                logger.warning(f"Error while retrieving item from cache: {exception}")
+                logger.warning(f"Error while retrieving key '{key}' from cache: {exception}")
             else:
                 raise exception

     def _write(self, key: str, value: str) -> None:
         try:
             super()._write(key, value)  # type: ignore  # lib is not typed
-        except Exception as exception:
-            logger.warning(f"Error while saving item to cache: {exception}")
+        except (sqlite3.Error, IOError) as exception:
+            logger.warning(f"Error while saving key '{key}' to cache: {exception}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
def __getitem__(self, key): # type: ignore # lib is not typed
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except Exception as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving item from cache: {exception}")
else:
raise exception
def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except Exception as exception:
logger.warning(f"Error while saving item to cache: {exception}")
from typing import Any
import sqlite3
class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
"""A SQLiteDict that logs warnings instead of raising exceptions on cache operations."""
def __getitem__(self, key: str) -> Any: # type: ignore # return type from parent
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except (sqlite3.Error, IOError) as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving key '{key}' from cache: {exception}")
else:
raise exception
def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except (sqlite3.Error, IOError) as exception:
logger.warning(f"Error while saving key '{key}' to cache: {exception}")


class SkipFailureSQLiteCache(requests_cache.backends.sqlite.SQLiteCache):
def __init__( # type: ignore # ignoring as lib is not typed
self,
table_name="response",
db_path="http_cache",
serializer=None,
**kwargs,
) -> None:
super().__init__(db_path, serializer, **kwargs)
skwargs = {"serializer": serializer, **kwargs} if serializer else kwargs
self.responses: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict(
db_path, table_name=table_name, fast_save=True, wal=True, **skwargs
)
self.redirects: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict(
db_path,
table_name=f"redirects_{table_name}",
fast_save=True,
wal=True,
lock=self.responses._lock,
serializer=None,
**kwargs,
)
28 changes: 27 additions & 1 deletion unit_tests/sources/streams/http/test_http_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

import contextlib
import logging
from datetime import timedelta
from sqlite3 import OperationalError
from unittest.mock import MagicMock, patch

import pytest
import requests
import requests_cache
from pympler import asizeof
from requests_cache import CachedRequest

Expand Down Expand Up @@ -741,3 +743,27 @@ def test_given_different_headers_then_response_is_not_cached(requests_mock):
)

assert second_response.json()["test"] == "second response"


class RaiseOnInsertConnection:
def execute(*args, **kwargs) -> None:
if "INSERT" in str(args):
raise OperationalError("database table is locked")


def test_given_cache_save_failure_then_do_not_break(requests_mock, monkeypatch):
@contextlib.contextmanager
def _create_sqlite_write_error_connection(*args, **kwargs):
yield RaiseOnInsertConnection()

monkeypatch.setattr(
requests_cache.backends.sqlite.SQLiteDict,
"connection",
_create_sqlite_write_error_connection,
)
http_client = HttpClient(name="test", logger=MagicMock(), use_cache=True)
requests_mock.register_uri("GET", "https://google.com/", json={"test": "response"})

request, response = http_client.send_request("GET", "https://google.com/", request_kwargs={})

assert response.json()
Loading