Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(concurrency): support failed on http cache write #115

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
69 changes: 35 additions & 34 deletions airbyte_cdk/sources/declarative/interpolation/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
from functools import cache
from typing import Any, Mapping, Optional, Tuple, Type
from typing import Any, Mapping, Optional, Set, Tuple, Type

from jinja2 import meta
from jinja2.environment import Template
Expand All @@ -30,6 +30,34 @@ def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool:
return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"'


# These aliases are used to deprecate existing keywords without breaking all existing connectors.
_ALIASES = {
"stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values
"stream_partition": "stream_slice", # Use stream_partition to access partition router's values
}

# These extensions are not installed so they're not currently a problem,
# but we're still explicitely removing them from the jinja context.
# At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks
_RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops

# By default, these Python builtin functions are available in the Jinja context.
# We explicitely remove them because of the potential security risk.
# Please add a unit test to test_jinja.py when adding a restriction.
_RESTRICTED_BUILTIN_FUNCTIONS = [
"range"
] # The range function can cause very expensive computations

_ENVIRONMENT = StreamPartitionAccessEnvironment()
_ENVIRONMENT.filters.update(**filters)
_ENVIRONMENT.globals.update(**macros)

for extension in _RESTRICTED_EXTENSIONS:
_ENVIRONMENT.extensions.pop(extension, None)
for builtin in _RESTRICTED_BUILTIN_FUNCTIONS:
_ENVIRONMENT.globals.pop(builtin, None)


class JinjaInterpolation(Interpolation):
"""
Interpolation strategy using the Jinja2 template engine.
Expand All @@ -48,34 +76,6 @@ class JinjaInterpolation(Interpolation):
Additional information on jinja templating can be found at https://jinja.palletsprojects.com/en/3.1.x/templates/#
"""

# These aliases are used to deprecate existing keywords without breaking all existing connectors.
ALIASES = {
"stream_interval": "stream_slice", # Use stream_interval to access incremental_sync values
"stream_partition": "stream_slice", # Use stream_partition to access partition router's values
}

# These extensions are not installed so they're not currently a problem,
# but we're still explicitely removing them from the jinja context.
# At worst, this is documentation that we do NOT want to include these extensions because of the potential security risks
RESTRICTED_EXTENSIONS = ["jinja2.ext.loopcontrols"] # Adds support for break continue in loops

# By default, these Python builtin functions are available in the Jinja context.
# We explicitely remove them because of the potential security risk.
# Please add a unit test to test_jinja.py when adding a restriction.
RESTRICTED_BUILTIN_FUNCTIONS = [
"range"
] # The range function can cause very expensive computations

def __init__(self) -> None:
self._environment = StreamPartitionAccessEnvironment()
self._environment.filters.update(**filters)
self._environment.globals.update(**macros)

for extension in self.RESTRICTED_EXTENSIONS:
self._environment.extensions.pop(extension, None)
for builtin in self.RESTRICTED_BUILTIN_FUNCTIONS:
self._environment.globals.pop(builtin, None)

def eval(
self,
input_str: str,
Expand All @@ -86,7 +86,7 @@ def eval(
) -> Any:
context = {"config": config, **additional_parameters}

for alias, equivalent in self.ALIASES.items():
for alias, equivalent in _ALIASES.items():
if alias in context:
# This is unexpected. We could ignore or log a warning, but failing loudly should result in fewer surprises
raise ValueError(
Expand All @@ -105,6 +105,7 @@ def eval(
raise Exception(f"Expected a string, got {input_str}")
except UndefinedError:
pass

# If result is empty or resulted in an undefined error, evaluate and return the default string
return self._literal_eval(self._eval(default, context), valid_types)

Expand Down Expand Up @@ -132,16 +133,16 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]:
return s

@cache
def _find_undeclared_variables(self, s: Optional[str]) -> Template:
def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]:
"""
Find undeclared variables and cache them
"""
ast = self._environment.parse(s) # type: ignore # parse is able to handle None
ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None
return meta.find_undeclared_variables(ast)

@cache
def _compile(self, s: Optional[str]) -> Template:
def _compile(self, s: str) -> Template:
"""
We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader
"""
return self._environment.from_string(s)
return _ENVIRONMENT.from_string(s)
42 changes: 41 additions & 1 deletion airbyte_cdk/sources/streams/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH")
logger = logging.getLogger("airbyte")


class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException):
Expand Down Expand Up @@ -142,8 +143,9 @@ def _request_session(self) -> requests.Session:
sqlite_path = str(Path(cache_dir) / self.cache_filename)
else:
sqlite_path = "file::memory:?cache=shared"
backend = SkipFailureSQLiteCache(sqlite_path)
return CachedLimiterSession(
sqlite_path, backend="sqlite", api_budget=self._api_budget, match_headers=True
sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True
) # type: ignore # there are no typeshed stubs for requests_cache
else:
return LimiterSession(api_budget=self._api_budget)
Expand Down Expand Up @@ -517,3 +519,41 @@ def send_request(
)

return request, response


class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
def __getitem__(self, key): # type: ignore # lib is not typed
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except Exception as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving item from cache: {exception}")
else:
raise exception

def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except Exception as exception:
logger.warning(f"Error while saving item to cache: {exception}")

Comment on lines +525 to +540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and type safety

A few suggestions to make this more robust:

  1. The broad exception handling could mask critical issues. What do you think about catching specific exceptions?
  2. The warning messages could be more descriptive by including the key. wdyt?
  3. Consider adding type hints for better maintainability?

Here's a potential improvement:

-class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
+class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
+    """A SQLiteDict that logs warnings instead of raising exceptions on cache operations."""
+
-    def __getitem__(self, key):  # type: ignore  # lib is not typed
+    def __getitem__(self, key: str) -> Any:  # type: ignore  # return type from parent
         try:
             return super().__getitem__(key)  # type: ignore  # lib is not typed
-        except Exception as exception:
+        except (sqlite3.Error, IOError) as exception:
             if not isinstance(exception, KeyError):
-                logger.warning(f"Error while retrieving item from cache: {exception}")
+                logger.warning(f"Error while retrieving key '{key}' from cache: {exception}")
             else:
                 raise exception

     def _write(self, key: str, value: str) -> None:
         try:
             super()._write(key, value)  # type: ignore  # lib is not typed
-        except Exception as exception:
-            logger.warning(f"Error while saving item to cache: {exception}")
+        except (sqlite3.Error, IOError) as exception:
+            logger.warning(f"Error while saving key '{key}' to cache: {exception}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
def __getitem__(self, key): # type: ignore # lib is not typed
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except Exception as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving item from cache: {exception}")
else:
raise exception
def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except Exception as exception:
logger.warning(f"Error while saving item to cache: {exception}")
from typing import Any
import sqlite3
class SkipFailureSQLiteDict(requests_cache.backends.sqlite.SQLiteDict):
"""A SQLiteDict that logs warnings instead of raising exceptions on cache operations."""
def __getitem__(self, key: str) -> Any: # type: ignore # return type from parent
try:
return super().__getitem__(key) # type: ignore # lib is not typed
except (sqlite3.Error, IOError) as exception:
if not isinstance(exception, KeyError):
logger.warning(f"Error while retrieving key '{key}' from cache: {exception}")
else:
raise exception
def _write(self, key: str, value: str) -> None:
try:
super()._write(key, value) # type: ignore # lib is not typed
except (sqlite3.Error, IOError) as exception:
logger.warning(f"Error while saving key '{key}' to cache: {exception}")


class SkipFailureSQLiteCache(requests_cache.backends.sqlite.SQLiteCache):
def __init__( # type: ignore # ignoring as lib is not typed
self,
db_path="http_cache",
serializer=None,
**kwargs,
) -> None:
super().__init__(db_path, serializer, **kwargs)
skwargs = {"serializer": serializer, **kwargs} if serializer else kwargs
self.responses: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict(
db_path, table_name="responses", **skwargs
)
self.redirects: requests_cache.backends.sqlite.SQLiteDict = SkipFailureSQLiteDict(
db_path,
table_name="redirects",
lock=self.responses._lock,
serializer=None,
**kwargs,
)
28 changes: 27 additions & 1 deletion unit_tests/sources/streams/http/test_http_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

import contextlib
import logging
from datetime import timedelta
from sqlite3 import OperationalError
from unittest.mock import MagicMock, patch

import pytest
import requests
import requests_cache
from pympler import asizeof
from requests_cache import CachedRequest

Expand Down Expand Up @@ -741,3 +743,27 @@ def test_given_different_headers_then_response_is_not_cached(requests_mock):
)

assert second_response.json()["test"] == "second response"


class RaiseOnInsertConnection:
def execute(*args, **kwargs) -> None:
if "INSERT" in str(args):
raise OperationalError("database table is locked")


def test_given_cache_save_failure_then_do_not_break(requests_mock, monkeypatch):
@contextlib.contextmanager
def _create_sqlite_write_error_connection(*args, **kwargs):
yield RaiseOnInsertConnection()

monkeypatch.setattr(
requests_cache.backends.sqlite.SQLiteDict,
"connection",
_create_sqlite_write_error_connection,
)
http_client = HttpClient(name="test", logger=MagicMock(), use_cache=True)
requests_mock.register_uri("GET", "https://google.com/", json={"test": "response"})

request, response = http_client.send_request("GET", "https://google.com/", request_kwargs={})

assert response.json()
Loading