Skip to content

Commit

Permalink
Do not use context manager for connection pool borrowing
Browse files Browse the repository at this point in the history
  • Loading branch information
bisho committed Nov 22, 2023
1 parent 949da4b commit e6de124
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
23 changes: 14 additions & 9 deletions src/meta_memcache/connection/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,29 @@ def _discard_connection(self, conn: MemcacheSocket, error: bool = False) -> None

@contextmanager
def get_connection(self) -> Generator[MemcacheSocket, None, None]:
try:
conn = self._pool.popleft()
except IndexError:
conn = None

if conn is None:
conn = self._create_connection()

conn = self.pop_connection()
try:
yield conn
except Exception as e:
self.release_connection(conn, error=True)
raise MemcacheServerError(self.server, "Memcache error") from e
else:
self.release_connection(conn, error=False)

def pop_connection(self) -> MemcacheSocket:
try:
return self._pool.popleft()
except IndexError:
return self._create_connection()

def release_connection(self, conn: MemcacheSocket, error: bool) -> None:
if error:
# Errors, assume connection is in bad state
_log.warning(
"Error during cache conn context (discarding connection)",
exc_info=True,
)
self._discard_connection(conn, error=True)
raise MemcacheServerError(self.server, "Memcache error") from e
else:
if len(self._pool) < self._max_pool_size:
# If there is a race, the deque might end with more than
Expand Down
21 changes: 18 additions & 3 deletions src/meta_memcache/executors/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def exec_on_pool(
else self._prepare_serialized_value_and_int_flags(value, int_flags)
)
try:
with pool.get_connection() as conn:
conn = pool.pop_connection()
error = False
try:
self._conn_send_cmd(
conn,
command=command,
Expand All @@ -126,6 +128,11 @@ def exec_on_pool(
token_flags=token_flags,
)
return self._conn_recv_response(conn, flags=flags)
except Exception as e:
error = True
raise MemcacheServerError(pool.server, "Memcache error") from e
finally:
pool.release_connection(conn, error=error)
except MemcacheServerError:
if track_write_failures and self._is_a_write_failure(command, int_flags):
self.on_write_failure(key)
Expand All @@ -141,7 +148,7 @@ def exec_on_pool(
else:
return NotStored()

def exec_multi_on_pool(
def exec_multi_on_pool( # noqa: C901
self,
pool: ConnectionPool,
command: MetaCommand,
Expand All @@ -154,7 +161,10 @@ def exec_multi_on_pool(
) -> Dict[Key, MemcacheResponse]:
results: Dict[Key, MemcacheResponse] = {}
try:
with pool.get_connection() as conn:
conn = pool.pop_connection()
error = False
try:
# with pool.get_connection() as conn:
for key, value in key_values:
cmd_value, int_flags = (
(None, int_flags)
Expand All @@ -175,6 +185,11 @@ def exec_multi_on_pool(
)
for key, _ in key_values:
results[key] = self._conn_recv_response(conn, flags=flags)
except Exception as e:
error = True
raise MemcacheServerError(pool.server, "Memcache error") from e
finally:
pool.release_connection(conn, error=error)
except MemcacheServerError:
if track_write_failures and self._is_a_write_failure(command, int_flags):
for key, _ in key_values:
Expand Down
16 changes: 7 additions & 9 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def connection_pool(
mocker: MockerFixture, memcache_socket: MemcacheSocket
) -> ConnectionPool:
connection_pool = mocker.MagicMock(spec=ConnectionPool)
connection_pool.get_connection().__enter__.return_value = memcache_socket
connection_pool.pop_connection.return_value = memcache_socket
return connection_pool


Expand All @@ -77,7 +77,7 @@ def connection_pool_1_6_6(
mocker: MockerFixture, memcache_socket_1_6_6: MemcacheSocket
) -> ConnectionPool:
connection_pool = mocker.MagicMock(spec=ConnectionPool)
connection_pool.get_connection().__enter__.return_value = memcache_socket_1_6_6
connection_pool.pop_connection.return_value = memcache_socket_1_6_6
return connection_pool


Expand Down Expand Up @@ -999,7 +999,7 @@ def test_on_write_failure(
on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key)
cache_client.on_write_failure += on_failure

connection_pool.get_connection.side_effect = MemcacheServerError(
connection_pool.pop_connection.side_effect = MemcacheServerError(
server="broken:11211", message="uh-oh"
)
try:
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def test_on_write_failure_for_reads(
on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key)
cache_client.on_write_failure += on_failure

connection_pool.get_connection.side_effect = MemcacheServerError(
connection_pool.pop_connection.side_effect = MemcacheServerError(
server="broken:11211", message="uh-oh"
)
try:
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def test_on_write_failure_for_multi_ops(
on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key)
cache_client.on_write_failure += on_failure

connection_pool.get_connection.side_effect = MemcacheServerError(
connection_pool.pop_connection.side_effect = MemcacheServerError(
server="broken:11211", message="uh-oh"
)

Expand Down Expand Up @@ -1094,7 +1094,7 @@ def test_on_write_failure_disabled(
on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key)
cache_client.on_write_failure += on_failure

connection_pool.get_connection.side_effect = MemcacheServerError(
connection_pool.pop_connection.side_effect = MemcacheServerError(
server="broken:11211", message="uh-oh"
)
try:
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def test_write_failure_not_raise_on_server_error(
on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key)
cache_client.on_write_failure += on_failure

connection_pool.get_connection.side_effect = MemcacheServerError(
connection_pool.pop_connection.side_effect = MemcacheServerError(
server="broken:11211", message="uh-oh"
)
result = cache_client.get(key=Key("foo"))
Expand Down Expand Up @@ -1215,8 +1215,6 @@ def test_delta_cmd(memcache_socket: MemcacheSocket, cache_client: CacheClient) -
memcache_socket.sendall.reset_mock()
memcache_socket.get_response.reset_mock()

# memcache_socket.get_response.return_value = Value(size=2, b"10")

memcache_socket.get_response.return_value = Success()

result = cache_client.delta(key=Key("foo"), delta=1)
Expand Down

0 comments on commit e6de124

Please sign in to comment.