From e6de1248bedbf6b25392e22da6ce70492c692583 Mon Sep 17 00:00:00 2001 From: Guillermo Perez Date: Mon, 6 Nov 2023 09:57:53 +0100 Subject: [PATCH] Do not use context manager for connection pool borrowing --- src/meta_memcache/connection/pool.py | 23 ++++++++++++++--------- src/meta_memcache/executors/default.py | 21 ++++++++++++++++++--- tests/commands_test.py | 16 +++++++--------- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/meta_memcache/connection/pool.py b/src/meta_memcache/connection/pool.py index d0ede25..cb67db7 100644 --- a/src/meta_memcache/connection/pool.py +++ b/src/meta_memcache/connection/pool.py @@ -115,24 +115,29 @@ def _discard_connection(self, conn: MemcacheSocket, error: bool = False) -> None @contextmanager def get_connection(self) -> Generator[MemcacheSocket, None, None]: - try: - conn = self._pool.popleft() - except IndexError: - conn = None - - if conn is None: - conn = self._create_connection() - + conn = self.pop_connection() try: yield conn except Exception as e: + self.release_connection(conn, error=True) + raise MemcacheServerError(self.server, "Memcache error") from e + else: + self.release_connection(conn, error=False) + + def pop_connection(self) -> MemcacheSocket: + try: + return self._pool.popleft() + except IndexError: + return self._create_connection() + + def release_connection(self, conn: MemcacheSocket, error: bool) -> None: + if error: # Errors, assume connection is in bad state _log.warning( "Error during cache conn context (discarding connection)", exc_info=True, ) self._discard_connection(conn, error=True) - raise MemcacheServerError(self.server, "Memcache error") from e else: if len(self._pool) < self._max_pool_size: # If there is a race, the deque might end with more than diff --git a/src/meta_memcache/executors/default.py b/src/meta_memcache/executors/default.py index 0c126f2..bba5e3b 100644 --- a/src/meta_memcache/executors/default.py +++ b/src/meta_memcache/executors/default.py @@ -115,7 +115,9 @@ def exec_on_pool( else self._prepare_serialized_value_and_int_flags(value, int_flags) ) try: - with pool.get_connection() as conn: + conn = pool.pop_connection() + error = False + try: self._conn_send_cmd( conn, command=command, @@ -126,6 +128,11 @@ def exec_on_pool( token_flags=token_flags, ) return self._conn_recv_response(conn, flags=flags) + except Exception as e: + error = True + raise MemcacheServerError(pool.server, "Memcache error") from e + finally: + pool.release_connection(conn, error=error) except MemcacheServerError: if track_write_failures and self._is_a_write_failure(command, int_flags): self.on_write_failure(key) @@ -141,7 +148,7 @@ def exec_on_pool( else: return NotStored() - def exec_multi_on_pool( + def exec_multi_on_pool( # noqa: C901 self, pool: ConnectionPool, command: MetaCommand, @@ -154,7 +161,10 @@ def exec_multi_on_pool( ) -> Dict[Key, MemcacheResponse]: results: Dict[Key, MemcacheResponse] = {} try: - with pool.get_connection() as conn: + conn = pool.pop_connection() + error = False + try: + # with pool.get_connection() as conn: for key, value in key_values: cmd_value, int_flags = ( (None, int_flags) @@ -175,6 +185,11 @@ def exec_multi_on_pool( ) for key, _ in key_values: results[key] = self._conn_recv_response(conn, flags=flags) + except Exception as e: + error = True + raise MemcacheServerError(pool.server, "Memcache error") from e + finally: + pool.release_connection(conn, error=error) except MemcacheServerError: if track_write_failures and self._is_a_write_failure(command, int_flags): for key, _ in key_values: diff --git a/tests/commands_test.py b/tests/commands_test.py index 6aa8045..1b3a474 100644 --- a/tests/commands_test.py +++ b/tests/commands_test.py @@ -68,7 +68,7 @@ def connection_pool( mocker: MockerFixture, memcache_socket: MemcacheSocket ) -> ConnectionPool: connection_pool = mocker.MagicMock(spec=ConnectionPool) - connection_pool.get_connection().__enter__.return_value = memcache_socket + connection_pool.pop_connection.return_value = memcache_socket return connection_pool @@ -77,7 +77,7 @@ def connection_pool_1_6_6( mocker: MockerFixture, memcache_socket_1_6_6: MemcacheSocket ) -> ConnectionPool: connection_pool = mocker.MagicMock(spec=ConnectionPool) - connection_pool.get_connection().__enter__.return_value = memcache_socket_1_6_6 + connection_pool.pop_connection.return_value = memcache_socket_1_6_6 return connection_pool @@ -999,7 +999,7 @@ def test_on_write_failure( on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key) cache_client.on_write_failure += on_failure - connection_pool.get_connection.side_effect = MemcacheServerError( + connection_pool.pop_connection.side_effect = MemcacheServerError( server="broken:11211", message="uh-oh" ) try: @@ -1029,7 +1029,7 @@ def test_on_write_failure_for_reads( on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key) cache_client.on_write_failure += on_failure - connection_pool.get_connection.side_effect = MemcacheServerError( + connection_pool.pop_connection.side_effect = MemcacheServerError( server="broken:11211", message="uh-oh" ) try: @@ -1065,7 +1065,7 @@ def test_on_write_failure_for_multi_ops( on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key) cache_client.on_write_failure += on_failure - connection_pool.get_connection.side_effect = MemcacheServerError( + connection_pool.pop_connection.side_effect = MemcacheServerError( server="broken:11211", message="uh-oh" ) @@ -1094,7 +1094,7 @@ def test_on_write_failure_disabled( on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key) cache_client.on_write_failure += on_failure - connection_pool.get_connection.side_effect = MemcacheServerError( + connection_pool.pop_connection.side_effect = MemcacheServerError( server="broken:11211", message="uh-oh" ) try: @@ -1129,7 +1129,7 @@ def test_write_failure_not_raise_on_server_error( on_failure: Callable[[Key], None] = lambda key: failures_tracked.append(key) cache_client.on_write_failure += on_failure - connection_pool.get_connection.side_effect = MemcacheServerError( + connection_pool.pop_connection.side_effect = MemcacheServerError( server="broken:11211", message="uh-oh" ) result = cache_client.get(key=Key("foo")) @@ -1215,8 +1215,6 @@ def test_delta_cmd(memcache_socket: MemcacheSocket, cache_client: CacheClient) - memcache_socket.sendall.reset_mock() memcache_socket.get_response.reset_mock() - # memcache_socket.get_response.return_value = Value(size=2, b"10") - memcache_socket.get_response.return_value = Success() result = cache_client.delta(key=Key("foo"), delta=1)