diff --git a/procrastinate/__init__.py b/procrastinate/__init__.py index e55fe5172..510449494 100644 --- a/procrastinate/__init__.py +++ b/procrastinate/__init__.py @@ -5,6 +5,7 @@ from procrastinate.connector import BaseConnector from procrastinate.job_context import JobContext from procrastinate.psycopg2_connector import Psycopg2Connector +from procrastinate.psycopg3_connector import Psycopg3Connector from procrastinate.retry import BaseRetryStrategy, RetryStrategy __all__ = [ @@ -15,6 +16,7 @@ "BaseRetryStrategy", "AiopgConnector", "Psycopg2Connector", + "Psycopg3Connector", "RetryStrategy", ] diff --git a/procrastinate/psycopg3_connector.py b/procrastinate/psycopg3_connector.py new file mode 100644 index 000000000..fed3de72d --- /dev/null +++ b/procrastinate/psycopg3_connector.py @@ -0,0 +1,291 @@ +import asyncio +import functools +import logging +import re +from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional + +import psycopg +import psycopg.errors +import psycopg.sql +import psycopg.types.json +import psycopg_pool +from psycopg.rows import DictRow, dict_row + +from procrastinate import connector, exceptions, sql + +logger = logging.getLogger(__name__) + +LISTEN_TIMEOUT = 30.0 + +CoroutineFunction = Callable[..., Coroutine] + + +def wrap_exceptions(coro: CoroutineFunction) -> CoroutineFunction: + """ + Wrap psycopg3 errors as connector exceptions. + + This decorator is expected to be used on coroutine functions only. + """ + + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + return await coro(*args, **kwargs) + except psycopg.errors.UniqueViolation as exc: + raise exceptions.UniqueViolation(constraint_name=exc.diag.constraint_name) + except psycopg.Error as exc: + raise exceptions.ConnectorException from exc + + # Attaching a custom attribute to ease testability and make the + # decorator more introspectable + wrapped._exceptions_wrapped = True # type: ignore + return wrapped + + +def wrap_query_exceptions(coro: CoroutineFunction) -> CoroutineFunction: + """ + Detect "admin shutdown" errors and retry a number of times. + + This is to handle the case where the database connection (obtained from the pool) + was actually closed by the server. In this case, pyscopg3 raises an AdminShutdown + exception when the connection is used for issuing a query. What we do is retry when + an AdminShutdown is raised, and until the maximum number of retries is reached. + + The number of retries is set to the pool maximum size plus one, to handle the case + where the connections we have in the pool were all closed on the server side. + """ + + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + final_exc = None + try: + max_tries = args[0]._pool.max_size + 1 + except Exception: + max_tries = 1 + for _ in range(max_tries): + try: + return await coro(*args, **kwargs) + except psycopg.errors.OperationalError as exc: + if "server closed the connection unexpectedly" in str(exc): + final_exc = exc + continue + raise exc + raise exceptions.ConnectorException( + f"Could not get a valid connection after {max_tries} tries" + ) from final_exc + + return wrapped + + +PERCENT_PATTERN = re.compile(r"%(?![\(s])") + + +class Psycopg3Connector(connector.BaseAsyncConnector): + def __init__( + self, + *, + json_dumps: Optional[Callable] = None, + json_loads: Optional[Callable] = None, + **kwargs: Any, + ): + """ + Asynchronous connector based on a ``psycopg_pool.AsyncConnectionPool``. + + The pool connection parameters can be provided here. Alternatively, an already + existing ``psycopg_pool.AsyncConnectionPool`` can be provided in the + ``App.open_async``, via the ``pool`` parameter. + + All other arguments than ``json_dumps`` and ``json_loads`` are passed to + :py:func:`AsyncConnectionPool` (see psycopg3 documentation__), with default + values that may differ from those of ``psycopg3`` (see a partial list of + parameters below). + + .. _psycopg3 doc: https://www.psycopg.org/psycopg3/docs/basic/adapt.html#json-adaptation + .. __: https://www.psycopg.org/psycopg3/docs/api/pool.html + #psycopg_pool.AsyncConnectionPool + + Parameters + ---------- + json_dumps : + The JSON dumps function to use for serializing job arguments. Defaults to + the function used by psycopg3. See the `psycopg3 doc`_. + json_loads : + The JSON loads function to use for deserializing job arguments. Defaults + to the function used by psycopg3. See the `psycopg3 doc`_. Unused if the + pool is externally created and set into the connector through the + ``App.open_async`` method. + min_size : int + Passed to psycopg3, default set to 1 (same as aiopg). + max_size : int + Passed to psycopg3, default set to 10 (same as aiopg). + conninfo : ``Optional[str]`` + Passed to psycopg3. Default is "" instead of None, which means if no + argument is passed, it will connect to localhost:5432 instead of a + Unix-domain local socket file. + """ + self.json_dumps = json_dumps + self.json_loads = json_loads + self._pool: Optional[psycopg_pool.AsyncConnectionPool] = None + self._pool_args = self._adapt_pool_args(kwargs, json_loads) + self._pool_externally_set = False + + @staticmethod + def _adapt_pool_args( + pool_args: Dict[str, Any], json_loads: Optional[Callable] + ) -> Dict[str, Any]: + """ + Adapt the pool args for ``psycopg3``, using sensible defaults for Procrastinate. + """ + base_configure = pool_args.pop("configure", None) + + @wrap_exceptions + async def configure(connection: psycopg.AsyncConnection[DictRow]): + if base_configure: + await base_configure(connection) + if json_loads: + psycopg.types.json.set_json_loads(json_loads, connection) + + return { + "conninfo": "", + "min_size": 1, + "max_size": 10, + "kwargs": { + "row_factory": dict_row, + }, + "configure": configure, + "open": False, + **pool_args, + } + + async def open_async( + self, pool: Optional[psycopg_pool.AsyncConnectionPool] = None + ) -> None: + """ + Instantiate the pool. + + pool : + Optional pool. Procrastinate can use an existing pool. Connection parameters + passed in the constructor will be ignored. + """ + if self._pool: + return + + if pool: + self._pool_externally_set = True + self._pool = pool + else: + self._pool = await self._create_pool(self._pool_args) + + # ensure pool is open + await self._pool.open() # type: ignore + + @staticmethod + @wrap_exceptions + async def _create_pool( + pool_args: Dict[str, Any] + ) -> psycopg_pool.AsyncConnectionPool: + return psycopg_pool.AsyncConnectionPool(**pool_args) + + @wrap_exceptions + async def close_async(self) -> None: + """ + Close the pool and awaits all connections to be released. + """ + if not self._pool or self._pool_externally_set: + return + + await self._pool.close() + self._pool = None + + @property + def pool( + self, + ) -> psycopg_pool.AsyncConnectionPool[psycopg.AsyncConnection[DictRow]]: + if self._pool is None: # Set by open + raise exceptions.AppNotOpen + return self._pool + + def _wrap_json(self, arguments: Dict[str, Any]): + return { + key: psycopg.types.json.Jsonb(value, dumps=self.json_dumps) + if isinstance(value, dict) + else value + for key, value in arguments.items() + } + + @wrap_exceptions + @wrap_query_exceptions + async def execute_query_async(self, query: str, **arguments: Any) -> None: + async with self.pool.connection() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, self._wrap_json(arguments)) + + @wrap_exceptions + @wrap_query_exceptions + async def execute_query_one_async( + self, query: str, **arguments: Any + ) -> Optional[DictRow]: + async with self.pool.connection() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, self._wrap_json(arguments)) + return await cursor.fetchone() + + @wrap_exceptions + @wrap_query_exceptions + async def execute_query_all_async( + self, query: str, **arguments: Any + ) -> List[DictRow]: + async with self.pool.connection() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, self._wrap_json(arguments)) + return await cursor.fetchall() + + @wrap_exceptions + async def listen_notify( + self, event: asyncio.Event, channels: Iterable[str] + ) -> None: + # We need to acquire a dedicated connection, and use the listen + # query + if self.pool.max_size == 1: + logger.warning( + "Listen/Notify capabilities disabled because maximum pool size" + "is set to 1", + extra={"action": "listen_notify_disabled"}, + ) + return + + query_template = psycopg.sql.SQL(sql.queries["listen_queue"]) + + while True: + async with self.pool.connection() as connection: + # autocommit is required for async connection notifies + await connection.set_autocommit(True) + + for channel_name in channels: + query = query_template.format( + channel_name=psycopg.sql.Identifier(channel_name) + ) + await connection.execute(query) + + event.set() + + await self._loop_notify(event=event, connection=connection) + + @wrap_exceptions + async def _loop_notify( + self, + event: asyncio.Event, + connection: psycopg.AsyncConnection, + ) -> None: + while True: + if connection.closed: + return + try: + notifies = connection.notifies() + async for _ in notifies: + event.set() + except psycopg.OperationalError: + continue + + def __del__(self): + pass diff --git a/tests/conftest.py b/tests/conftest.py index eb82b094b..39f5705a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,13 @@ import pytest from psycopg2 import sql from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +from psycopg.conninfo import make_conninfo from procrastinate import aiopg_connector as aiopg_connector_module from procrastinate import app as app_module from procrastinate import blueprints, builtin_tasks, jobs from procrastinate import psycopg2_connector as psycopg2_connector_module +from procrastinate import psycopg3_connector as psycopg3_connector_module from procrastinate import schema, testing from procrastinate.contrib.sqlalchemy import ( psycopg2_connector as sqlalchemy_psycopg2_connector_module, @@ -102,6 +104,13 @@ def connection_params(setup_db, db_factory): yield {"dsn": "", "dbname": "procrastinate_test"} +@pytest.fixture +def psycopg3_connection_params(setup_db, db_factory): + db_factory(dbname="procrastinate_test", template=setup_db) + + yield {"conninfo": make_conninfo(dbname="procrastinate_test")} + + @pytest.fixture def sqlalchemy_engine_dsn(setup_db, db_factory): db_factory(dbname="procrastinate_test", template=setup_db) @@ -120,6 +129,11 @@ async def not_opened_aiopg_connector(connection_params): yield aiopg_connector_module.AiopgConnector(**connection_params) +@pytest.fixture +async def not_opened_psycopg3_connector(psycopg3_connection_params): + yield psycopg3_connector_module.Psycopg3Connector(**psycopg3_connection_params) + + @pytest.fixture def not_opened_psycopg2_connector(connection_params): yield psycopg2_connector_module.Psycopg2Connector(**connection_params) @@ -139,6 +153,13 @@ async def aiopg_connector(not_opened_aiopg_connector): await not_opened_aiopg_connector.close_async() +@pytest.fixture +async def psycopg3_connector(not_opened_psycopg3_connector): + await not_opened_psycopg3_connector.open_async() + yield not_opened_psycopg3_connector + await not_opened_psycopg3_connector.close_async() + + @pytest.fixture def psycopg2_connector(not_opened_psycopg2_connector): not_opened_psycopg2_connector.open() diff --git a/tests/integration/test_psycopg3_connector.py b/tests/integration/test_psycopg3_connector.py new file mode 100644 index 000000000..ce53f24a3 --- /dev/null +++ b/tests/integration/test_psycopg3_connector.py @@ -0,0 +1,202 @@ +import asyncio +import functools +import json + +import attr +import pytest + +from procrastinate import psycopg3_connector + + +@pytest.fixture +async def psycopg3_connector_factory(psycopg3_connection_params): + connectors = [] + + async def _(**kwargs): + json_dumps = kwargs.pop("json_dumps", None) + json_loads = kwargs.pop("json_loads", None) + psycopg3_connection_params.update(kwargs) + connector = psycopg3_connector.Psycopg3Connector( + json_dumps=json_dumps, json_loads=json_loads, **psycopg3_connection_params + ) + connectors.append(connector) + await connector.open_async() + return connector + + yield _ + for connector in connectors: + await connector.close_async() + + +async def test_adapt_pool_args_configure(mocker): + called = [] + + async def configure(connection): + called.append(connection) + + args = psycopg3_connector.Psycopg3Connector._adapt_pool_args( + pool_args={"configure": configure}, json_loads=None + ) + + assert args["configure"] is not configure + + connection = mocker.Mock(_pool=None) + await args["configure"](connection) + + assert called == [connection] + + +@pytest.mark.parametrize( + "method_name, expected", + [ + ("execute_query_one_async", {"json": {"a": "a", "b": "foo"}}), + ("execute_query_all_async", [{"json": {"a": "a", "b": "foo"}}]), + ], +) +async def test_execute_query_json_dumps( + psycopg3_connector_factory, mocker, method_name, expected +): + class NotJSONSerializableByDefault: + pass + + def encode(obj): + if isinstance(obj, NotJSONSerializableByDefault): + return "foo" + raise TypeError() + + query = "SELECT %(arg)s::jsonb as json" + arg = {"a": "a", "b": NotJSONSerializableByDefault()} + json_dumps = functools.partial(json.dumps, default=encode) + connector = await psycopg3_connector_factory(json_dumps=json_dumps) + method = getattr(connector, method_name) + + result = await method(query, arg=arg) + assert result == expected + + +async def test_json_loads(psycopg3_connector_factory, mocker): + @attr.dataclass + class Param: + p: int + + def decode(dct): + if "b" in dct: + dct["b"] = Param(p=dct["b"]) + return dct + + json_loads = functools.partial(json.loads, object_hook=decode) + + query = "SELECT %(arg)s::jsonb as json" + arg = {"a": 1, "b": 2} + connector = await psycopg3_connector_factory(json_loads=json_loads) + + result = await connector.execute_query_one_async(query, arg=arg) + assert result["json"] == {"a": 1, "b": Param(p=2)} + + +async def test_execute_query(psycopg3_connector): + assert ( + await psycopg3_connector.execute_query_async( + "COMMENT ON TABLE \"procrastinate_jobs\" IS 'foo' " + ) + is None + ) + result = await psycopg3_connector.execute_query_one_async( + "SELECT obj_description('public.procrastinate_jobs'::regclass)" + ) + assert result == {"obj_description": "foo"} + + result = await psycopg3_connector.execute_query_all_async( + "SELECT obj_description('public.procrastinate_jobs'::regclass)" + ) + assert result == [{"obj_description": "foo"}] + + +async def test_execute_query_interpolate(psycopg3_connector): + result = await psycopg3_connector.execute_query_one_async( + "SELECT %(foo)s as foo;", foo="bar" + ) + assert result == {"foo": "bar"} + + +@pytest.mark.filterwarnings("error::ResourceWarning") +async def test_execute_query_simultaneous(psycopg3_connector): + # two coroutines doing execute_query_async simultaneously + # + # the test may fail if the connector fails to properly parallelize connections + + async def query(): + await psycopg3_connector.execute_query_async("SELECT 1") + + try: + await asyncio.gather(query(), query()) + except ResourceWarning: + pytest.fail("ResourceWarning") + + +async def test_close_async(psycopg3_connector): + await psycopg3_connector.execute_query_async("SELECT 1") + pool = psycopg3_connector._pool + await psycopg3_connector.close_async() + assert pool.closed is True + assert psycopg3_connector._pool is None + + +async def test_listen_notify(psycopg3_connector): + channel = "somechannel" + event = asyncio.Event() + + task = asyncio.ensure_future( + psycopg3_connector.listen_notify(channels=[channel], event=event) + ) + try: + await event.wait() + event.clear() + await psycopg3_connector.execute_query_async(f"""NOTIFY "{channel}" """) + await asyncio.wait_for(event.wait(), timeout=1) + except asyncio.TimeoutError: + pytest.fail("Notify not received within 1 sec") + finally: + task.cancel() + + +async def test_loop_notify_stop_when_connection_closed(psycopg3_connector): + # We want to make sure that the when the connection is closed, the loop end. + event = asyncio.Event() + await psycopg3_connector.open_async() + async with psycopg3_connector._pool.connection() as connection: + coro = psycopg3_connector._loop_notify(event=event, connection=connection) + + await psycopg3_connector._pool.close() + assert connection.closed + + try: + await asyncio.wait_for(coro, 1) + except asyncio.TimeoutError: + pytest.fail("Failed to detect that connection was closed and stop") + + +async def test_loop_notify_timeout(psycopg3_connector): + # We want to make sure that when the listen starts, we don't listen forever. If the + # connection closes, we eventually finish the coroutine. + event = asyncio.Event() + await psycopg3_connector.open_async() + async with psycopg3_connector._pool.connection() as connection: + task = asyncio.ensure_future( + psycopg3_connector._loop_notify(event=event, connection=connection) + ) + assert not task.done() + + await psycopg3_connector._pool.close() + assert connection.closed + + try: + await asyncio.wait_for(task, 0.1) + except asyncio.TimeoutError: + pytest.fail("Failed to detect that connection was closed and stop") + + assert not event.is_set() + + +async def test_destructor(): + ... diff --git a/tests/unit/test_psycopg3_connector.py b/tests/unit/test_psycopg3_connector.py new file mode 100644 index 000000000..b53fb6070 --- /dev/null +++ b/tests/unit/test_psycopg3_connector.py @@ -0,0 +1,167 @@ +import psycopg +import pytest + +from procrastinate import exceptions, psycopg3_connector + + +@pytest.fixture +def connector(): + return psycopg3_connector.Psycopg3Connector() + + +async def test_adapt_pool_args_configure(mocker): + called = [] + + async def configure(connection): + called.append(connection) + + args = psycopg3_connector.Psycopg3Connector._adapt_pool_args( + pool_args={"configure": configure}, json_loads=None + ) + + assert args["configure"] is not configure + + connection = mocker.Mock(_pool=None) + await args["configure"](connection) + + assert called == [connection] + + +async def test_wrap_exceptions_wraps(): + @psycopg3_connector.wrap_exceptions + async def corofunc(): + raise psycopg.DatabaseError + + coro = corofunc() + + with pytest.raises(exceptions.ConnectorException): + await coro + + +async def test_wrap_exceptions_success(): + @psycopg3_connector.wrap_exceptions + async def corofunc(a, b): + return a, b + + assert await corofunc(1, 2) == (1, 2) + + +@pytest.mark.parametrize( + "max_size, expected_calls_count", + [ + pytest.param(5, 6, id="Valid max_size"), + pytest.param("5", 1, id="Invalid max_size"), + ], +) +async def test_wrap_query_exceptions_reached_max_tries( + mocker, max_size, expected_calls_count +): + called = [] + + @psycopg3_connector.wrap_query_exceptions + async def corofunc(connector): + called.append(True) + raise psycopg.errors.OperationalError( + "server closed the connection unexpectedly" + ) + + connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=max_size)) + coro = corofunc(connector) + + with pytest.raises(exceptions.ConnectorException) as excinfo: + await coro + + assert len(called) == expected_calls_count + assert ( + str(excinfo.value) + == f"Could not get a valid connection after {expected_calls_count} tries" + ) + + +@pytest.mark.parametrize( + "exception_class", [Exception, psycopg.errors.OperationalError] +) +async def test_wrap_query_exceptions_unhandled_exception(mocker, exception_class): + called = [] + + @psycopg3_connector.wrap_query_exceptions + async def corofunc(connector): + called.append(True) + raise exception_class("foo") + + connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=5)) + coro = corofunc(connector) + + with pytest.raises(exception_class): + await coro + + assert len(called) == 1 + + +async def test_wrap_query_exceptions_success(mocker): + called = [] + + @psycopg3_connector.wrap_query_exceptions + async def corofunc(connector, a, b): + if len(called) < 2: + called.append(True) + raise psycopg.errors.OperationalError( + "server closed the connection unexpectedly" + ) + return a, b + + connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=5)) + + assert await corofunc(connector, 1, 2) == (1, 2) + assert len(called) == 2 + + +@pytest.mark.parametrize( + "method_name", + [ + "_create_pool", + "close_async", + "execute_query_async", + "execute_query_one_async", + "execute_query_all_async", + "listen_notify", + ], +) +def test_wrap_exceptions_applied(method_name, connector): + assert getattr(connector, method_name)._exceptions_wrapped is True + + +async def test_listen_notify_pool_one_connection(mocker, caplog, connector): + pool = mocker.AsyncMock(max_size=1) + await connector.open_async(pool) + caplog.clear() + + await connector.listen_notify(None, None) + + assert {e.action for e in caplog.records} == {"listen_notify_disabled"} + + +async def test_open_async_no_pool_specified(mocker, connector): + mocker.patch.object(connector, "_create_pool", return_value=mocker.AsyncMock()) + + await connector.open_async() + + assert connector._create_pool.call_count == 1 + assert connector._pool.open.await_count == 1 + + +async def test_open_async_pool_argument_specified(mocker, connector): + mocker.patch.object(connector, "_create_pool") + pool = mocker.AsyncMock() + + await connector.open_async(pool) + + assert connector._pool_externally_set is True + assert connector._create_pool.call_count == 0 + assert connector._pool.open.await_count == 1 + assert connector._pool == pool + + +def test_get_pool(connector): + with pytest.raises(exceptions.AppNotOpen): + _ = connector.pool