Skip to content

Commit

Permalink
wrap_query_exception isn't necessary anymore for psycopg3
Browse files Browse the repository at this point in the history
  • Loading branch information
ewjoachim committed Dec 16, 2023
1 parent dc5dc66 commit 7c4cd2d
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 112 deletions.
42 changes: 0 additions & 42 deletions procrastinate/psycopg_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import functools
import logging
import re
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional

import psycopg
Expand Down Expand Up @@ -43,44 +42,6 @@ async def wrapped(*args, **kwargs):
return wrapped


def wrap_query_exceptions(coro: CoroutineFunction) -> CoroutineFunction:
"""
Detect "admin shutdown" errors and retry a number of times.
This is to handle the case where the database connection (obtained from the pool)
was actually closed by the server. In this case, pyscopg3 raises an AdminShutdown
exception when the connection is used for issuing a query. What we do is retry when
an AdminShutdown is raised, and until the maximum number of retries is reached.
The number of retries is set to the pool maximum size plus one, to handle the case
where the connections we have in the pool were all closed on the server side.
"""

@functools.wraps(coro)
async def wrapped(*args, **kwargs):
final_exc = None
try:
max_tries = args[0]._pool.max_size + 1
except Exception:
max_tries = 1
for _ in range(max_tries):
try:
return await coro(*args, **kwargs)
except psycopg.errors.OperationalError as exc:
if "server closed the connection unexpectedly" in str(exc):
final_exc = exc
continue
raise exc
raise exceptions.ConnectorException(
f"Could not get a valid connection after {max_tries} tries"
) from final_exc

return wrapped


PERCENT_PATTERN = re.compile(r"%(?![\(s])")


class PsycopgConnector(connector.BaseAsyncConnector):
def __init__(
self,
Expand Down Expand Up @@ -230,13 +191,11 @@ def _wrap_json(self, arguments: Dict[str, Any]):
}

@wrap_exceptions
@wrap_query_exceptions
async def execute_query_async(self, query: LiteralString, **arguments: Any) -> None:
async with self.pool.connection() as connection:
await connection.execute(query, self._wrap_json(arguments))

@wrap_exceptions
@wrap_query_exceptions
async def execute_query_one_async(
self, query: LiteralString, **arguments: Any
) -> DictRow:
Expand All @@ -251,7 +210,6 @@ async def execute_query_one_async(
return result

@wrap_exceptions
@wrap_query_exceptions
async def execute_query_all_async(
self, query: LiteralString, **arguments: Any
) -> List[DictRow]:
Expand Down
70 changes: 0 additions & 70 deletions tests/unit/test_psycopg_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,76 +46,6 @@ async def corofunc(a, b):
assert await corofunc(1, 2) == (1, 2)


@pytest.mark.parametrize(
"max_size, expected_calls_count",
[
pytest.param(5, 6, id="Valid max_size"),
pytest.param("5", 1, id="Invalid max_size"),
],
)
async def test_wrap_query_exceptions_reached_max_tries(
mocker, max_size, expected_calls_count
):
called = []

@psycopg_connector.wrap_query_exceptions
async def corofunc(connector):
called.append(True)
raise psycopg.errors.OperationalError(
"server closed the connection unexpectedly"
)

connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=max_size))
coro = corofunc(connector)

with pytest.raises(exceptions.ConnectorException) as excinfo:
await coro

assert len(called) == expected_calls_count
assert (
str(excinfo.value)
== f"Could not get a valid connection after {expected_calls_count} tries"
)


@pytest.mark.parametrize(
"exception_class", [Exception, psycopg.errors.OperationalError]
)
async def test_wrap_query_exceptions_unhandled_exception(mocker, exception_class):
called = []

@psycopg_connector.wrap_query_exceptions
async def corofunc(connector):
called.append(True)
raise exception_class("foo")

connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=5))
coro = corofunc(connector)

with pytest.raises(exception_class):
await coro

assert len(called) == 1


async def test_wrap_query_exceptions_success(mocker):
called = []

@psycopg_connector.wrap_query_exceptions
async def corofunc(connector, a, b):
if len(called) < 2:
called.append(True)
raise psycopg.errors.OperationalError(
"server closed the connection unexpectedly"
)
return a, b

connector = mocker.Mock(_pool=mocker.AsyncMock(max_size=5))

assert await corofunc(connector, 1, 2) == (1, 2)
assert len(called) == 2


@pytest.mark.parametrize(
"method_name",
[
Expand Down

0 comments on commit 7c4cd2d

Please sign in to comment.