diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index 58918ce352..32ab7531bb 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -33,6 +33,7 @@ source "$HOME/.cargo/env" rustc --version # Install protobuf compiler PB_REL="https://github.com/protocolbuffers/protobuf/releases" +# For other arch the signature of the should be protoc---.zip, e.g. protoc-3.20.3-linux-aarch_64.zip for ARM64. curl -LO $PB_REL/download/v3.20.3/protoc-3.20.3-linux-x86_64.zip unzip protoc-3.20.3-linux-x86_64.zip -d $HOME/.local export PATH="$PATH:$HOME/.local/bin" diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 7b0510dbb1..6906c31a96 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -96,7 +96,7 @@ SlotType, ) -from .glide import Script +from .glide import ClusterScanCursor, Script __all__ = [ # Client @@ -173,6 +173,7 @@ "TrimByMaxLen", "TrimByMinId", "UpdateOptions", + "ClusterScanCursor" # Logger "Logger", "LogLevel", diff --git a/python/python/glide/async_commands/cluster_commands.py b/python/python/glide/async_commands/cluster_commands.py index f2f17a01c6..b3e40f3115 100644 --- a/python/python/glide/async_commands/cluster_commands.py +++ b/python/python/glide/async_commands/cluster_commands.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Dict, List, Mapping, Optional, cast +from typing import Dict, List, Mapping, Optional, Union, cast -from glide.async_commands.command_args import Limit, OrderBy +from glide.async_commands.command_args import Limit, ObjectType, OrderBy from glide.async_commands.core import ( CoreCommands, FlushMode, @@ -16,6 +16,8 @@ from glide.protobuf.redis_request_pb2 import RequestType from glide.routes import Route +from ..glide import ClusterScanCursor + class ClusterCommands(CoreCommands): async def custom_command( @@ -624,3 +626,86 @@ async def lolwut( TClusterResponse[str], await self._execute_command(RequestType.Lolwut, args, route), ) + + async def scan( + self, + cursor: ClusterScanCursor, + match: Optional[str] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[ClusterScanCursor, List[str]]]: + """ + Incrementally iterates over the keys in the Redis Cluster. + The method returns a list containing the next cursor and a list of keys. + + This command is similar to the SCAN command, but it is designed to work in a Redis Cluster environment. + It do so by iterating over the keys in the cluster, one node at a time, while maintaining a consistent view of + the slots that are being scanned. + The view is maintaining by saving the slots that have been scanned in the scanState, while returning a ref to the + the state in the cursor object. + After every node that been scanned the method check for changes as failover or resharding and get a validated + result of the slots that been covered in the scan and checking for the next node own the next slots to scan. + Every cursor is a new state object, which mean that using the same cursor object will result the scan to handle + the same scan iteration again. + For each iteration the new cursor object should be used to continue the scan. + + As the SCAN command, the method can be used to iterate over the keys in the database, the guarantee of the scan is + to return all keys the database have from the time the scan started that stay in the database till the scan ends. + The same key can be returned in multiple scans iteration. + + See https://valkey.io/commands/scan/ for more details. + + Args: + cursor (ClusterScanCursor): The cursor object wrapping the scan state - when starting a new scan + creation of new empty ClusterScanCursor is needed `ClusterScanCursor()`. + match (Optional[str]): A pattern to match keys against. + count (Optional[int]): The number of keys to return in a single iteration - the amount returned can vary and + not obligated to return exactly count. + This param is just a hint to the server of how much steps to do in each iteration. + type (Optional[ObjectType]): The type of object to scan for (STRING, LIST, SET, ZSET, HASH). + + Returns: + List[str, List[str]]: A list containing the next cursor and a list of keys. + + Examples: + >>> In the following example, we will iterate over the keys in the cluster. + client.set("key1", "value1") + client.set("key2", "value2") + client.set("key3", "value3") + let cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await client.scan(cursor, count=10) + all_keys.extend(keys) + print(all_keys) # ['key1', 'key2', 'key3'] + >>> In the following example, we will iterate over the keys in the cluster that match the pattern "my_key*". + client.set("my_key1", "value1") + client.set("my_key2", "value2") + client.set("not_my_key", "value3") + client.set("something_else", "value4") + let cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await client.cluster_scan(cursor, match="my_key*", count=10) + all_keys.extend(keys) + print(all_keys) # ['my_key1', 'my_key2', 'not_my_key'] + >>> In the following example, we will iterate over the keys in the cluster that are of type STRING. + client.set("str_key1", "value1") + client.set("str_key2", "value2") + client.set("str_key3", "value3") + client.sadd("it_is_a_set", "value4") + let cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await client.cluster_scan(cursor, type=ObjectType.STRING) + all_keys.extend(keys) + print(all_keys) # ['str_key1', 'str_key2', 'str_key3'] + """ + response = await self._cluster_scan(cursor, match, count, type) + casted_response = cast( + List[Union[str, List[str]]], + response, + ) + cursor_str = cast(str, casted_response[0]) + cursor = ClusterScanCursor(cursor_str) + return cast(List[Union[ClusterScanCursor, List[str]]], [cursor, casted_response[1]]) diff --git a/python/python/glide/async_commands/command_args.py b/python/python/glide/async_commands/command_args.py index ce76fd2d55..139fc0e868 100644 --- a/python/python/glide/async_commands/command_args.py +++ b/python/python/glide/async_commands/command_args.py @@ -63,3 +63,34 @@ class ListDirection(Enum): """ RIGHT: Represents the option that elements should be popped from or added to the right side of a list. """ + + +class ObjectType(Enum): + """ + Enumeration representing the type of ValKey object. + """ + + STRING = "String" + """ + Represents a string object in Redis. + """ + + LIST = "List" + """ + Represents a list object in Redis. + """ + + SET = "Set" + """ + Represents a set object in Redis. + """ + + ZSET = "ZSet" + """ + Represents a sorted set object in Redis. + """ + + HASH = "Hash" + """ + Represents a hash object in Redis. + """ diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 330a9018d3..7d03998321 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -26,7 +26,7 @@ _create_bitfield_args, _create_bitfield_read_only_args, ) -from glide.async_commands.command_args import Limit, ListDirection, OrderBy +from glide.async_commands.command_args import Limit, ListDirection, ObjectType, OrderBy from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -57,7 +57,7 @@ from glide.protobuf.redis_request_pb2 import RequestType from glide.routes import Route -from ..glide import Script +from ..glide import ClusterScanCursor, Script class ConditionalChange(Enum): @@ -360,6 +360,14 @@ async def _execute_script( route: Optional[Route] = None, ) -> TResult: ... + async def _cluster_scan( + self, + cursor: ClusterScanCursor, + match: Optional[str] = ..., + count: Optional[int] = ..., + type: Optional[ObjectType] = ..., + ) -> TResult: ... + async def set( self, key: str, diff --git a/python/python/glide/async_commands/standalone_commands.py b/python/python/glide/async_commands/standalone_commands.py index cd5518d417..65b4bc7597 100644 --- a/python/python/glide/async_commands/standalone_commands.py +++ b/python/python/glide/async_commands/standalone_commands.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Dict, List, Mapping, Optional, cast +from typing import Dict, List, Mapping, Optional, Union, cast -from glide.async_commands.command_args import Limit, OrderBy +from glide.async_commands.command_args import Limit, ObjectType, OrderBy from glide.async_commands.core import ( CoreCommands, FlushMode, @@ -566,3 +566,72 @@ async def lolwut( str, await self._execute_command(RequestType.Lolwut, args), ) + + async def scan( + self, + cursor: int, + match: Optional[str] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[int, List[str]]]: + """ + Incrementally iterate over a collection of keys. + SCAN is a cursor based iterator. This means that at every call of the command, + the server returns an updated cursor that the user needs to use as the cursor argument in the next call. + An iteration starts when the cursor is set to 0, and terminates when the cursor returned by the server is 0. + + The SCAN command, and the other commands in the SCAN family, + are able to provide to the user a set of guarantees associated to full iterations. + + A full iteration always retrieves all the elements that were present + in the collection from the start to the end of a full iteration. + A full iteration never returns any element that was NOT present in the collection + from the start to the end of a full iteration. + However because SCAN has very little state associated (just the cursor) it has the following drawbacks: + A given element may be returned multiple times. + Elements that were not constantly present in the collection during a full iteration, may be returned or not/ + + See https://valkey.io/commands/scan for more details. + + Args: + cursor (int): The cursor used for the iteration. In the first iteration, the cursor should be set to 0. + If the cursor sent to the server is not 0 or is not a valid cursor, + the result are undefined. + match (Optional[str]): A pattern to match keys against, + for example, "key*" will return all keys starting with "key". + count (Optional[int]): The number of keys to return per iteration. + The number of keys returned per iteration is not guaranteed to be the same as the count argument. + the argument is used as a hint for the server to know how many "steps" it can use to retrieve the keys. + The default value is 10. + type (ObjectType): The type of object to scan for: STRING, LIST, SET, HASH, ZSET. + + Returns: + List[int, List[str]]: A tuple containing the next cursor value and a list of keys. + + Examples: + >>> result = await client.scan(0) + print(result) #[17, ['key1', 'key2', 'key3', 'key4', 'key5', 'set1', 'set2', 'set3']] + result = await client.scan(17) + print(result) #[349, ['key4', 'key5', 'set1', 'hash1', 'zset1', 'list1', 'list2', + 'list3', 'zset2', 'zset3', 'zset4', 'zset5', 'zset6']] + result = await client.scan(349) + print(result) #[0, ['key6', 'key7']] + + >>> result = await client.scan(17, match="key*", count=2) + print(result) #[6, ['key4', 'key5']] + + >>> result = await client.scan(0, type=ObjectType.Set) + print(result) #[362, ['set1', 'set2', 'set3']] + """ + args = [str(cursor)] + if match: + args.extend(["MATCH", match]) + if count: + args.extend(["COUNT", str(count)]) + if type: + args.extend(["TYPE", type.value]) + response = await self._execute_command(RequestType.Scan, args) + casted_response = cast(List[Union[int, List[str]]], response) + str_cursor = cast(str, casted_response[0]) + keys = cast(List[str], casted_response[1]) + return [int(str_cursor), keys] diff --git a/python/python/glide/constants.py b/python/python/glide/constants.py index f78398895a..0a800e1b95 100644 --- a/python/python/glide/constants.py +++ b/python/python/glide/constants.py @@ -1,6 +1,6 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 -from typing import Dict, List, Literal, Mapping, Optional, Set, TypeVar, Union +from typing import Dict, List, Literal, Mapping, Optional, Set, Tuple, TypeVar, Union from glide.protobuf.connection_request_pb2 import ConnectionRequest from glide.protobuf.redis_request_pb2 import RedisRequest @@ -22,6 +22,7 @@ float, Set[T], List[T], + List[Union[str, List[str]]], ] TRequest = Union[RedisRequest, ConnectionRequest] # When routing to a single node, response will be T diff --git a/python/python/glide/glide.pyi b/python/python/glide/glide.pyi index fde1ac0d99..c3cbb29e20 100644 --- a/python/python/glide/glide.pyi +++ b/python/python/glide/glide.pyi @@ -21,6 +21,12 @@ class Script: def get_hash(self) -> str: ... def __del__(self) -> None: ... +class ClusterScanCursor: + def __init__(self, cursor: Optional[str]) -> None: ... + def get_cursor(self) -> str: ... + def is_finished(self) -> bool: ... + def __del__(self) -> None: ... + def start_socket_listener_external(init_callback: Callable) -> None: ... def value_from_pointer(pointer: int) -> TResult: ... def create_leaked_value(message: str) -> int: ... diff --git a/python/python/glide/glide_client.py b/python/python/glide/glide_client.py index 2840caf9a6..2901224f04 100644 --- a/python/python/glide/glide_client.py +++ b/python/python/glide/glide_client.py @@ -7,6 +7,7 @@ import async_timeout from glide.async_commands.cluster_commands import ClusterCommands +from glide.async_commands.command_args import ObjectType from glide.async_commands.core import CoreCommands from glide.async_commands.standalone_commands import StandaloneCommands from glide.config import BaseClientConfiguration @@ -31,6 +32,7 @@ from .glide import ( DEFAULT_TIMEOUT_IN_MILLISECONDS, MAX_REQUEST_ARGS_LEN, + ClusterScanCursor, create_leaked_bytes_vec, start_socket_listener_external, value_from_pointer, @@ -516,6 +518,31 @@ class GlideClusterClient(BaseClient, ClusterCommands): https://github.com/aws/babushka/wiki/Python-wrapper#redis-cluster """ + async def _cluster_scan( + self, + cursor: ClusterScanCursor, + match: Optional[str] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[str, List[str]]]: + if self._is_closed: + raise ClosingError( + "Unable to execute requests; the client is closed. Please create a new client." + ) + request = RedisRequest() + request.callback_idx = self._get_callback_index() + # Take out the hash string from the wrapping object + cursor_str = cursor.get_cursor() + if cursor_str is not None: + request.cluster_scan.cursor = cursor_str + if match is not None: + request.cluster_scan.match_pattern = match + if count is not None: + request.cluster_scan.count = count + if type is not None: + request.cluster_scan.object_type = type.value + return await self._write_request_await_response(request) + def _get_protobuf_conn_request(self) -> ConnectionRequest: return self.config._create_a_protobuf_conn_request(cluster_mode=True) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index b5134b2248..a038e3dbd5 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Union, cast import pytest -from glide import ClosingError, RequestError, Script +from glide import ClosingError, ClusterScanCursor, RequestError, Script from glide.async_commands.bitmap import ( BitFieldGet, BitFieldIncrBy, @@ -26,7 +26,7 @@ SignedEncoding, UnsignedEncoding, ) -from glide.async_commands.command_args import Limit, ListDirection, OrderBy +from glide.async_commands.command_args import Limit, ListDirection, ObjectType, OrderBy from glide.async_commands.core import ( ConditionalChange, ExpireOptions, @@ -6487,6 +6487,350 @@ async def test_lolwut(self, redis_client: TGlideClient): result = await redis_client.lolwut(2, [10, 20], RandomNode()) assert "Redis ver. " in node_result + # Cluster scan tests + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_simple_cluster_scan(self, redis_client: RedisClusterClient): + expected_keys = [f"key:{i}" for i in range(100)] + for key in expected_keys: + await redis_client.set(key, "value") + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor) + cursor = result[0] + keys.extend(result[1]) + + assert set(expected_keys) == set(keys) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_object_type_and_pattern( + self, redis_client: RedisClusterClient + ): + expected_keys = [f"key:{i}" for i in range(100)] + for key in expected_keys: + await redis_client.set(key, "value") + unexpected_type_keys = [f"key:{i}" for i in range(100, 200)] + for key in unexpected_type_keys: + await redis_client.sadd(key, ["value"]) + unexpected_pattern_keys = [f"{i}" for i in range(200, 300)] + for key in unexpected_pattern_keys: + await redis_client.set(key, "value") + keys = [] + cursor = ClusterScanCursor(None) + while not cursor.is_finished(): + result = await redis_client.scan( + cursor, match="key:*", type=ObjectType.STRING + ) + cursor = result[0] + keys.extend(result[1]) + + assert set(expected_keys) == set(keys) + assert not set(unexpected_type_keys).intersection(set(keys)) + assert not set(unexpected_pattern_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_count(self, redis_client: RedisClusterClient): + expected_keys = [f"key:{i}" for i in range(1000)] + for key in expected_keys: + await redis_client.set(key, "value") + cursor = ClusterScanCursor(None) + keys = [] + succefull_compared_scans = 0 + while not cursor.is_finished(): + result_of_1 = await redis_client.scan(cursor, count=1) + cursor = result_of_1[0] + keys_of_1 = result_of_1[1] + keys.extend(keys_of_1) + if cursor.is_finished(): + break + result_of_100 = await redis_client.scan(cursor, count=100) + cursor = result_of_100[0] + keys_of_100 = result_of_100[1] + keys.extend(keys_of_100) + if keys_of_100 > keys_of_1: + succefull_compared_scans += 1 + + assert set(expected_keys) == set(keys) + assert succefull_compared_scans > 0 + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_match(self, redis_client: RedisClusterClient): + unexpected_keys = [f"{i}" for i in range(100)] + for key in unexpected_keys: + await redis_client.set(key, "value") + expected_keys = [f"key:{i}" for i in range(100, 200)] + for key in expected_keys: + await redis_client.set(key, "value") + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, match="key:*") + cursor = result[0] + keys.extend(result[1]) + assert set(expected_keys) == set(keys) + assert not set(unexpected_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + # We test whther the cursor is cleaned up after it is deleted, which we expect to happen when th GC is called + async def test_cluster_scan_cleaning_cursor(self, redis_client: RedisClusterClient): + keys = [f"key:{i}" for i in range(100)] + for key in keys: + await redis_client.set(key, "value") + cursor = ClusterScanCursor(None) + result = await redis_client.scan(cursor) + cursor = result[0] + cursor_string = str(cursor) + cursor.__del__() + new_cursor_with_same_id = ClusterScanCursor(cursor_string) + with pytest.raises(RequestError) as e_info: + await redis_client.scan(new_cursor_with_same_id) + assert "Invalid scan_state_cursor hash" in str(e_info.value) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_all_types(self, redis_client: RedisClusterClient): + # We test that the scan command work for all types of keys + string_keys = [f"key:{i}" for i in range(100)] + for key in string_keys: + await redis_client.set(key, "value") + + set_keys = [f"key:{i}" for i in range(100, 200)] + for key in set_keys: + await redis_client.sadd(key, ["value"]) + + hash_keys = [f"key:{i}" for i in range(200, 300)] + for key in hash_keys: + await redis_client.hset(key, {"field": "value"}) + + list_keys = [f"key:{i}" for i in range(300, 400)] + for key in list_keys: + await redis_client.lpush(key, ["value"]) + + zset_keys = [f"key:{i}" for i in range(400, 500)] + for key in zset_keys: + await redis_client.zadd(key, {"value": 1}) + + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.STRING) + cursor = result[0] + keys.extend(result[1]) + assert set(string_keys) == set(keys) + assert not set(set_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.SET) + cursor = result[0] + keys.extend(result[1]) + assert set(set_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.HASH) + cursor = result[0] + keys.extend(result[1]) + assert set(hash_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(set_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = ClusterScanCursor(None) + keys = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.LIST) + cursor = result[0] + keys.extend(result[1]) + assert set(list_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(set_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + # Standalone scan tests + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_simple_scan(self, redis_client: RedisClient): + expected_keys = [f"key:{i}" for i in range(100)] + for key in expected_keys: + await redis_client.set(key, "value") + keys = [] + cursor = 0 + while True: + result = await redis_client.scan(cursor) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(expected_keys) == set(keys) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_scan_with_object_type_and_pattern(self, redis_client: RedisClient): + expected_keys = [f"key:{i}" for i in range(100)] + for key in expected_keys: + await redis_client.set(key, "value") + unexpected_type_keys = [f"key:{i}" for i in range(100, 200)] + for key in unexpected_type_keys: + await redis_client.sadd(key, ["value"]) + unexpected_pattern_keys = [f"{i}" for i in range(200, 300)] + for key in unexpected_pattern_keys: + await redis_client.set(key, "value") + keys = [] + cursor = 0 + while True: + result = await redis_client.scan( + cursor, match="key:*", type=ObjectType.STRING + ) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(expected_keys) == set(keys) + assert not set(unexpected_type_keys).intersection(set(keys)) + assert not set(unexpected_pattern_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_scan_with_count(self, redis_client: RedisClient): + expected_keys = [f"key:{i}" for i in range(1000)] + for key in expected_keys: + await redis_client.set(key, "value") + cursor = 0 + keys = [] + succefull_compared_scans = 0 + while True: + result_of_1 = await redis_client.scan(cursor, count=1) + cursor = result_of_1[0] + keys_of_1 = result_of_1[1] + keys.extend(keys_of_1) + result_of_100 = await redis_client.scan(cursor, count=100) + cursor = result_of_100[0] + keys_of_100 = result_of_100[1] + keys.extend(keys_of_100) + if keys_of_100 > keys_of_1: + succefull_compared_scans += 1 + if cursor == 0: + break + assert set(expected_keys) == set(keys) + assert succefull_compared_scans > 0 + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_with_match(self, redis_client: RedisClient): + unexpected_keys = [f"{i}" for i in range(100)] + for key in unexpected_keys: + await redis_client.set(key, "value") + expected_keys = [f"key:{i}" for i in range(100, 200)] + for key in expected_keys: + await redis_client.set(key, "value") + cursor = 0 + keys = [] + while True: + result = await redis_client.scan(cursor, match="key:*") + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(expected_keys) == set(keys) + assert not set(unexpected_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_scan_all_types(self, redis_client: RedisClient): + # We test that the scan command work for all types of keys + string_keys = [f"key:{i}" for i in range(100)] + for key in string_keys: + await redis_client.set(key, "value") + + set_keys = [f"key:{i}" for i in range(100, 200)] + for key in set_keys: + await redis_client.sadd(key, ["value"]) + + hash_keys = [f"key:{i}" for i in range(200, 300)] + for key in hash_keys: + await redis_client.hset(key, {"field": "value"}) + + list_keys = [f"key:{i}" for i in range(300, 400)] + for key in list_keys: + await redis_client.lpush(key, ["value"]) + + zset_keys = [f"key:{i}" for i in range(400, 500)] + for key in zset_keys: + await redis_client.zadd(key, {"value": 1}) + + cursor = 0 + keys = [] + while True: + result = await redis_client.scan(cursor, type=ObjectType.STRING) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(string_keys) == set(keys) + assert not set(set_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = 0 + keys = [] + while True: + result = await redis_client.scan(cursor, type=ObjectType.SET) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(set_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = 0 + keys = [] + while True: + result = await redis_client.scan(cursor, type=ObjectType.HASH) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(hash_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(set_keys).intersection(set(keys)) + assert not set(list_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + + cursor = 0 + keys = [] + while True: + result = await redis_client.scan(cursor, type=ObjectType.LIST) + cursor = result[0] + keys.extend(result[1]) + if cursor == 0: + break + assert set(list_keys) == set(keys) + assert not set(string_keys).intersection(set(keys)) + assert not set(set_keys).intersection(set(keys)) + assert not set(hash_keys).intersection(set(keys)) + assert not set(zset_keys).intersection(set(keys)) + class TestMultiKeyCommandCrossSlot: @pytest.mark.parametrize("cluster_mode", [True]) diff --git a/python/src/lib.rs b/python/src/lib.rs index 143d706f99..255c9ec8cb 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -32,6 +32,41 @@ impl Level { } } +/// This struct is used to keep track of the cursor of a cluster scan. +/// We want to avoid passing the cursor between layers of the application, +/// So we keep the state in the container and only pass the hash of the cursor. +/// The cursor is stored in the container and can be retrieved using the hash. +/// The cursor is removed from the container when the object is deleted (dropped). +#[pyclass] +pub struct ClusterScanCursor { + cursor: String, +} + +#[pymethods] +impl ClusterScanCursor { + #[new] + fn new(new_cursor: Option) -> Self { + match new_cursor { + Some(cursor) => ClusterScanCursor { cursor }, + None => ClusterScanCursor { + cursor: String::new(), + }, + } + } + + fn get_cursor(&self) -> String { + self.cursor.clone() + } + + fn __del__(&mut self) { + glide_core::cluster_scan_container::remove_scan_state_cursor(self.cursor.clone()); + } + + fn is_finished(&self) -> bool { + self.cursor == "finished" + } +} + #[pyclass] pub struct Script { hash: String, @@ -59,6 +94,7 @@ impl Script { fn glide(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::